sonusai 0.18.7__py3-none-any.whl → 0.18.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +0 -1
- sonusai/mixture/feature.py +11 -9
- sonusai/mixture/helpers.py +4 -2
- sonusai/mixture/mixdb.py +77 -19
- {sonusai-0.18.7.dist-info → sonusai-0.18.9.dist-info}/METADATA +1 -1
- {sonusai-0.18.7.dist-info → sonusai-0.18.9.dist-info}/RECORD +8 -9
- sonusai/post_spenh_targetf.py +0 -160
- {sonusai-0.18.7.dist-info → sonusai-0.18.9.dist-info}/WHEEL +0 -0
- {sonusai-0.18.7.dist-info → sonusai-0.18.9.dist-info}/entry_points.txt +0 -0
sonusai/__init__.py
CHANGED
@@ -24,7 +24,6 @@ commands_doc = """
|
|
24
24
|
mkwav Make WAV files from a mixture database
|
25
25
|
onnx_predict Run ONNX predict on a trained model
|
26
26
|
plot Plot mixture data
|
27
|
-
post_spenh_targetf Run post-processing for speech enhancement targetf data
|
28
27
|
summarize_metric_spenh Summarize speech enhancement and analysis results
|
29
28
|
tplot Plot truth data
|
30
29
|
vars List custom SonusAI variables
|
sonusai/mixture/feature.py
CHANGED
@@ -27,9 +27,6 @@ def get_feature_from_audio(audio: AudioT,
|
|
27
27
|
num_classes=num_classes,
|
28
28
|
truth_mutex=truth_mutex)
|
29
29
|
|
30
|
-
feature_step_samples = fg.ftransform_R * fg.decimation * fg.step
|
31
|
-
audio = pad_audio_to_frame(audio, feature_step_samples)
|
32
|
-
|
33
30
|
audio_f = forward_transform(audio=audio,
|
34
31
|
config=TransformConfig(N=fg.ftransform_N,
|
35
32
|
R=fg.ftransform_R,
|
@@ -37,10 +34,8 @@ def get_feature_from_audio(audio: AudioT,
|
|
37
34
|
bin_end=fg.bin_end,
|
38
35
|
ttype=fg.ftransform_ttype))
|
39
36
|
|
40
|
-
|
41
|
-
|
42
|
-
feature_frames = samples // feature_step_samples
|
43
|
-
|
37
|
+
transform_frames = audio_f.shape[0]
|
38
|
+
feature_frames = transform_frames // (fg.decimation * fg.step)
|
44
39
|
feature = np.empty((feature_frames, fg.stride, fg.feature_parameters), dtype=np.float32)
|
45
40
|
|
46
41
|
feature_frame = 0
|
@@ -60,7 +55,7 @@ def get_audio_from_feature(feature: Feature,
|
|
60
55
|
truth_mutex: Optional[bool] = False) -> AudioT:
|
61
56
|
"""Apply inverse transform to feature data to generate audio data
|
62
57
|
|
63
|
-
:param feature: Feature data [frames,
|
58
|
+
:param feature: Feature data [frames, stride=1, feature_parameters]
|
64
59
|
:param feature_mode: Feature mode
|
65
60
|
:param num_classes: Number of classes
|
66
61
|
:param truth_mutex: Whether to calculate 'other' label
|
@@ -70,16 +65,23 @@ def get_audio_from_feature(feature: Feature,
|
|
70
65
|
|
71
66
|
from pyaaware import FeatureGenerator
|
72
67
|
|
68
|
+
from sonusai import SonusAIError
|
73
69
|
from .datatypes import TransformConfig
|
74
70
|
from .helpers import inverse_transform
|
75
71
|
from sonusai.utils.stacked_complex import unstack_complex
|
76
72
|
from sonusai.utils.compress import power_uncompress
|
77
73
|
|
74
|
+
if feature.ndim != 3:
|
75
|
+
raise SonusAIError('feature must have 3 dimensions: [frames, stride=1, feature_parameters]')
|
76
|
+
|
77
|
+
if feature.shape[1] != 1:
|
78
|
+
raise SonusAIError('Strided feature data is not supported for audio extraction; stride must be 1.')
|
79
|
+
|
78
80
|
fg = FeatureGenerator(feature_mode=feature_mode,
|
79
81
|
num_classes=num_classes,
|
80
82
|
truth_mutex=truth_mutex)
|
81
83
|
|
82
|
-
feature_complex = unstack_complex(feature)
|
84
|
+
feature_complex = unstack_complex(feature.squeeze())
|
83
85
|
if feature_mode[0:1] == 'h':
|
84
86
|
feature_complex = power_uncompress(feature_complex)
|
85
87
|
return np.squeeze(inverse_transform(transform=feature_complex,
|
sonusai/mixture/helpers.py
CHANGED
@@ -276,7 +276,6 @@ def read_mixture_data(name: str, items: list[str] | str) -> Any:
|
|
276
276
|
:return: Data (or tuple of data)
|
277
277
|
"""
|
278
278
|
from os.path import exists
|
279
|
-
from typing import Any
|
280
279
|
|
281
280
|
import h5py
|
282
281
|
import numpy as np
|
@@ -287,7 +286,10 @@ def read_mixture_data(name: str, items: list[str] | str) -> Any:
|
|
287
286
|
if d_name in file:
|
288
287
|
data = np.array(file[d_name])
|
289
288
|
if data.size == 1:
|
290
|
-
|
289
|
+
item = data.item()
|
290
|
+
if isinstance(item, bytes):
|
291
|
+
return item.decode('utf-8')
|
292
|
+
return item
|
291
293
|
return data
|
292
294
|
return None
|
293
295
|
|
sonusai/mixture/mixdb.py
CHANGED
@@ -214,8 +214,14 @@ class MixtureDatabase:
|
|
214
214
|
MetricDoc('Truth Metrics', 'sedtopn', '(not implemented) N most active by largest sedavg [N, 1]'),
|
215
215
|
])
|
216
216
|
for name in self.asr_configs:
|
217
|
+
metrics.append(MetricDoc('Target Metrics', f'tasr.{name}',
|
218
|
+
f'Target[0] ASR text using {name} ASR as defined in mixdb asr_configs parameter'))
|
219
|
+
metrics.append(MetricDoc('Mixture Metrics', f'mxasr.{name}',
|
220
|
+
f'ASR text using {name} ASR as defined in mixdb asr_configs parameter'))
|
221
|
+
metrics.append(MetricDoc('Target Metrics', f'basewer.{name}',
|
222
|
+
f'Word error rate of tasr.{name} vs. speech text metadata for the target'))
|
217
223
|
metrics.append(MetricDoc('Mixture Metrics', f'mxwer.{name}',
|
218
|
-
f'Word error rate
|
224
|
+
f'Word error rate of mxasr.{name} vs. tasr.{name}'))
|
219
225
|
|
220
226
|
return metrics
|
221
227
|
|
@@ -1185,7 +1191,7 @@ class MixtureDatabase:
|
|
1185
1191
|
|
1186
1192
|
def mixture_metrics(self, m_id: int,
|
1187
1193
|
metrics: list[str],
|
1188
|
-
force: bool = False) -> list[float | int | Segsnr]:
|
1194
|
+
force: bool = False) -> list[float | int | str | Segsnr]:
|
1189
1195
|
"""Get metrics data for the given mixture ID
|
1190
1196
|
|
1191
1197
|
:param m_id: Zero-based mixture ID
|
@@ -1328,7 +1334,56 @@ class MixtureDatabase:
|
|
1328
1334
|
|
1329
1335
|
noise_stats = create_noise_stats()
|
1330
1336
|
|
1331
|
-
def
|
1337
|
+
def create_asr_config() -> Callable[[str], dict]:
|
1338
|
+
state: dict[str, dict] = {}
|
1339
|
+
|
1340
|
+
def get(asr_name) -> dict:
|
1341
|
+
nonlocal state
|
1342
|
+
if asr_name not in state:
|
1343
|
+
state[asr_name] = self.asr_configs.get(asr_name, None)
|
1344
|
+
if state[asr_name] is None:
|
1345
|
+
raise SonusAIError(f"Unrecognized ASR name: '{asr_name}'")
|
1346
|
+
return state[asr_name]
|
1347
|
+
|
1348
|
+
return get
|
1349
|
+
|
1350
|
+
asr_config = create_asr_config()
|
1351
|
+
|
1352
|
+
def create_target_asr() -> Callable[[str], str]:
|
1353
|
+
state: dict[str, str] = {}
|
1354
|
+
|
1355
|
+
def get(asr_name) -> str:
|
1356
|
+
nonlocal state
|
1357
|
+
if asr_name not in state:
|
1358
|
+
state[asr_name] = calc_asr(target_audio(), **asr_config(asr_name)).text
|
1359
|
+
return state[asr_name]
|
1360
|
+
|
1361
|
+
return get
|
1362
|
+
|
1363
|
+
target_asr = create_target_asr()
|
1364
|
+
|
1365
|
+
def create_mixture_asr() -> Callable[[str], str]:
|
1366
|
+
state: dict[str, str] = {}
|
1367
|
+
|
1368
|
+
def get(asr_name) -> str:
|
1369
|
+
nonlocal state
|
1370
|
+
if asr_name not in state:
|
1371
|
+
state[asr_name] = calc_asr(mixture_audio(), **asr_config(asr_name)).text
|
1372
|
+
return state[asr_name]
|
1373
|
+
|
1374
|
+
return get
|
1375
|
+
|
1376
|
+
mixture_asr = create_mixture_asr()
|
1377
|
+
|
1378
|
+
def get_asr_name(m: str) -> str:
|
1379
|
+
parts = m.split('.')
|
1380
|
+
if len(parts) != 2:
|
1381
|
+
raise SonusAIError(
|
1382
|
+
f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
|
1383
|
+
asr_name = parts[1]
|
1384
|
+
return asr_name
|
1385
|
+
|
1386
|
+
def calc(m: str) -> float | int | str | Segsnr:
|
1332
1387
|
if m == 'mxsnr':
|
1333
1388
|
return self.mixture(m_id).snr
|
1334
1389
|
|
@@ -1340,31 +1395,31 @@ class MixtureDatabase:
|
|
1340
1395
|
|
1341
1396
|
# Otherwise, generate data as needed
|
1342
1397
|
if m.startswith('mxwer'):
|
1343
|
-
|
1344
|
-
if len(parts) != 2:
|
1345
|
-
raise SonusAIError(
|
1346
|
-
f"Unrecognized 'mxwer' format: '{m}'; must be of the form: 'mxwer.<name>'")
|
1347
|
-
asr_name = parts[1]
|
1348
|
-
asr_config = self.asr_configs.get(asr_name, None)
|
1349
|
-
if asr_config is None:
|
1350
|
-
raise SonusAIError(f"Unrecognized metric: '{m}'")
|
1398
|
+
asr_name = get_asr_name(m)
|
1351
1399
|
|
1352
1400
|
if self.mixture(m_id).snr < -96:
|
1353
1401
|
# noise only, ignore/reset target asr
|
1354
1402
|
return float('nan')
|
1355
1403
|
|
1356
|
-
|
1357
|
-
|
1358
|
-
if target_asr is None:
|
1359
|
-
target_asr = calc_asr(target_audio(), **asr_config).text
|
1404
|
+
if target_asr(asr_name):
|
1405
|
+
return calc_wer(mixture_asr(asr_name), target_asr(asr_name)).wer * 100
|
1360
1406
|
|
1361
|
-
|
1362
|
-
|
1363
|
-
|
1407
|
+
# TODO: should this be NaN like above?
|
1408
|
+
return float(0)
|
1409
|
+
|
1410
|
+
if m.startswith('basewer'):
|
1411
|
+
asr_name = get_asr_name(m)
|
1412
|
+
|
1413
|
+
text = self.mixture_speech_metadata(m_id, 'text')[0]
|
1414
|
+
if text is not None:
|
1415
|
+
return calc_wer(target_asr(asr_name), text).wer * 100
|
1364
1416
|
|
1365
1417
|
# TODO: should this be NaN like above?
|
1366
1418
|
return float(0)
|
1367
1419
|
|
1420
|
+
if m.startswith('mxasr'):
|
1421
|
+
return mixture_asr(get_asr_name(m))
|
1422
|
+
|
1368
1423
|
if m == 'mxssnr_avg':
|
1369
1424
|
return calc_segsnr_f(segsnr_f()).avg
|
1370
1425
|
|
@@ -1454,6 +1509,9 @@ class MixtureDatabase:
|
|
1454
1509
|
if m == 'tpkc':
|
1455
1510
|
return target_stats().pkc
|
1456
1511
|
|
1512
|
+
if m.startswith('tasr'):
|
1513
|
+
return target_asr(get_asr_name(m))
|
1514
|
+
|
1457
1515
|
if m == 'ndco':
|
1458
1516
|
return noise_stats().dco
|
1459
1517
|
|
@@ -1501,7 +1559,7 @@ class MixtureDatabase:
|
|
1501
1559
|
|
1502
1560
|
raise SonusAIError(f"Unrecognized metric: '{m}'")
|
1503
1561
|
|
1504
|
-
result: list[float | int | Segsnr] = []
|
1562
|
+
result: list[float | int | str | Segsnr] = []
|
1505
1563
|
for metric in metrics:
|
1506
1564
|
result.append(calc(metric))
|
1507
1565
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
sonusai/__init__.py,sha256=
|
1
|
+
sonusai/__init__.py,sha256=PakKXwYWgB0TZysZ6t9l6s33WobyS55qTR5jceMrADQ,3062
|
2
2
|
sonusai/aawscd_probwrite.py,sha256=GukR5owp_0A3DrqSl9fHWULYgclNft4D5OkHIwfxxkc,3698
|
3
3
|
sonusai/audiofe.py,sha256=3LssRiL73DH8teihD9f3nCvfZ0a65WQtXCqWGnKHuJM,11157
|
4
4
|
sonusai/calc_metric_spenh.py,sha256=ee2xrx6L1lFyWSoQSiq56He3RQ1cF7T_ak-6TjejXsc,47738
|
@@ -42,11 +42,11 @@ sonusai/mixture/constants.py,sha256=90qaRIEcmIoS3Od5h_UP0_SkkvG2aE_eYPv6WsIktC0,
|
|
42
42
|
sonusai/mixture/datatypes.py,sha256=2vegllgZcmFLq5NjqS7Lo97dOpOJOAj0Eml4ggP_tGo,10966
|
43
43
|
sonusai/mixture/db_datatypes.py,sha256=GDYbcSrlgUJsesiUUNnR4s5aBkMgviiNSQDaBcgYX7I,1428
|
44
44
|
sonusai/mixture/eq_rule_is_valid.py,sha256=MpQwRA5M76wSiQWEI1lW2cLFdPaMttBLcQp3tWD8efM,1243
|
45
|
-
sonusai/mixture/feature.py,sha256=
|
45
|
+
sonusai/mixture/feature.py,sha256=kYomwZpuvPQAZdb2MCaJBD8UD5LD2w5jTIkkRldaFlM,3839
|
46
46
|
sonusai/mixture/generation.py,sha256=W3n6ipI-dxg4Wj6YBJn8RTpFqkAyIXzxwObeFbSLq08,42801
|
47
|
-
sonusai/mixture/helpers.py,sha256=
|
47
|
+
sonusai/mixture/helpers.py,sha256=9x7gezEqPm5xKGAbwCqDMjedVEmoDWyFR_5-T_5nlno,24740
|
48
48
|
sonusai/mixture/log_duration_and_sizes.py,sha256=baTUpqyM15wA125jo9E3posmVJUe3WlpksyO6v9Jul0,1347
|
49
|
-
sonusai/mixture/mixdb.py,sha256=
|
49
|
+
sonusai/mixture/mixdb.py,sha256=EoH-kwg-zVJLAqpxbRKV7TtCxPqiBo3rIfdvCeZhEyI,64872
|
50
50
|
sonusai/mixture/soundfile_audio.py,sha256=BwO4lftNvrhoPTJERONcrpxSpM2fjO6kL_e5Ylz742A,4220
|
51
51
|
sonusai/mixture/sox_audio.py,sha256=DbHuyLtEuQYtKsIRxx6g1webW_LsdgLz52P5VO37MqI,17119
|
52
52
|
sonusai/mixture/sox_augmentation.py,sha256=kBWPrsFk0EBi71nLcKt5v0GA34bY7g9D9x0cEamNWbU,4564
|
@@ -68,7 +68,6 @@ sonusai/mixture/truth_functions/target.py,sha256=XypzXVMi24Ys13TiEM9JFY_cvHK61Lo
|
|
68
68
|
sonusai/mkwav.py,sha256=zfSyIiQTIK3KV9Ij33jkLhhZIMVYqaROcRQ4S7c4sIo,5364
|
69
69
|
sonusai/onnx_predict.py,sha256=jSxhD2oFyGSTHOGCXbW4fRT-k4SqKOboK2JaDO-yWcs,8737
|
70
70
|
sonusai/plot.py,sha256=ERkmxMM3qjcCDm4LGDQY4fRAncCYAzP7uW8iZ7_brcg,17105
|
71
|
-
sonusai/post_spenh_targetf.py,sha256=MBikRQfVfSZtRz9I5R3muxUtzR83S-i5INu3fAXliT4,4959
|
72
71
|
sonusai/queries/__init__.py,sha256=oKY5JeqZ4Cz7DwCwPc1_ydB8bUs6KaMcWFp_w02TjOs,255
|
73
72
|
sonusai/queries/queries.py,sha256=oV-m9uiLZOwYTK-Wo7Gf8dpGisaoGf6uDsAJAarVqZI,7553
|
74
73
|
sonusai/speech/__init__.py,sha256=SuPcU_K9wQISsZRIzsRNLtEC6cb616l-Jlx3PU-HWMs,113
|
@@ -120,7 +119,7 @@ sonusai/utils/stratified_shuffle_split.py,sha256=rJNXvBp-GxoKzH3OpL7k0ANSu5xMP2z
|
|
120
119
|
sonusai/utils/write_audio.py,sha256=ZsPGExwM86QHLLN2LOWekK2uAqf5pV_1oRW811p0QAI,840
|
121
120
|
sonusai/utils/yes_or_no.py,sha256=eMLXBVH0cEahiXY4W2KNORmwNQ-ba10eRtldh0y4NYg,263
|
122
121
|
sonusai/vars.py,sha256=m2AefF0m5bXWGXpJj8Pi42zWL2ydeEj7bkak3GrtMyM,940
|
123
|
-
sonusai-0.18.
|
124
|
-
sonusai-0.18.
|
125
|
-
sonusai-0.18.
|
126
|
-
sonusai-0.18.
|
122
|
+
sonusai-0.18.9.dist-info/METADATA,sha256=GdYfD7ldc9oJoMQxNgpG8Vs-RFOmP597X306RuMGi_M,2591
|
123
|
+
sonusai-0.18.9.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
124
|
+
sonusai-0.18.9.dist-info/entry_points.txt,sha256=zMNjEphEPO6B3cD1GNpit7z-yA9tUU5-j3W2v-UWstU,92
|
125
|
+
sonusai-0.18.9.dist-info/RECORD,,
|
sonusai/post_spenh_targetf.py
DELETED
@@ -1,160 +0,0 @@
|
|
1
|
-
"""sonusai post_spenh_targetf
|
2
|
-
|
3
|
-
usage: post_spenh_targetf [-hv] (-m MODEL) (-w KMODEL) INPUT ...
|
4
|
-
|
5
|
-
options:
|
6
|
-
-h, --help
|
7
|
-
-v, --verbose Be verbose.
|
8
|
-
-m MODEL, --model MODEL Python model file.
|
9
|
-
-w KMODEL, --weights KMODEL Keras model weights file.
|
10
|
-
|
11
|
-
Run post-processing on speech enhancement targetf prediction data.
|
12
|
-
|
13
|
-
Inputs:
|
14
|
-
MODEL A SonusAI Python model file with build and/or hypermodel functions.
|
15
|
-
KMODEL A Keras model weights file (or model file with weights).
|
16
|
-
INPUT A single H5 file or a glob of H5 files
|
17
|
-
|
18
|
-
Outputs the following to post_spenh_targetf-<TIMESTAMP> directory:
|
19
|
-
<name>.wav
|
20
|
-
post_spenh_targetf.log
|
21
|
-
|
22
|
-
"""
|
23
|
-
import signal
|
24
|
-
from dataclasses import dataclass
|
25
|
-
|
26
|
-
|
27
|
-
def signal_handler(_sig, _frame):
|
28
|
-
import sys
|
29
|
-
|
30
|
-
from sonusai import logger
|
31
|
-
|
32
|
-
logger.info('Canceled due to keyboard interrupt')
|
33
|
-
sys.exit(1)
|
34
|
-
|
35
|
-
|
36
|
-
signal.signal(signal.SIGINT, signal_handler)
|
37
|
-
|
38
|
-
|
39
|
-
@dataclass
|
40
|
-
class MPGlobal:
|
41
|
-
N: int = None
|
42
|
-
R: int = None
|
43
|
-
bin_start: int = None
|
44
|
-
bin_end: int = None
|
45
|
-
ttype: str = None
|
46
|
-
output_dir: str = None
|
47
|
-
|
48
|
-
|
49
|
-
MP_GLOBAL = MPGlobal()
|
50
|
-
|
51
|
-
|
52
|
-
def main() -> None:
|
53
|
-
from docopt import docopt
|
54
|
-
|
55
|
-
import sonusai
|
56
|
-
from sonusai.utils import trim_docstring
|
57
|
-
|
58
|
-
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
59
|
-
|
60
|
-
verbose = args['--verbose']
|
61
|
-
model_name = args['--model']
|
62
|
-
weights_name = args['--weights']
|
63
|
-
input_name = args['INPUT']
|
64
|
-
|
65
|
-
import time
|
66
|
-
from os import makedirs
|
67
|
-
from os.path import isfile
|
68
|
-
from os.path import join
|
69
|
-
from os.path import splitext
|
70
|
-
|
71
|
-
from pyaaware import FeatureGenerator
|
72
|
-
from tqdm import tqdm
|
73
|
-
|
74
|
-
from sonusai import create_file_handler
|
75
|
-
from sonusai import initial_log_messages
|
76
|
-
from sonusai import logger
|
77
|
-
from sonusai import update_console_handler
|
78
|
-
from sonusai.utils import create_ts_name
|
79
|
-
from sonusai.utils import import_and_check_keras_model
|
80
|
-
from sonusai.utils import pp_tqdm_imap
|
81
|
-
from sonusai.utils import seconds_to_hms
|
82
|
-
|
83
|
-
start_time = time.monotonic()
|
84
|
-
|
85
|
-
output_dir = create_ts_name('post_spenh_targetf')
|
86
|
-
makedirs(output_dir, exist_ok=True)
|
87
|
-
|
88
|
-
# Setup logging file
|
89
|
-
create_file_handler(join(output_dir, 'post_spenh_targetf.log'))
|
90
|
-
update_console_handler(verbose)
|
91
|
-
initial_log_messages('post_spenh_targetf')
|
92
|
-
|
93
|
-
hypermodel = import_and_check_keras_model(model_name=model_name, weights_name=weights_name)
|
94
|
-
|
95
|
-
fg = FeatureGenerator(feature_mode=hypermodel.feature,
|
96
|
-
num_classes=hypermodel.num_classes,
|
97
|
-
truth_mutex=hypermodel.truth_mutex)
|
98
|
-
|
99
|
-
MP_GLOBAL.N = fg.itransform_N
|
100
|
-
MP_GLOBAL.R = fg.itransform_R
|
101
|
-
MP_GLOBAL.bin_start = fg.bin_start
|
102
|
-
MP_GLOBAL.bin_end = fg.bin_end
|
103
|
-
MP_GLOBAL.ttype = fg.itransform_ttype
|
104
|
-
MP_GLOBAL.output_dir = output_dir
|
105
|
-
|
106
|
-
if not all(isfile(file) and splitext(file)[1] == '.h5' for file in input_name):
|
107
|
-
logger.exception(f'Do not know how to process input from {input_name}')
|
108
|
-
raise SystemExit(1)
|
109
|
-
|
110
|
-
logger.info('')
|
111
|
-
logger.info(f'Found {len(input_name):,} files to process')
|
112
|
-
|
113
|
-
progress = tqdm(total=len(input_name))
|
114
|
-
pp_tqdm_imap(_process, input_name, progress=progress)
|
115
|
-
progress.close()
|
116
|
-
|
117
|
-
logger.info(f'Wrote {len(input_name)} mixtures to {output_dir}')
|
118
|
-
logger.info('')
|
119
|
-
|
120
|
-
end_time = time.monotonic()
|
121
|
-
logger.info(f'Completed in {seconds_to_hms(seconds=end_time - start_time)}')
|
122
|
-
logger.info('')
|
123
|
-
|
124
|
-
|
125
|
-
def _process(file: str) -> None:
|
126
|
-
"""Run extraction on predict data to generate estimation audio
|
127
|
-
"""
|
128
|
-
from os.path import basename
|
129
|
-
from os.path import join
|
130
|
-
from os.path import splitext
|
131
|
-
|
132
|
-
import h5py
|
133
|
-
import numpy as np
|
134
|
-
from sonusai import InverseTransform
|
135
|
-
|
136
|
-
from sonusai import SonusAIError
|
137
|
-
from sonusai.mixture import get_audio_from_transform
|
138
|
-
from sonusai.utils import float_to_int16
|
139
|
-
from sonusai.utils import unstack_complex
|
140
|
-
from sonusai.utils import write_audio
|
141
|
-
|
142
|
-
try:
|
143
|
-
with h5py.File(file, 'r') as f:
|
144
|
-
predict = unstack_complex(np.array(f['predict']))
|
145
|
-
except Exception as e:
|
146
|
-
raise SonusAIError(f'Error reading {file}: {e}')
|
147
|
-
|
148
|
-
output_name = join(MP_GLOBAL.output_dir, splitext(basename(file))[0] + '.wav')
|
149
|
-
audio, _ = get_audio_from_transform(data=predict,
|
150
|
-
transform=InverseTransform(N=MP_GLOBAL.N,
|
151
|
-
R=MP_GLOBAL.R,
|
152
|
-
bin_start=MP_GLOBAL.bin_start,
|
153
|
-
bin_end=MP_GLOBAL.bin_end,
|
154
|
-
ttype=MP_GLOBAL.ttype,
|
155
|
-
gain=np.float32(1)))
|
156
|
-
write_audio(name=output_name, audio=float_to_int16(audio))
|
157
|
-
|
158
|
-
|
159
|
-
if __name__ == '__main__':
|
160
|
-
main()
|
File without changes
|
File without changes
|