zipenhancer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zipenhancer/__init__.py +32 -0
- zipenhancer/codec.py +524 -0
- zipenhancer/denoise.py +144 -0
- zipenhancer/models/__init__.py +4 -0
- zipenhancer/models/layers/__init__.py +4 -0
- zipenhancer/models/layers/generator.py +215 -0
- zipenhancer/models/layers/scaling.py +412 -0
- zipenhancer/models/layers/zipenhancer_layer.py +500 -0
- zipenhancer/models/layers/zipformer.py +1035 -0
- zipenhancer/models/zipenhancer.py +225 -0
- zipenhancer/standalone.py +164 -0
- zipenhancer-0.1.0.dist-info/METADATA +391 -0
- zipenhancer-0.1.0.dist-info/RECORD +16 -0
- zipenhancer-0.1.0.dist-info/WHEEL +5 -0
- zipenhancer-0.1.0.dist-info/licenses/LICENSE +21 -0
- zipenhancer-0.1.0.dist-info/top_level.txt +1 -0
zipenhancer/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""
|
|
2
|
+
zipenhancer — 语音降噪核心包
|
|
3
|
+
|
|
4
|
+
用法:
|
|
5
|
+
from zipenhancer import denoise, write
|
|
6
|
+
|
|
7
|
+
audio, sr = librosa.load("noisy.wav", sr=16000)
|
|
8
|
+
denoised, proc_time, duration = denoise(audio, sr)
|
|
9
|
+
write("output.flac", denoised, sr, fmt="flac")
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from zipenhancer.codec import write, FORMATS, get_supported_formats, WriteResult
|
|
13
|
+
from zipenhancer.codec import FormatConfig, CodecError, FormatNotSupported
|
|
14
|
+
from zipenhancer.denoise import denoise, load_model, ensure_model, normalize_audio
|
|
15
|
+
from zipenhancer.denoise import MODEL_ZIPENHANCER, MODEL_FRCRN, MODEL_MOSSFORMER2
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"denoise",
|
|
19
|
+
"write",
|
|
20
|
+
"load_model",
|
|
21
|
+
"ensure_model",
|
|
22
|
+
"normalize_audio",
|
|
23
|
+
"FORMATS",
|
|
24
|
+
"get_supported_formats",
|
|
25
|
+
"WriteResult",
|
|
26
|
+
"FormatConfig",
|
|
27
|
+
"CodecError",
|
|
28
|
+
"FormatNotSupported",
|
|
29
|
+
"MODEL_ZIPENHANCER",
|
|
30
|
+
"MODEL_FRCRN",
|
|
31
|
+
"MODEL_MOSSFORMER2",
|
|
32
|
+
]
|
zipenhancer/codec.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
import subprocess
|
|
4
|
+
import uuid
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import soundfile as sf
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class FormatConfig:
|
|
18
|
+
"""格式能力定义"""
|
|
19
|
+
name: str # 唯一标识符: wav / flac / mp3 / ogg
|
|
20
|
+
description: str # API接口中展示的描述
|
|
21
|
+
ext: str # 文件扩展名(含点)
|
|
22
|
+
subtypes: tuple # 有效 subtype 列表,空 tuple 表示由 codec 名决定
|
|
23
|
+
default_subtype: str # 默认 subtype
|
|
24
|
+
needs_ffmpeg: bool # 是否依赖 ffmpeg
|
|
25
|
+
supports_compression: bool = False
|
|
26
|
+
supports_bitrate: bool = False
|
|
27
|
+
compression_range: tuple = ()
|
|
28
|
+
bitrate_range: tuple = ()
|
|
29
|
+
default_bitrate: Optional[str] = None
|
|
30
|
+
default_compression: Optional[int] = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
FORMATS = {
|
|
34
|
+
"wav": FormatConfig(
|
|
35
|
+
name="wav",
|
|
36
|
+
description="WAV(无损,兼容性最佳)",
|
|
37
|
+
ext=".wav",
|
|
38
|
+
subtypes=("PCM_16", "PCM_24", "PCM_32", "FLOAT"),
|
|
39
|
+
default_subtype="PCM_16",
|
|
40
|
+
needs_ffmpeg=False,
|
|
41
|
+
),
|
|
42
|
+
"flac": FormatConfig(
|
|
43
|
+
name="flac",
|
|
44
|
+
description="FLAC(无损压缩,文件较小)",
|
|
45
|
+
ext=".flac",
|
|
46
|
+
subtypes=("PCM_16", "PCM_24"),
|
|
47
|
+
default_subtype="PCM_16",
|
|
48
|
+
needs_ffmpeg=False,
|
|
49
|
+
supports_compression=True,
|
|
50
|
+
compression_range=(0, 8),
|
|
51
|
+
default_compression=5,
|
|
52
|
+
),
|
|
53
|
+
"mp3": FormatConfig(
|
|
54
|
+
name="mp3",
|
|
55
|
+
description="MP3(有损压缩,广泛兼容)",
|
|
56
|
+
ext=".mp3",
|
|
57
|
+
subtypes=(),
|
|
58
|
+
default_subtype="libmp3lame",
|
|
59
|
+
needs_ffmpeg=True,
|
|
60
|
+
supports_bitrate=True,
|
|
61
|
+
bitrate_range=(32000, 320000),
|
|
62
|
+
default_bitrate="192k",
|
|
63
|
+
),
|
|
64
|
+
"ogg": FormatConfig(
|
|
65
|
+
name="ogg",
|
|
66
|
+
description="OGG Opus/Vorbis(有损压缩,开源优选)",
|
|
67
|
+
ext=".ogg",
|
|
68
|
+
subtypes=(),
|
|
69
|
+
default_subtype="opus",
|
|
70
|
+
needs_ffmpeg=True,
|
|
71
|
+
supports_bitrate=True,
|
|
72
|
+
bitrate_range=(6000, 510000),
|
|
73
|
+
default_bitrate="192k",
|
|
74
|
+
),
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class CodecError(Exception):
|
|
79
|
+
def __init__(self, message: str, hint: str = ""):
|
|
80
|
+
"""编解码错误基类
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
message (str): 错误描述
|
|
84
|
+
hint (str, optional): 修复提示. Defaults to "".
|
|
85
|
+
"""
|
|
86
|
+
self.hint = hint
|
|
87
|
+
super().__init__(message if not hint else f"{message}({hint})")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class FormatNotSupported(CodecError):
|
|
91
|
+
"""不支持的输出格式"""
|
|
92
|
+
def __init__(self, fmt: str):
|
|
93
|
+
super().__init__(
|
|
94
|
+
f"不支持的输出格式: {fmt}",
|
|
95
|
+
f"可选: {', '.join(FORMATS)}",
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class SubtypeNotSupported(CodecError):
|
|
100
|
+
"""不支持的编码参数"""
|
|
101
|
+
def __init__(self, fmt: str, subtype: str, valid: tuple):
|
|
102
|
+
super().__init__(
|
|
103
|
+
f"格式 {fmt} 不支持编码 {subtype}",
|
|
104
|
+
f"{fmt} 支持的编码: {', '.join(valid)}",
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class FfmpegNotFound(CodecError):
|
|
109
|
+
"""ffmpeg 未安装"""
|
|
110
|
+
def __init__(self):
|
|
111
|
+
super().__init__(
|
|
112
|
+
"系统未找到 ffmpeg",
|
|
113
|
+
"请安装 ffmpeg 并将其加入 PATH(https://ffmpeg.org/download.html)",
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class DiskSpaceError(CodecError):
|
|
118
|
+
"""磁盘空间不足"""
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class EncodeError(CodecError):
|
|
123
|
+
"""编码失败"""
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@dataclass
|
|
128
|
+
class WriteResult:
|
|
129
|
+
"""编码写入结果"""
|
|
130
|
+
path: str
|
|
131
|
+
format: str
|
|
132
|
+
subtype: str
|
|
133
|
+
bitrate: Optional[str] = None
|
|
134
|
+
compression: Optional[int] = None
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def write(
|
|
138
|
+
path: str,
|
|
139
|
+
data: np.ndarray,
|
|
140
|
+
sample_rate: int,
|
|
141
|
+
fmt: str = "wav",
|
|
142
|
+
subtype: Optional[str] = None,
|
|
143
|
+
bitrate: Optional[str] = None,
|
|
144
|
+
compression: Optional[int] = None,
|
|
145
|
+
atomic: bool = True,
|
|
146
|
+
) -> WriteResult:
|
|
147
|
+
"""_summary_
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
path (str): 输出文件路径
|
|
151
|
+
data (np.ndarray): 音频数据
|
|
152
|
+
sample_rate (int): 采样率
|
|
153
|
+
fmt (str, optional): 输出格式. Defaults to "wav".
|
|
154
|
+
subtype (Optional[str], optional): 编码子类型. Defaults to None.
|
|
155
|
+
bitrate (Optional[str], optional): 比特率. Defaults to None.
|
|
156
|
+
compression (Optional[int], optional): 压缩级别. Defaults to None.
|
|
157
|
+
atomic (bool, optional): 是否原子写入. Defaults to True.
|
|
158
|
+
|
|
159
|
+
Raises:
|
|
160
|
+
FormatNotSupported: _description_
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
WriteResult: WriteResult
|
|
164
|
+
"""
|
|
165
|
+
cfg = FORMATS.get(fmt)
|
|
166
|
+
if cfg is None:
|
|
167
|
+
raise FormatNotSupported(fmt)
|
|
168
|
+
|
|
169
|
+
# 确保输出目录存在
|
|
170
|
+
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
|
171
|
+
|
|
172
|
+
# 1. 校验 & 填充默认值
|
|
173
|
+
subtype, bitrate, compression = _resolve_params(cfg, subtype, bitrate, compression)
|
|
174
|
+
|
|
175
|
+
# 2. 估算大小 + 磁盘空间检查
|
|
176
|
+
estimated = _estimate_size(data, sample_rate, cfg, subtype, bitrate, compression)
|
|
177
|
+
_check_disk_space(path, estimated)
|
|
178
|
+
|
|
179
|
+
# 3. 写入
|
|
180
|
+
if atomic:
|
|
181
|
+
_atomic_write(path, data, sample_rate, cfg, subtype, bitrate, compression)
|
|
182
|
+
else:
|
|
183
|
+
_do_write(path, data, sample_rate, cfg, subtype, bitrate, compression)
|
|
184
|
+
|
|
185
|
+
return WriteResult(
|
|
186
|
+
path=path,
|
|
187
|
+
format=fmt,
|
|
188
|
+
subtype=subtype,
|
|
189
|
+
bitrate=bitrate,
|
|
190
|
+
compression=compression,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _resolve_params(
|
|
195
|
+
cfg: FormatConfig,
|
|
196
|
+
subtype: Optional[str],
|
|
197
|
+
bitrate: Optional[str],
|
|
198
|
+
compression: Optional[int],
|
|
199
|
+
) -> tuple:
|
|
200
|
+
"""校验并填充编码参数,必要时自动降级"""
|
|
201
|
+
# subtype
|
|
202
|
+
if subtype is None:
|
|
203
|
+
subtype = cfg.default_subtype
|
|
204
|
+
else:
|
|
205
|
+
if cfg.subtypes and subtype not in cfg.subtypes:
|
|
206
|
+
logger.warning(
|
|
207
|
+
"格式 %s 不支持编码 %s,降级为 %s",
|
|
208
|
+
cfg.name, subtype, cfg.default_subtype,
|
|
209
|
+
)
|
|
210
|
+
subtype = cfg.default_subtype
|
|
211
|
+
|
|
212
|
+
# bitrate
|
|
213
|
+
if cfg.supports_bitrate:
|
|
214
|
+
if bitrate is None:
|
|
215
|
+
bitrate = cfg.default_bitrate
|
|
216
|
+
else:
|
|
217
|
+
try:
|
|
218
|
+
bps = _parse_bitrate(bitrate)
|
|
219
|
+
except (ValueError, TypeError):
|
|
220
|
+
logger.warning("比特率格式无效 %s,使用默认值 %s", bitrate, cfg.default_bitrate)
|
|
221
|
+
bitrate = cfg.default_bitrate
|
|
222
|
+
else:
|
|
223
|
+
if bps < cfg.bitrate_range[0] or bps > cfg.bitrate_range[1]:
|
|
224
|
+
logger.warning(
|
|
225
|
+
"比特率 %s 超出 %s 范围 (%d-%d bps),使用默认值 %s",
|
|
226
|
+
bitrate, cfg.name, cfg.bitrate_range[0], cfg.bitrate_range[1],
|
|
227
|
+
cfg.default_bitrate,
|
|
228
|
+
)
|
|
229
|
+
bitrate = cfg.default_bitrate
|
|
230
|
+
else:
|
|
231
|
+
bitrate = None
|
|
232
|
+
|
|
233
|
+
# compression
|
|
234
|
+
if cfg.supports_compression:
|
|
235
|
+
if compression is None:
|
|
236
|
+
compression = cfg.default_compression
|
|
237
|
+
elif compression < cfg.compression_range[0] or compression > cfg.compression_range[1]:
|
|
238
|
+
logger.warning(
|
|
239
|
+
"压缩级别 %d 超出 %s 范围 (%d-%d),使用默认值 %d",
|
|
240
|
+
compression, cfg.name, cfg.compression_range[0], cfg.compression_range[1],
|
|
241
|
+
cfg.default_compression,
|
|
242
|
+
)
|
|
243
|
+
compression = cfg.default_compression
|
|
244
|
+
else:
|
|
245
|
+
compression = None
|
|
246
|
+
|
|
247
|
+
return subtype, bitrate, compression
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _parse_bitrate(bitrate: str) -> int:
|
|
251
|
+
"""将比特率字符串转为 bps,如 '192k' → 192000"""
|
|
252
|
+
s = str(bitrate).strip().lower()
|
|
253
|
+
if s.endswith("k"):
|
|
254
|
+
return int(float(s[:-1]) * 1000)
|
|
255
|
+
elif s.endswith("m"):
|
|
256
|
+
return int(float(s[:-1]) * 1_000_000)
|
|
257
|
+
else:
|
|
258
|
+
return int(s)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def _estimate_size(
|
|
262
|
+
data: np.ndarray,
|
|
263
|
+
sample_rate: int,
|
|
264
|
+
cfg: FormatConfig,
|
|
265
|
+
subtype: str,
|
|
266
|
+
bitrate: Optional[str],
|
|
267
|
+
compression: Optional[int],
|
|
268
|
+
) -> int:
|
|
269
|
+
"""估算输出文件大小(字节)"""
|
|
270
|
+
if data.ndim == 1:
|
|
271
|
+
n_channels = 1
|
|
272
|
+
n_samples = data.shape[0]
|
|
273
|
+
else:
|
|
274
|
+
n_channels = data.shape[1] if data.shape[1] > 0 else 1
|
|
275
|
+
n_samples = data.shape[0]
|
|
276
|
+
|
|
277
|
+
duration = n_samples / sample_rate
|
|
278
|
+
|
|
279
|
+
if cfg.name == "wav":
|
|
280
|
+
bps = _subtype_bytes(subtype)
|
|
281
|
+
return int(n_samples * n_channels * bps * 1.02) # + header overhead
|
|
282
|
+
elif cfg.name == "flac":
|
|
283
|
+
bps = _subtype_bytes(subtype)
|
|
284
|
+
# FLAC 通常压缩到原始 PCM 的 50-70%
|
|
285
|
+
return int(n_samples * n_channels * bps * 0.65)
|
|
286
|
+
elif cfg.name == "mp3" and bitrate:
|
|
287
|
+
bps = _parse_bitrate(bitrate)
|
|
288
|
+
return int(bps / 8 * duration * 1.05)
|
|
289
|
+
elif cfg.name == "ogg" and bitrate:
|
|
290
|
+
bps = _parse_bitrate(bitrate)
|
|
291
|
+
return int(bps / 8 * duration * 1.05)
|
|
292
|
+
else:
|
|
293
|
+
return int(n_samples * n_channels * 4 * 1.1)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _subtype_bytes(subtype: str) -> int:
|
|
297
|
+
"""subtype → 每样本字节数"""
|
|
298
|
+
return {
|
|
299
|
+
"PCM_S8": 1, "PCM_U8": 1,
|
|
300
|
+
"PCM_16": 2,
|
|
301
|
+
"PCM_24": 3,
|
|
302
|
+
"PCM_32": 4, "FLOAT": 4,
|
|
303
|
+
"DOUBLE": 8,
|
|
304
|
+
}.get(subtype, 4)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def _do_write(
|
|
308
|
+
path: str,
|
|
309
|
+
data: np.ndarray,
|
|
310
|
+
sample_rate: int,
|
|
311
|
+
cfg: FormatConfig,
|
|
312
|
+
subtype: str,
|
|
313
|
+
bitrate: Optional[str],
|
|
314
|
+
compression: Optional[int],
|
|
315
|
+
):
|
|
316
|
+
"""实际写入(依格式派发)"""
|
|
317
|
+
if cfg.name in ("wav", "flac"):
|
|
318
|
+
_write_sf(path, data, sample_rate, cfg, subtype, compression)
|
|
319
|
+
elif cfg.needs_ffmpeg:
|
|
320
|
+
_write_ffmpeg(path, data, sample_rate, cfg, subtype, bitrate)
|
|
321
|
+
else:
|
|
322
|
+
raise EncodeError(f"未知格式: {cfg.name}")
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def _write_sf(
|
|
326
|
+
path: str,
|
|
327
|
+
data: np.ndarray,
|
|
328
|
+
sample_rate: int,
|
|
329
|
+
cfg: FormatConfig,
|
|
330
|
+
subtype: str,
|
|
331
|
+
compression: Optional[int],
|
|
332
|
+
):
|
|
333
|
+
"""通过 soundfile 写入 WAV/FLAC"""
|
|
334
|
+
# data 约定: [samples] 或 [samples, channels],已是 soundfile 原生格式
|
|
335
|
+
if data.ndim == 1:
|
|
336
|
+
data_for_sf = data
|
|
337
|
+
else:
|
|
338
|
+
data_for_sf = data
|
|
339
|
+
|
|
340
|
+
# RF64 自动降级(>4GB WAV)
|
|
341
|
+
sf_format = cfg.name.upper()
|
|
342
|
+
if cfg.name == "wav" and data_for_sf.nbytes > 3.5 * 1024 ** 3:
|
|
343
|
+
sf_format = "RF64"
|
|
344
|
+
logger.info("文件超过 4GB,自动使用 RF64 格式")
|
|
345
|
+
|
|
346
|
+
sf_kwargs = {"format": sf_format, "subtype": subtype}
|
|
347
|
+
if compression is not None and cfg.name == "flac":
|
|
348
|
+
# soundfile 0.13.x 的 compression_level: FLAC 有效范围 0.0-1.0
|
|
349
|
+
sf_kwargs["compression_level"] = compression / 8.0
|
|
350
|
+
|
|
351
|
+
try:
|
|
352
|
+
sf.write(path, data_for_sf, sample_rate, **sf_kwargs)
|
|
353
|
+
except Exception as e:
|
|
354
|
+
raise EncodeError(
|
|
355
|
+
f"soundfile 编码失败: {e}",
|
|
356
|
+
hint="检查磁盘空间和文件权限",
|
|
357
|
+
) from e
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
# 每个格式的首选编码器列表(按优先级)
|
|
361
|
+
_FFMPEG_ENCODERS = {
|
|
362
|
+
"mp3": ["libmp3lame", "mp3_mf"],
|
|
363
|
+
"ogg": ["libvorbis", "opus", "vorbis"],
|
|
364
|
+
}
|
|
365
|
+
_encoder_cache = {}
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def _detect_encoder(format_name: str) -> str:
|
|
369
|
+
"""自动检测系统可用的 ffmpeg 编码器(缓存结果)"""
|
|
370
|
+
if format_name in _encoder_cache:
|
|
371
|
+
return _encoder_cache[format_name]
|
|
372
|
+
|
|
373
|
+
candidates = _FFMPEG_ENCODERS.get(format_name, [])
|
|
374
|
+
if not candidates:
|
|
375
|
+
_encoder_cache[format_name] = ""
|
|
376
|
+
return ""
|
|
377
|
+
|
|
378
|
+
try:
|
|
379
|
+
proc = subprocess.run(
|
|
380
|
+
["ffmpeg", "-encoders"],
|
|
381
|
+
capture_output=True, text=True, timeout=10,
|
|
382
|
+
)
|
|
383
|
+
available = proc.stdout
|
|
384
|
+
except Exception:
|
|
385
|
+
_encoder_cache[format_name] = candidates[0]
|
|
386
|
+
return candidates[0]
|
|
387
|
+
|
|
388
|
+
for enc in candidates:
|
|
389
|
+
if enc in available:
|
|
390
|
+
_encoder_cache[format_name] = enc
|
|
391
|
+
logger.info("ffmpeg 编码器检测: %s → %s", format_name, enc)
|
|
392
|
+
return enc
|
|
393
|
+
|
|
394
|
+
_encoder_cache[format_name] = candidates[0]
|
|
395
|
+
logger.warning("ffmpeg 编码器 %s 均未检测到,尝试 %s", candidates, candidates[0])
|
|
396
|
+
return candidates[0]
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def _write_ffmpeg(
|
|
400
|
+
path: str,
|
|
401
|
+
data: np.ndarray,
|
|
402
|
+
sample_rate: int,
|
|
403
|
+
cfg: FormatConfig,
|
|
404
|
+
subtype: str,
|
|
405
|
+
bitrate: Optional[str],
|
|
406
|
+
):
|
|
407
|
+
"""通过 ffmpeg 写入 MP3/OGG(自动检测编码器)"""
|
|
408
|
+
# 先写临时 WAV,再转码
|
|
409
|
+
tmp_wav = path + ".tmp.wav." + uuid.uuid4().hex[:12]
|
|
410
|
+
try:
|
|
411
|
+
_write_sf(tmp_wav, data, sample_rate, FORMATS["wav"], "PCM_16", None)
|
|
412
|
+
|
|
413
|
+
encoder = _detect_encoder(cfg.name)
|
|
414
|
+
cmd = ["ffmpeg", "-y", "-i", tmp_wav, "-codec:a", encoder]
|
|
415
|
+
if bitrate:
|
|
416
|
+
cmd.extend(["-b:a", bitrate])
|
|
417
|
+
|
|
418
|
+
# opus 编码器在部分 Windows 构建中标记为 experimental
|
|
419
|
+
if encoder == "opus":
|
|
420
|
+
cmd.extend(["-strict", "-2"])
|
|
421
|
+
|
|
422
|
+
cmd.extend(["-f", cfg.name, path])
|
|
423
|
+
|
|
424
|
+
_run_ffmpeg(cmd)
|
|
425
|
+
finally:
|
|
426
|
+
if os.path.exists(tmp_wav):
|
|
427
|
+
os.unlink(tmp_wav)
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def _run_ffmpeg(cmd: list, timeout: int = 300):
|
|
431
|
+
"""执行 ffmpeg 命令"""
|
|
432
|
+
_check_ffmpeg()
|
|
433
|
+
|
|
434
|
+
try:
|
|
435
|
+
proc = subprocess.run(
|
|
436
|
+
cmd,
|
|
437
|
+
capture_output=True,
|
|
438
|
+
text=True,
|
|
439
|
+
timeout=timeout,
|
|
440
|
+
)
|
|
441
|
+
except FileNotFoundError:
|
|
442
|
+
raise FfmpegNotFound()
|
|
443
|
+
except subprocess.TimeoutExpired:
|
|
444
|
+
raise EncodeError(
|
|
445
|
+
f"ffmpeg 超时(>{timeout}s)",
|
|
446
|
+
hint="可尝试降低采样率或比特率",
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
if proc.returncode != 0:
|
|
450
|
+
err_lines = proc.stderr.strip().splitlines()
|
|
451
|
+
brief = err_lines[-3:] if len(err_lines) > 3 else err_lines
|
|
452
|
+
raise EncodeError(
|
|
453
|
+
f"ffmpeg 编码失败 (code={proc.returncode}): {'; '.join(brief)}",
|
|
454
|
+
hint="检查 ffmpeg 版本和编码器支持",
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def _check_ffmpeg():
|
|
459
|
+
"""检查 ffmpeg 是否可用"""
|
|
460
|
+
try:
|
|
461
|
+
subprocess.run(
|
|
462
|
+
["ffmpeg", "-version"],
|
|
463
|
+
capture_output=True,
|
|
464
|
+
timeout=10,
|
|
465
|
+
)
|
|
466
|
+
except (FileNotFoundError, subprocess.TimeoutExpired) as e:
|
|
467
|
+
raise FfmpegNotFound() from e
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def _check_disk_space(path: str, estimated_bytes: int):
|
|
471
|
+
"""写入前检查磁盘空间(预留 10% 余量)"""
|
|
472
|
+
try:
|
|
473
|
+
parent = os.path.dirname(os.path.abspath(path)) or "."
|
|
474
|
+
usage = shutil.disk_usage(parent)
|
|
475
|
+
needed = int(estimated_bytes * 1.1)
|
|
476
|
+
if usage.free < needed:
|
|
477
|
+
raise DiskSpaceError(
|
|
478
|
+
f"磁盘空间不足: 需要约 {needed / 1024**3:.2f} GB,"
|
|
479
|
+
f"可用 {usage.free / 1024**3:.2f} GB",
|
|
480
|
+
)
|
|
481
|
+
except OSError:
|
|
482
|
+
pass # 无法检查时静默跳过
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def _atomic_write(
|
|
486
|
+
path: str,
|
|
487
|
+
data: np.ndarray,
|
|
488
|
+
sample_rate: int,
|
|
489
|
+
cfg: FormatConfig,
|
|
490
|
+
subtype: str,
|
|
491
|
+
bitrate: Optional[str],
|
|
492
|
+
compression: Optional[int],
|
|
493
|
+
):
|
|
494
|
+
"""原子写入:写临时文件 → os.replace 覆盖"""
|
|
495
|
+
tmp_path = path + ".tmp." + uuid.uuid4().hex[:12]
|
|
496
|
+
try:
|
|
497
|
+
_do_write(tmp_path, data, sample_rate, cfg, subtype, bitrate, compression)
|
|
498
|
+
os.replace(tmp_path, path)
|
|
499
|
+
except Exception:
|
|
500
|
+
if os.path.exists(tmp_path):
|
|
501
|
+
os.unlink(tmp_path)
|
|
502
|
+
raise
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
def get_supported_formats() -> list:
|
|
506
|
+
"""返回格式列表(用于 API /models 端点)"""
|
|
507
|
+
result = []
|
|
508
|
+
for cfg in FORMATS.values():
|
|
509
|
+
entry = {
|
|
510
|
+
"format": cfg.name,
|
|
511
|
+
"description": cfg.description,
|
|
512
|
+
"extension": cfg.ext,
|
|
513
|
+
"default_subtype": cfg.default_subtype,
|
|
514
|
+
}
|
|
515
|
+
if cfg.subtypes:
|
|
516
|
+
entry["subtypes"] = list(cfg.subtypes)
|
|
517
|
+
if cfg.supports_bitrate:
|
|
518
|
+
entry["bitrate_range"] = list(cfg.bitrate_range)
|
|
519
|
+
entry["default_bitrate"] = cfg.default_bitrate
|
|
520
|
+
if cfg.supports_compression:
|
|
521
|
+
entry["compression_range"] = list(cfg.compression_range)
|
|
522
|
+
entry["default_compression"] = cfg.default_compression
|
|
523
|
+
result.append(entry)
|
|
524
|
+
return result
|
zipenhancer/denoise.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
import tempfile
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import librosa
|
|
7
|
+
import numpy as np
|
|
8
|
+
import soundfile as sf
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from zipenhancer.standalone import ZipEnhancerStandalone
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
# 模型列表
|
|
16
|
+
MODEL_ZIPENHANCER = "iic/speech_zipenhancer_ans_multiloss_16k_base"
|
|
17
|
+
MODEL_FRCRN = "iic/speech_frcrn_ans_cirm_16k"
|
|
18
|
+
MODEL_MOSSFORMER2 = "iic/speech_mossformer2_ans_48k"
|
|
19
|
+
|
|
20
|
+
# 模型固定使用 16kHz 处理
|
|
21
|
+
PROCESS_SR = 16000
|
|
22
|
+
|
|
23
|
+
# 模型缓存
|
|
24
|
+
_model_pipeline = None
|
|
25
|
+
_current_model_name: Optional[str] = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def load_model(model_name: str):
|
|
29
|
+
"""加载模型(返回模型对象,不缓存)"""
|
|
30
|
+
if model_name == MODEL_ZIPENHANCER:
|
|
31
|
+
logger.info(f"加载剥离后的 ZipEnhancer: {model_name}")
|
|
32
|
+
start = time.time()
|
|
33
|
+
ans = ZipEnhancerStandalone(model_name)
|
|
34
|
+
logger.info(f"模型加载完成,耗时: {time.time() - start:.1f}s")
|
|
35
|
+
return ans
|
|
36
|
+
|
|
37
|
+
from modelscope.pipelines import pipeline
|
|
38
|
+
from modelscope.utils.constant import Tasks
|
|
39
|
+
|
|
40
|
+
logger.info(f"加载模型: {model_name}")
|
|
41
|
+
start = time.time()
|
|
42
|
+
ans = pipeline(
|
|
43
|
+
Tasks.acoustic_noise_suppression,
|
|
44
|
+
model=model_name,
|
|
45
|
+
disable_update=True,
|
|
46
|
+
disable_log=True,
|
|
47
|
+
)
|
|
48
|
+
logger.info(f"模型加载完成,耗时: {time.time() - start:.1f}s")
|
|
49
|
+
return ans
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def ensure_model(model_name: str):
|
|
53
|
+
"""确保指定模型已加载,需要时才切换"""
|
|
54
|
+
global _model_pipeline, _current_model_name
|
|
55
|
+
|
|
56
|
+
if model_name != _current_model_name:
|
|
57
|
+
logger.info(f"切换模型: {_current_model_name} → {model_name}")
|
|
58
|
+
_model_pipeline = load_model(model_name)
|
|
59
|
+
_current_model_name = model_name
|
|
60
|
+
|
|
61
|
+
if _model_pipeline is None:
|
|
62
|
+
raise RuntimeError("模型未加载")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def normalize_audio(data: np.ndarray, target_db: float = -3.0) -> np.ndarray:
|
|
66
|
+
"""音量归一化到目标响度"""
|
|
67
|
+
peak = np.max(np.abs(data))
|
|
68
|
+
if peak > 1e-10:
|
|
69
|
+
target_peak = 10 ** (target_db / 20)
|
|
70
|
+
data = data * (target_peak / peak)
|
|
71
|
+
if np.max(np.abs(data)) > 0.99:
|
|
72
|
+
data = data * 0.95 / np.max(np.abs(data))
|
|
73
|
+
return data
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def denoise(
|
|
77
|
+
audio: np.ndarray,
|
|
78
|
+
sample_rate: int,
|
|
79
|
+
model: str = MODEL_ZIPENHANCER,
|
|
80
|
+
normalize: bool = True,
|
|
81
|
+
target_sr: int = 0,
|
|
82
|
+
strength: float = 1.0,
|
|
83
|
+
) -> Tuple[np.ndarray, float, float]:
|
|
84
|
+
""" 降噪
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
audio (np.ndarray): 输入音频
|
|
88
|
+
sample_rate (int): 输入采样率
|
|
89
|
+
model (str, optional): 模型名称. Defaults to MODEL_ZIPENHANCER.
|
|
90
|
+
normalize (bool, optional): 是否音量归一化. Defaults to True.
|
|
91
|
+
target_sr (int, optional): 目标采样率. Defaults to 0.
|
|
92
|
+
strength (float, optional): 降噪强度 0.0~1.0. Defaults to 1.0.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Tuple[np.ndarray, float, float]: (降噪后音频, 处理时间, 原始音频时长)
|
|
96
|
+
"""
|
|
97
|
+
ensure_model(model)
|
|
98
|
+
# 设置降噪强度
|
|
99
|
+
if hasattr(_model_pipeline, 'strength'):
|
|
100
|
+
_model_pipeline.strength = max(0.0, min(1.0, strength))
|
|
101
|
+
|
|
102
|
+
orig_channels = audio.shape[0] if audio.ndim > 1 else 1
|
|
103
|
+
output_sr = target_sr if target_sr > 0 else sample_rate
|
|
104
|
+
|
|
105
|
+
# 多声道混音为单声道用于模型处理
|
|
106
|
+
if audio.ndim > 1 and audio.shape[0] > 1:
|
|
107
|
+
audio_mono = librosa.to_mono(audio)
|
|
108
|
+
else:
|
|
109
|
+
audio_mono = audio.flatten() if audio.ndim > 1 else audio
|
|
110
|
+
|
|
111
|
+
# 重采样到 16kHz(模型固定)
|
|
112
|
+
if sample_rate != PROCESS_SR:
|
|
113
|
+
audio_mono = librosa.resample(audio_mono, orig_sr=sample_rate, target_sr=PROCESS_SR)
|
|
114
|
+
|
|
115
|
+
duration = len(audio_mono) / PROCESS_SR
|
|
116
|
+
|
|
117
|
+
# 写临时文件 → 模型处理 → 读取(模型接口只支持文件路径)
|
|
118
|
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_raw:
|
|
119
|
+
tmp_path = tmp_raw.name
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
sf.write(tmp_path, audio_mono, PROCESS_SR)
|
|
123
|
+
|
|
124
|
+
start = time.time()
|
|
125
|
+
_model_pipeline(tmp_path, output_path=tmp_path)
|
|
126
|
+
proc_time = time.time() - start
|
|
127
|
+
|
|
128
|
+
denoised, _ = sf.read(tmp_path)
|
|
129
|
+
finally:
|
|
130
|
+
if os.path.exists(tmp_path):
|
|
131
|
+
os.unlink(tmp_path)
|
|
132
|
+
|
|
133
|
+
# 重采样到目标采样率
|
|
134
|
+
if PROCESS_SR != output_sr:
|
|
135
|
+
denoised = librosa.resample(denoised, orig_sr=PROCESS_SR, target_sr=output_sr)
|
|
136
|
+
|
|
137
|
+
if normalize:
|
|
138
|
+
denoised = normalize_audio(denoised)
|
|
139
|
+
|
|
140
|
+
# 恢复原始声道数
|
|
141
|
+
if orig_channels > 1:
|
|
142
|
+
denoised = np.column_stack([denoised] * orig_channels)
|
|
143
|
+
|
|
144
|
+
return denoised, proc_time, duration
|