zipenhancer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,32 @@
1
+ """
2
+ zipenhancer — 语音降噪核心包
3
+
4
+ 用法:
5
+ from zipenhancer import denoise, write
6
+
7
+ audio, sr = librosa.load("noisy.wav", sr=16000)
8
+ denoised, proc_time, duration = denoise(audio, sr)
9
+ write("output.flac", denoised, sr, fmt="flac")
10
+ """
11
+
12
+ from zipenhancer.codec import write, FORMATS, get_supported_formats, WriteResult
13
+ from zipenhancer.codec import FormatConfig, CodecError, FormatNotSupported
14
+ from zipenhancer.denoise import denoise, load_model, ensure_model, normalize_audio
15
+ from zipenhancer.denoise import MODEL_ZIPENHANCER, MODEL_FRCRN, MODEL_MOSSFORMER2
16
+
17
+ __all__ = [
18
+ "denoise",
19
+ "write",
20
+ "load_model",
21
+ "ensure_model",
22
+ "normalize_audio",
23
+ "FORMATS",
24
+ "get_supported_formats",
25
+ "WriteResult",
26
+ "FormatConfig",
27
+ "CodecError",
28
+ "FormatNotSupported",
29
+ "MODEL_ZIPENHANCER",
30
+ "MODEL_FRCRN",
31
+ "MODEL_MOSSFORMER2",
32
+ ]
zipenhancer/codec.py ADDED
@@ -0,0 +1,524 @@
1
+ import os
2
+ import shutil
3
+ import subprocess
4
+ import uuid
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ import numpy as np
9
+ import soundfile as sf
10
+
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class FormatConfig:
18
+ """格式能力定义"""
19
+ name: str # 唯一标识符: wav / flac / mp3 / ogg
20
+ description: str # API接口中展示的描述
21
+ ext: str # 文件扩展名(含点)
22
+ subtypes: tuple # 有效 subtype 列表,空 tuple 表示由 codec 名决定
23
+ default_subtype: str # 默认 subtype
24
+ needs_ffmpeg: bool # 是否依赖 ffmpeg
25
+ supports_compression: bool = False
26
+ supports_bitrate: bool = False
27
+ compression_range: tuple = ()
28
+ bitrate_range: tuple = ()
29
+ default_bitrate: Optional[str] = None
30
+ default_compression: Optional[int] = None
31
+
32
+
33
+ FORMATS = {
34
+ "wav": FormatConfig(
35
+ name="wav",
36
+ description="WAV(无损,兼容性最佳)",
37
+ ext=".wav",
38
+ subtypes=("PCM_16", "PCM_24", "PCM_32", "FLOAT"),
39
+ default_subtype="PCM_16",
40
+ needs_ffmpeg=False,
41
+ ),
42
+ "flac": FormatConfig(
43
+ name="flac",
44
+ description="FLAC(无损压缩,文件较小)",
45
+ ext=".flac",
46
+ subtypes=("PCM_16", "PCM_24"),
47
+ default_subtype="PCM_16",
48
+ needs_ffmpeg=False,
49
+ supports_compression=True,
50
+ compression_range=(0, 8),
51
+ default_compression=5,
52
+ ),
53
+ "mp3": FormatConfig(
54
+ name="mp3",
55
+ description="MP3(有损压缩,广泛兼容)",
56
+ ext=".mp3",
57
+ subtypes=(),
58
+ default_subtype="libmp3lame",
59
+ needs_ffmpeg=True,
60
+ supports_bitrate=True,
61
+ bitrate_range=(32000, 320000),
62
+ default_bitrate="192k",
63
+ ),
64
+ "ogg": FormatConfig(
65
+ name="ogg",
66
+ description="OGG Opus/Vorbis(有损压缩,开源优选)",
67
+ ext=".ogg",
68
+ subtypes=(),
69
+ default_subtype="opus",
70
+ needs_ffmpeg=True,
71
+ supports_bitrate=True,
72
+ bitrate_range=(6000, 510000),
73
+ default_bitrate="192k",
74
+ ),
75
+ }
76
+
77
+
78
+ class CodecError(Exception):
79
+ def __init__(self, message: str, hint: str = ""):
80
+ """编解码错误基类
81
+
82
+ Args:
83
+ message (str): 错误描述
84
+ hint (str, optional): 修复提示. Defaults to "".
85
+ """
86
+ self.hint = hint
87
+ super().__init__(message if not hint else f"{message}({hint})")
88
+
89
+
90
+ class FormatNotSupported(CodecError):
91
+ """不支持的输出格式"""
92
+ def __init__(self, fmt: str):
93
+ super().__init__(
94
+ f"不支持的输出格式: {fmt}",
95
+ f"可选: {', '.join(FORMATS)}",
96
+ )
97
+
98
+
99
+ class SubtypeNotSupported(CodecError):
100
+ """不支持的编码参数"""
101
+ def __init__(self, fmt: str, subtype: str, valid: tuple):
102
+ super().__init__(
103
+ f"格式 {fmt} 不支持编码 {subtype}",
104
+ f"{fmt} 支持的编码: {', '.join(valid)}",
105
+ )
106
+
107
+
108
+ class FfmpegNotFound(CodecError):
109
+ """ffmpeg 未安装"""
110
+ def __init__(self):
111
+ super().__init__(
112
+ "系统未找到 ffmpeg",
113
+ "请安装 ffmpeg 并将其加入 PATH(https://ffmpeg.org/download.html)",
114
+ )
115
+
116
+
117
+ class DiskSpaceError(CodecError):
118
+ """磁盘空间不足"""
119
+ pass
120
+
121
+
122
+ class EncodeError(CodecError):
123
+ """编码失败"""
124
+ pass
125
+
126
+
127
+ @dataclass
128
+ class WriteResult:
129
+ """编码写入结果"""
130
+ path: str
131
+ format: str
132
+ subtype: str
133
+ bitrate: Optional[str] = None
134
+ compression: Optional[int] = None
135
+
136
+
137
+ def write(
138
+ path: str,
139
+ data: np.ndarray,
140
+ sample_rate: int,
141
+ fmt: str = "wav",
142
+ subtype: Optional[str] = None,
143
+ bitrate: Optional[str] = None,
144
+ compression: Optional[int] = None,
145
+ atomic: bool = True,
146
+ ) -> WriteResult:
147
+ """_summary_
148
+
149
+ Args:
150
+ path (str): 输出文件路径
151
+ data (np.ndarray): 音频数据
152
+ sample_rate (int): 采样率
153
+ fmt (str, optional): 输出格式. Defaults to "wav".
154
+ subtype (Optional[str], optional): 编码子类型. Defaults to None.
155
+ bitrate (Optional[str], optional): 比特率. Defaults to None.
156
+ compression (Optional[int], optional): 压缩级别. Defaults to None.
157
+ atomic (bool, optional): 是否原子写入. Defaults to True.
158
+
159
+ Raises:
160
+ FormatNotSupported: _description_
161
+
162
+ Returns:
163
+ WriteResult: WriteResult
164
+ """
165
+ cfg = FORMATS.get(fmt)
166
+ if cfg is None:
167
+ raise FormatNotSupported(fmt)
168
+
169
+ # 确保输出目录存在
170
+ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
171
+
172
+ # 1. 校验 & 填充默认值
173
+ subtype, bitrate, compression = _resolve_params(cfg, subtype, bitrate, compression)
174
+
175
+ # 2. 估算大小 + 磁盘空间检查
176
+ estimated = _estimate_size(data, sample_rate, cfg, subtype, bitrate, compression)
177
+ _check_disk_space(path, estimated)
178
+
179
+ # 3. 写入
180
+ if atomic:
181
+ _atomic_write(path, data, sample_rate, cfg, subtype, bitrate, compression)
182
+ else:
183
+ _do_write(path, data, sample_rate, cfg, subtype, bitrate, compression)
184
+
185
+ return WriteResult(
186
+ path=path,
187
+ format=fmt,
188
+ subtype=subtype,
189
+ bitrate=bitrate,
190
+ compression=compression,
191
+ )
192
+
193
+
194
+ def _resolve_params(
195
+ cfg: FormatConfig,
196
+ subtype: Optional[str],
197
+ bitrate: Optional[str],
198
+ compression: Optional[int],
199
+ ) -> tuple:
200
+ """校验并填充编码参数,必要时自动降级"""
201
+ # subtype
202
+ if subtype is None:
203
+ subtype = cfg.default_subtype
204
+ else:
205
+ if cfg.subtypes and subtype not in cfg.subtypes:
206
+ logger.warning(
207
+ "格式 %s 不支持编码 %s,降级为 %s",
208
+ cfg.name, subtype, cfg.default_subtype,
209
+ )
210
+ subtype = cfg.default_subtype
211
+
212
+ # bitrate
213
+ if cfg.supports_bitrate:
214
+ if bitrate is None:
215
+ bitrate = cfg.default_bitrate
216
+ else:
217
+ try:
218
+ bps = _parse_bitrate(bitrate)
219
+ except (ValueError, TypeError):
220
+ logger.warning("比特率格式无效 %s,使用默认值 %s", bitrate, cfg.default_bitrate)
221
+ bitrate = cfg.default_bitrate
222
+ else:
223
+ if bps < cfg.bitrate_range[0] or bps > cfg.bitrate_range[1]:
224
+ logger.warning(
225
+ "比特率 %s 超出 %s 范围 (%d-%d bps),使用默认值 %s",
226
+ bitrate, cfg.name, cfg.bitrate_range[0], cfg.bitrate_range[1],
227
+ cfg.default_bitrate,
228
+ )
229
+ bitrate = cfg.default_bitrate
230
+ else:
231
+ bitrate = None
232
+
233
+ # compression
234
+ if cfg.supports_compression:
235
+ if compression is None:
236
+ compression = cfg.default_compression
237
+ elif compression < cfg.compression_range[0] or compression > cfg.compression_range[1]:
238
+ logger.warning(
239
+ "压缩级别 %d 超出 %s 范围 (%d-%d),使用默认值 %d",
240
+ compression, cfg.name, cfg.compression_range[0], cfg.compression_range[1],
241
+ cfg.default_compression,
242
+ )
243
+ compression = cfg.default_compression
244
+ else:
245
+ compression = None
246
+
247
+ return subtype, bitrate, compression
248
+
249
+
250
+ def _parse_bitrate(bitrate: str) -> int:
251
+ """将比特率字符串转为 bps,如 '192k' → 192000"""
252
+ s = str(bitrate).strip().lower()
253
+ if s.endswith("k"):
254
+ return int(float(s[:-1]) * 1000)
255
+ elif s.endswith("m"):
256
+ return int(float(s[:-1]) * 1_000_000)
257
+ else:
258
+ return int(s)
259
+
260
+
261
+ def _estimate_size(
262
+ data: np.ndarray,
263
+ sample_rate: int,
264
+ cfg: FormatConfig,
265
+ subtype: str,
266
+ bitrate: Optional[str],
267
+ compression: Optional[int],
268
+ ) -> int:
269
+ """估算输出文件大小(字节)"""
270
+ if data.ndim == 1:
271
+ n_channels = 1
272
+ n_samples = data.shape[0]
273
+ else:
274
+ n_channels = data.shape[1] if data.shape[1] > 0 else 1
275
+ n_samples = data.shape[0]
276
+
277
+ duration = n_samples / sample_rate
278
+
279
+ if cfg.name == "wav":
280
+ bps = _subtype_bytes(subtype)
281
+ return int(n_samples * n_channels * bps * 1.02) # + header overhead
282
+ elif cfg.name == "flac":
283
+ bps = _subtype_bytes(subtype)
284
+ # FLAC 通常压缩到原始 PCM 的 50-70%
285
+ return int(n_samples * n_channels * bps * 0.65)
286
+ elif cfg.name == "mp3" and bitrate:
287
+ bps = _parse_bitrate(bitrate)
288
+ return int(bps / 8 * duration * 1.05)
289
+ elif cfg.name == "ogg" and bitrate:
290
+ bps = _parse_bitrate(bitrate)
291
+ return int(bps / 8 * duration * 1.05)
292
+ else:
293
+ return int(n_samples * n_channels * 4 * 1.1)
294
+
295
+
296
+ def _subtype_bytes(subtype: str) -> int:
297
+ """subtype → 每样本字节数"""
298
+ return {
299
+ "PCM_S8": 1, "PCM_U8": 1,
300
+ "PCM_16": 2,
301
+ "PCM_24": 3,
302
+ "PCM_32": 4, "FLOAT": 4,
303
+ "DOUBLE": 8,
304
+ }.get(subtype, 4)
305
+
306
+
307
+ def _do_write(
308
+ path: str,
309
+ data: np.ndarray,
310
+ sample_rate: int,
311
+ cfg: FormatConfig,
312
+ subtype: str,
313
+ bitrate: Optional[str],
314
+ compression: Optional[int],
315
+ ):
316
+ """实际写入(依格式派发)"""
317
+ if cfg.name in ("wav", "flac"):
318
+ _write_sf(path, data, sample_rate, cfg, subtype, compression)
319
+ elif cfg.needs_ffmpeg:
320
+ _write_ffmpeg(path, data, sample_rate, cfg, subtype, bitrate)
321
+ else:
322
+ raise EncodeError(f"未知格式: {cfg.name}")
323
+
324
+
325
+ def _write_sf(
326
+ path: str,
327
+ data: np.ndarray,
328
+ sample_rate: int,
329
+ cfg: FormatConfig,
330
+ subtype: str,
331
+ compression: Optional[int],
332
+ ):
333
+ """通过 soundfile 写入 WAV/FLAC"""
334
+ # data 约定: [samples] 或 [samples, channels],已是 soundfile 原生格式
335
+ if data.ndim == 1:
336
+ data_for_sf = data
337
+ else:
338
+ data_for_sf = data
339
+
340
+ # RF64 自动降级(>4GB WAV)
341
+ sf_format = cfg.name.upper()
342
+ if cfg.name == "wav" and data_for_sf.nbytes > 3.5 * 1024 ** 3:
343
+ sf_format = "RF64"
344
+ logger.info("文件超过 4GB,自动使用 RF64 格式")
345
+
346
+ sf_kwargs = {"format": sf_format, "subtype": subtype}
347
+ if compression is not None and cfg.name == "flac":
348
+ # soundfile 0.13.x 的 compression_level: FLAC 有效范围 0.0-1.0
349
+ sf_kwargs["compression_level"] = compression / 8.0
350
+
351
+ try:
352
+ sf.write(path, data_for_sf, sample_rate, **sf_kwargs)
353
+ except Exception as e:
354
+ raise EncodeError(
355
+ f"soundfile 编码失败: {e}",
356
+ hint="检查磁盘空间和文件权限",
357
+ ) from e
358
+
359
+
360
+ # 每个格式的首选编码器列表(按优先级)
361
+ _FFMPEG_ENCODERS = {
362
+ "mp3": ["libmp3lame", "mp3_mf"],
363
+ "ogg": ["libvorbis", "opus", "vorbis"],
364
+ }
365
+ _encoder_cache = {}
366
+
367
+
368
+ def _detect_encoder(format_name: str) -> str:
369
+ """自动检测系统可用的 ffmpeg 编码器(缓存结果)"""
370
+ if format_name in _encoder_cache:
371
+ return _encoder_cache[format_name]
372
+
373
+ candidates = _FFMPEG_ENCODERS.get(format_name, [])
374
+ if not candidates:
375
+ _encoder_cache[format_name] = ""
376
+ return ""
377
+
378
+ try:
379
+ proc = subprocess.run(
380
+ ["ffmpeg", "-encoders"],
381
+ capture_output=True, text=True, timeout=10,
382
+ )
383
+ available = proc.stdout
384
+ except Exception:
385
+ _encoder_cache[format_name] = candidates[0]
386
+ return candidates[0]
387
+
388
+ for enc in candidates:
389
+ if enc in available:
390
+ _encoder_cache[format_name] = enc
391
+ logger.info("ffmpeg 编码器检测: %s → %s", format_name, enc)
392
+ return enc
393
+
394
+ _encoder_cache[format_name] = candidates[0]
395
+ logger.warning("ffmpeg 编码器 %s 均未检测到,尝试 %s", candidates, candidates[0])
396
+ return candidates[0]
397
+
398
+
399
+ def _write_ffmpeg(
400
+ path: str,
401
+ data: np.ndarray,
402
+ sample_rate: int,
403
+ cfg: FormatConfig,
404
+ subtype: str,
405
+ bitrate: Optional[str],
406
+ ):
407
+ """通过 ffmpeg 写入 MP3/OGG(自动检测编码器)"""
408
+ # 先写临时 WAV,再转码
409
+ tmp_wav = path + ".tmp.wav." + uuid.uuid4().hex[:12]
410
+ try:
411
+ _write_sf(tmp_wav, data, sample_rate, FORMATS["wav"], "PCM_16", None)
412
+
413
+ encoder = _detect_encoder(cfg.name)
414
+ cmd = ["ffmpeg", "-y", "-i", tmp_wav, "-codec:a", encoder]
415
+ if bitrate:
416
+ cmd.extend(["-b:a", bitrate])
417
+
418
+ # opus 编码器在部分 Windows 构建中标记为 experimental
419
+ if encoder == "opus":
420
+ cmd.extend(["-strict", "-2"])
421
+
422
+ cmd.extend(["-f", cfg.name, path])
423
+
424
+ _run_ffmpeg(cmd)
425
+ finally:
426
+ if os.path.exists(tmp_wav):
427
+ os.unlink(tmp_wav)
428
+
429
+
430
+ def _run_ffmpeg(cmd: list, timeout: int = 300):
431
+ """执行 ffmpeg 命令"""
432
+ _check_ffmpeg()
433
+
434
+ try:
435
+ proc = subprocess.run(
436
+ cmd,
437
+ capture_output=True,
438
+ text=True,
439
+ timeout=timeout,
440
+ )
441
+ except FileNotFoundError:
442
+ raise FfmpegNotFound()
443
+ except subprocess.TimeoutExpired:
444
+ raise EncodeError(
445
+ f"ffmpeg 超时(>{timeout}s)",
446
+ hint="可尝试降低采样率或比特率",
447
+ )
448
+
449
+ if proc.returncode != 0:
450
+ err_lines = proc.stderr.strip().splitlines()
451
+ brief = err_lines[-3:] if len(err_lines) > 3 else err_lines
452
+ raise EncodeError(
453
+ f"ffmpeg 编码失败 (code={proc.returncode}): {'; '.join(brief)}",
454
+ hint="检查 ffmpeg 版本和编码器支持",
455
+ )
456
+
457
+
458
+ def _check_ffmpeg():
459
+ """检查 ffmpeg 是否可用"""
460
+ try:
461
+ subprocess.run(
462
+ ["ffmpeg", "-version"],
463
+ capture_output=True,
464
+ timeout=10,
465
+ )
466
+ except (FileNotFoundError, subprocess.TimeoutExpired) as e:
467
+ raise FfmpegNotFound() from e
468
+
469
+
470
+ def _check_disk_space(path: str, estimated_bytes: int):
471
+ """写入前检查磁盘空间(预留 10% 余量)"""
472
+ try:
473
+ parent = os.path.dirname(os.path.abspath(path)) or "."
474
+ usage = shutil.disk_usage(parent)
475
+ needed = int(estimated_bytes * 1.1)
476
+ if usage.free < needed:
477
+ raise DiskSpaceError(
478
+ f"磁盘空间不足: 需要约 {needed / 1024**3:.2f} GB,"
479
+ f"可用 {usage.free / 1024**3:.2f} GB",
480
+ )
481
+ except OSError:
482
+ pass # 无法检查时静默跳过
483
+
484
+
485
+ def _atomic_write(
486
+ path: str,
487
+ data: np.ndarray,
488
+ sample_rate: int,
489
+ cfg: FormatConfig,
490
+ subtype: str,
491
+ bitrate: Optional[str],
492
+ compression: Optional[int],
493
+ ):
494
+ """原子写入:写临时文件 → os.replace 覆盖"""
495
+ tmp_path = path + ".tmp." + uuid.uuid4().hex[:12]
496
+ try:
497
+ _do_write(tmp_path, data, sample_rate, cfg, subtype, bitrate, compression)
498
+ os.replace(tmp_path, path)
499
+ except Exception:
500
+ if os.path.exists(tmp_path):
501
+ os.unlink(tmp_path)
502
+ raise
503
+
504
+
505
+ def get_supported_formats() -> list:
506
+ """返回格式列表(用于 API /models 端点)"""
507
+ result = []
508
+ for cfg in FORMATS.values():
509
+ entry = {
510
+ "format": cfg.name,
511
+ "description": cfg.description,
512
+ "extension": cfg.ext,
513
+ "default_subtype": cfg.default_subtype,
514
+ }
515
+ if cfg.subtypes:
516
+ entry["subtypes"] = list(cfg.subtypes)
517
+ if cfg.supports_bitrate:
518
+ entry["bitrate_range"] = list(cfg.bitrate_range)
519
+ entry["default_bitrate"] = cfg.default_bitrate
520
+ if cfg.supports_compression:
521
+ entry["compression_range"] = list(cfg.compression_range)
522
+ entry["default_compression"] = cfg.default_compression
523
+ result.append(entry)
524
+ return result
zipenhancer/denoise.py ADDED
@@ -0,0 +1,144 @@
1
+ import os
2
+ import time
3
+ import tempfile
4
+ from typing import Optional, Tuple
5
+
6
+ import librosa
7
+ import numpy as np
8
+ import soundfile as sf
9
+
10
+ import logging
11
+ from zipenhancer.standalone import ZipEnhancerStandalone
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # 模型列表
16
+ MODEL_ZIPENHANCER = "iic/speech_zipenhancer_ans_multiloss_16k_base"
17
+ MODEL_FRCRN = "iic/speech_frcrn_ans_cirm_16k"
18
+ MODEL_MOSSFORMER2 = "iic/speech_mossformer2_ans_48k"
19
+
20
+ # 模型固定使用 16kHz 处理
21
+ PROCESS_SR = 16000
22
+
23
+ # 模型缓存
24
+ _model_pipeline = None
25
+ _current_model_name: Optional[str] = None
26
+
27
+
28
+ def load_model(model_name: str):
29
+ """加载模型(返回模型对象,不缓存)"""
30
+ if model_name == MODEL_ZIPENHANCER:
31
+ logger.info(f"加载剥离后的 ZipEnhancer: {model_name}")
32
+ start = time.time()
33
+ ans = ZipEnhancerStandalone(model_name)
34
+ logger.info(f"模型加载完成,耗时: {time.time() - start:.1f}s")
35
+ return ans
36
+
37
+ from modelscope.pipelines import pipeline
38
+ from modelscope.utils.constant import Tasks
39
+
40
+ logger.info(f"加载模型: {model_name}")
41
+ start = time.time()
42
+ ans = pipeline(
43
+ Tasks.acoustic_noise_suppression,
44
+ model=model_name,
45
+ disable_update=True,
46
+ disable_log=True,
47
+ )
48
+ logger.info(f"模型加载完成,耗时: {time.time() - start:.1f}s")
49
+ return ans
50
+
51
+
52
+ def ensure_model(model_name: str):
53
+ """确保指定模型已加载,需要时才切换"""
54
+ global _model_pipeline, _current_model_name
55
+
56
+ if model_name != _current_model_name:
57
+ logger.info(f"切换模型: {_current_model_name} → {model_name}")
58
+ _model_pipeline = load_model(model_name)
59
+ _current_model_name = model_name
60
+
61
+ if _model_pipeline is None:
62
+ raise RuntimeError("模型未加载")
63
+
64
+
65
+ def normalize_audio(data: np.ndarray, target_db: float = -3.0) -> np.ndarray:
66
+ """音量归一化到目标响度"""
67
+ peak = np.max(np.abs(data))
68
+ if peak > 1e-10:
69
+ target_peak = 10 ** (target_db / 20)
70
+ data = data * (target_peak / peak)
71
+ if np.max(np.abs(data)) > 0.99:
72
+ data = data * 0.95 / np.max(np.abs(data))
73
+ return data
74
+
75
+
76
+ def denoise(
77
+ audio: np.ndarray,
78
+ sample_rate: int,
79
+ model: str = MODEL_ZIPENHANCER,
80
+ normalize: bool = True,
81
+ target_sr: int = 0,
82
+ strength: float = 1.0,
83
+ ) -> Tuple[np.ndarray, float, float]:
84
+ """ 降噪
85
+
86
+ Args:
87
+ audio (np.ndarray): 输入音频
88
+ sample_rate (int): 输入采样率
89
+ model (str, optional): 模型名称. Defaults to MODEL_ZIPENHANCER.
90
+ normalize (bool, optional): 是否音量归一化. Defaults to True.
91
+ target_sr (int, optional): 目标采样率. Defaults to 0.
92
+ strength (float, optional): 降噪强度 0.0~1.0. Defaults to 1.0.
93
+
94
+ Returns:
95
+ Tuple[np.ndarray, float, float]: (降噪后音频, 处理时间, 原始音频时长)
96
+ """
97
+ ensure_model(model)
98
+ # 设置降噪强度
99
+ if hasattr(_model_pipeline, 'strength'):
100
+ _model_pipeline.strength = max(0.0, min(1.0, strength))
101
+
102
+ orig_channels = audio.shape[0] if audio.ndim > 1 else 1
103
+ output_sr = target_sr if target_sr > 0 else sample_rate
104
+
105
+ # 多声道混音为单声道用于模型处理
106
+ if audio.ndim > 1 and audio.shape[0] > 1:
107
+ audio_mono = librosa.to_mono(audio)
108
+ else:
109
+ audio_mono = audio.flatten() if audio.ndim > 1 else audio
110
+
111
+ # 重采样到 16kHz(模型固定)
112
+ if sample_rate != PROCESS_SR:
113
+ audio_mono = librosa.resample(audio_mono, orig_sr=sample_rate, target_sr=PROCESS_SR)
114
+
115
+ duration = len(audio_mono) / PROCESS_SR
116
+
117
+ # 写临时文件 → 模型处理 → 读取(模型接口只支持文件路径)
118
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_raw:
119
+ tmp_path = tmp_raw.name
120
+
121
+ try:
122
+ sf.write(tmp_path, audio_mono, PROCESS_SR)
123
+
124
+ start = time.time()
125
+ _model_pipeline(tmp_path, output_path=tmp_path)
126
+ proc_time = time.time() - start
127
+
128
+ denoised, _ = sf.read(tmp_path)
129
+ finally:
130
+ if os.path.exists(tmp_path):
131
+ os.unlink(tmp_path)
132
+
133
+ # 重采样到目标采样率
134
+ if PROCESS_SR != output_sr:
135
+ denoised = librosa.resample(denoised, orig_sr=PROCESS_SR, target_sr=output_sr)
136
+
137
+ if normalize:
138
+ denoised = normalize_audio(denoised)
139
+
140
+ # 恢复原始声道数
141
+ if orig_channels > 1:
142
+ denoised = np.column_stack([denoised] * orig_channels)
143
+
144
+ return denoised, proc_time, duration
@@ -0,0 +1,4 @@
1
+ # ------ coding : utf-8 ------
2
+ # @FileName : __init__.py.py
3
+ # @Author : lxc
4
+ # @Time : 2025/3/4 11:39
@@ -0,0 +1,4 @@
1
+ # ------ coding : utf-8 ------
2
+ # @FileName : __init__.py.py
3
+ # @Author : lxc
4
+ # @Time : 2025/3/4 11:39