sonusai 0.18.7__py3-none-any.whl → 0.18.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +0 -1
- sonusai/mixture/helpers.py +4 -2
- sonusai/mixture/mixdb.py +65 -19
- {sonusai-0.18.7.dist-info → sonusai-0.18.8.dist-info}/METADATA +1 -1
- {sonusai-0.18.7.dist-info → sonusai-0.18.8.dist-info}/RECORD +7 -8
- sonusai/post_spenh_targetf.py +0 -160
- {sonusai-0.18.7.dist-info → sonusai-0.18.8.dist-info}/WHEEL +0 -0
- {sonusai-0.18.7.dist-info → sonusai-0.18.8.dist-info}/entry_points.txt +0 -0
sonusai/__init__.py
CHANGED
@@ -24,7 +24,6 @@ commands_doc = """
|
|
24
24
|
mkwav Make WAV files from a mixture database
|
25
25
|
onnx_predict Run ONNX predict on a trained model
|
26
26
|
plot Plot mixture data
|
27
|
-
post_spenh_targetf Run post-processing for speech enhancement targetf data
|
28
27
|
summarize_metric_spenh Summarize speech enhancement and analysis results
|
29
28
|
tplot Plot truth data
|
30
29
|
vars List custom SonusAI variables
|
sonusai/mixture/helpers.py
CHANGED
@@ -276,7 +276,6 @@ def read_mixture_data(name: str, items: list[str] | str) -> Any:
|
|
276
276
|
:return: Data (or tuple of data)
|
277
277
|
"""
|
278
278
|
from os.path import exists
|
279
|
-
from typing import Any
|
280
279
|
|
281
280
|
import h5py
|
282
281
|
import numpy as np
|
@@ -287,7 +286,10 @@ def read_mixture_data(name: str, items: list[str] | str) -> Any:
|
|
287
286
|
if d_name in file:
|
288
287
|
data = np.array(file[d_name])
|
289
288
|
if data.size == 1:
|
290
|
-
|
289
|
+
item = data.item()
|
290
|
+
if isinstance(item, bytes):
|
291
|
+
return item.decode('utf-8')
|
292
|
+
return item
|
291
293
|
return data
|
292
294
|
return None
|
293
295
|
|
sonusai/mixture/mixdb.py
CHANGED
@@ -214,6 +214,10 @@ class MixtureDatabase:
|
|
214
214
|
MetricDoc('Truth Metrics', 'sedtopn', '(not implemented) N most active by largest sedavg [N, 1]'),
|
215
215
|
])
|
216
216
|
for name in self.asr_configs:
|
217
|
+
metrics.append(MetricDoc('Target Metrics', f'tasr.{name}',
|
218
|
+
f'Target[0] ASR text using {name} ASR as defined in mixdb asr_configs parameter'))
|
219
|
+
metrics.append(MetricDoc('Mixture Metrics', f'mxasr.{name}',
|
220
|
+
f'ASR text using {name} ASR as defined in mixdb asr_configs parameter'))
|
217
221
|
metrics.append(MetricDoc('Mixture Metrics', f'mxwer.{name}',
|
218
222
|
f'Word error rate using {name} ASR as defined in mixdb asr_configs parameter'))
|
219
223
|
|
@@ -1185,7 +1189,7 @@ class MixtureDatabase:
|
|
1185
1189
|
|
1186
1190
|
def mixture_metrics(self, m_id: int,
|
1187
1191
|
metrics: list[str],
|
1188
|
-
force: bool = False) -> list[float | int | Segsnr]:
|
1192
|
+
force: bool = False) -> list[float | int | str | Segsnr]:
|
1189
1193
|
"""Get metrics data for the given mixture ID
|
1190
1194
|
|
1191
1195
|
:param m_id: Zero-based mixture ID
|
@@ -1328,7 +1332,56 @@ class MixtureDatabase:
|
|
1328
1332
|
|
1329
1333
|
noise_stats = create_noise_stats()
|
1330
1334
|
|
1331
|
-
def
|
1335
|
+
def create_asr_config() -> Callable[[str], dict]:
|
1336
|
+
state: dict[str, dict] = {}
|
1337
|
+
|
1338
|
+
def get(asr_name) -> dict:
|
1339
|
+
nonlocal state
|
1340
|
+
if asr_name not in state:
|
1341
|
+
state[asr_name] = self.asr_configs.get(asr_name, None)
|
1342
|
+
if state[asr_name] is None:
|
1343
|
+
raise SonusAIError(f"Unrecognized ASR name: '{asr_name}'")
|
1344
|
+
return state[asr_name]
|
1345
|
+
|
1346
|
+
return get
|
1347
|
+
|
1348
|
+
asr_config = create_asr_config()
|
1349
|
+
|
1350
|
+
def create_target_asr() -> Callable[[str], str]:
|
1351
|
+
state: dict[str, str] = {}
|
1352
|
+
|
1353
|
+
def get(asr_name) -> str:
|
1354
|
+
nonlocal state
|
1355
|
+
if asr_name not in state:
|
1356
|
+
state[asr_name] = calc_asr(target_audio(), **asr_config(asr_name)).text
|
1357
|
+
return state[asr_name]
|
1358
|
+
|
1359
|
+
return get
|
1360
|
+
|
1361
|
+
target_asr = create_target_asr()
|
1362
|
+
|
1363
|
+
def create_mixture_asr() -> Callable[[str], str]:
|
1364
|
+
state: dict[str, str] = {}
|
1365
|
+
|
1366
|
+
def get(asr_name) -> str:
|
1367
|
+
nonlocal state
|
1368
|
+
if asr_name not in state:
|
1369
|
+
state[asr_name] = calc_asr(mixture_audio(), **asr_config(asr_name)).text
|
1370
|
+
return state[asr_name]
|
1371
|
+
|
1372
|
+
return get
|
1373
|
+
|
1374
|
+
mixture_asr = create_mixture_asr()
|
1375
|
+
|
1376
|
+
def get_asr_name(m: str) -> str:
|
1377
|
+
parts = m.split('.')
|
1378
|
+
if len(parts) != 2:
|
1379
|
+
raise SonusAIError(
|
1380
|
+
f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
|
1381
|
+
asr_name = parts[1]
|
1382
|
+
return asr_name
|
1383
|
+
|
1384
|
+
def calc(m: str) -> float | int | str | Segsnr:
|
1332
1385
|
if m == 'mxsnr':
|
1333
1386
|
return self.mixture(m_id).snr
|
1334
1387
|
|
@@ -1340,31 +1393,21 @@ class MixtureDatabase:
|
|
1340
1393
|
|
1341
1394
|
# Otherwise, generate data as needed
|
1342
1395
|
if m.startswith('mxwer'):
|
1343
|
-
|
1344
|
-
if len(parts) != 2:
|
1345
|
-
raise SonusAIError(
|
1346
|
-
f"Unrecognized 'mxwer' format: '{m}'; must be of the form: 'mxwer.<name>'")
|
1347
|
-
asr_name = parts[1]
|
1348
|
-
asr_config = self.asr_configs.get(asr_name, None)
|
1349
|
-
if asr_config is None:
|
1350
|
-
raise SonusAIError(f"Unrecognized metric: '{m}'")
|
1396
|
+
asr_name = get_asr_name(m)
|
1351
1397
|
|
1352
1398
|
if self.mixture(m_id).snr < -96:
|
1353
1399
|
# noise only, ignore/reset target asr
|
1354
1400
|
return float('nan')
|
1355
1401
|
|
1356
|
-
|
1357
|
-
|
1358
|
-
if target_asr is None:
|
1359
|
-
target_asr = calc_asr(target_audio(), **asr_config).text
|
1360
|
-
|
1361
|
-
if target_asr:
|
1362
|
-
mixture_asr = calc_asr(mixture_audio(), **asr_config).text
|
1363
|
-
return calc_wer(mixture_asr, target_asr).wer * 100
|
1402
|
+
if target_asr(asr_name):
|
1403
|
+
return calc_wer(mixture_asr(asr_name), target_asr(asr_name)).wer * 100
|
1364
1404
|
|
1365
1405
|
# TODO: should this be NaN like above?
|
1366
1406
|
return float(0)
|
1367
1407
|
|
1408
|
+
if m.startswith('mxasr'):
|
1409
|
+
return mixture_asr(get_asr_name(m))
|
1410
|
+
|
1368
1411
|
if m == 'mxssnr_avg':
|
1369
1412
|
return calc_segsnr_f(segsnr_f()).avg
|
1370
1413
|
|
@@ -1454,6 +1497,9 @@ class MixtureDatabase:
|
|
1454
1497
|
if m == 'tpkc':
|
1455
1498
|
return target_stats().pkc
|
1456
1499
|
|
1500
|
+
if m.startswith('tasr'):
|
1501
|
+
return target_asr(get_asr_name(m))
|
1502
|
+
|
1457
1503
|
if m == 'ndco':
|
1458
1504
|
return noise_stats().dco
|
1459
1505
|
|
@@ -1501,7 +1547,7 @@ class MixtureDatabase:
|
|
1501
1547
|
|
1502
1548
|
raise SonusAIError(f"Unrecognized metric: '{m}'")
|
1503
1549
|
|
1504
|
-
result: list[float | int | Segsnr] = []
|
1550
|
+
result: list[float | int | str | Segsnr] = []
|
1505
1551
|
for metric in metrics:
|
1506
1552
|
result.append(calc(metric))
|
1507
1553
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
sonusai/__init__.py,sha256=
|
1
|
+
sonusai/__init__.py,sha256=PakKXwYWgB0TZysZ6t9l6s33WobyS55qTR5jceMrADQ,3062
|
2
2
|
sonusai/aawscd_probwrite.py,sha256=GukR5owp_0A3DrqSl9fHWULYgclNft4D5OkHIwfxxkc,3698
|
3
3
|
sonusai/audiofe.py,sha256=3LssRiL73DH8teihD9f3nCvfZ0a65WQtXCqWGnKHuJM,11157
|
4
4
|
sonusai/calc_metric_spenh.py,sha256=ee2xrx6L1lFyWSoQSiq56He3RQ1cF7T_ak-6TjejXsc,47738
|
@@ -44,9 +44,9 @@ sonusai/mixture/db_datatypes.py,sha256=GDYbcSrlgUJsesiUUNnR4s5aBkMgviiNSQDaBcgYX
|
|
44
44
|
sonusai/mixture/eq_rule_is_valid.py,sha256=MpQwRA5M76wSiQWEI1lW2cLFdPaMttBLcQp3tWD8efM,1243
|
45
45
|
sonusai/mixture/feature.py,sha256=bHAPRaYGyS-ZTOb-RLCwDau7n1NDKsVEW30Gd9SRZYo,3676
|
46
46
|
sonusai/mixture/generation.py,sha256=W3n6ipI-dxg4Wj6YBJn8RTpFqkAyIXzxwObeFbSLq08,42801
|
47
|
-
sonusai/mixture/helpers.py,sha256=
|
47
|
+
sonusai/mixture/helpers.py,sha256=9x7gezEqPm5xKGAbwCqDMjedVEmoDWyFR_5-T_5nlno,24740
|
48
48
|
sonusai/mixture/log_duration_and_sizes.py,sha256=baTUpqyM15wA125jo9E3posmVJUe3WlpksyO6v9Jul0,1347
|
49
|
-
sonusai/mixture/mixdb.py,sha256=
|
49
|
+
sonusai/mixture/mixdb.py,sha256=mr9Ck4p_mCfvz1PXoUgjWcw9F-Rlw3uGiDizUvPqo2A,64359
|
50
50
|
sonusai/mixture/soundfile_audio.py,sha256=BwO4lftNvrhoPTJERONcrpxSpM2fjO6kL_e5Ylz742A,4220
|
51
51
|
sonusai/mixture/sox_audio.py,sha256=DbHuyLtEuQYtKsIRxx6g1webW_LsdgLz52P5VO37MqI,17119
|
52
52
|
sonusai/mixture/sox_augmentation.py,sha256=kBWPrsFk0EBi71nLcKt5v0GA34bY7g9D9x0cEamNWbU,4564
|
@@ -68,7 +68,6 @@ sonusai/mixture/truth_functions/target.py,sha256=XypzXVMi24Ys13TiEM9JFY_cvHK61Lo
|
|
68
68
|
sonusai/mkwav.py,sha256=zfSyIiQTIK3KV9Ij33jkLhhZIMVYqaROcRQ4S7c4sIo,5364
|
69
69
|
sonusai/onnx_predict.py,sha256=jSxhD2oFyGSTHOGCXbW4fRT-k4SqKOboK2JaDO-yWcs,8737
|
70
70
|
sonusai/plot.py,sha256=ERkmxMM3qjcCDm4LGDQY4fRAncCYAzP7uW8iZ7_brcg,17105
|
71
|
-
sonusai/post_spenh_targetf.py,sha256=MBikRQfVfSZtRz9I5R3muxUtzR83S-i5INu3fAXliT4,4959
|
72
71
|
sonusai/queries/__init__.py,sha256=oKY5JeqZ4Cz7DwCwPc1_ydB8bUs6KaMcWFp_w02TjOs,255
|
73
72
|
sonusai/queries/queries.py,sha256=oV-m9uiLZOwYTK-Wo7Gf8dpGisaoGf6uDsAJAarVqZI,7553
|
74
73
|
sonusai/speech/__init__.py,sha256=SuPcU_K9wQISsZRIzsRNLtEC6cb616l-Jlx3PU-HWMs,113
|
@@ -120,7 +119,7 @@ sonusai/utils/stratified_shuffle_split.py,sha256=rJNXvBp-GxoKzH3OpL7k0ANSu5xMP2z
|
|
120
119
|
sonusai/utils/write_audio.py,sha256=ZsPGExwM86QHLLN2LOWekK2uAqf5pV_1oRW811p0QAI,840
|
121
120
|
sonusai/utils/yes_or_no.py,sha256=eMLXBVH0cEahiXY4W2KNORmwNQ-ba10eRtldh0y4NYg,263
|
122
121
|
sonusai/vars.py,sha256=m2AefF0m5bXWGXpJj8Pi42zWL2ydeEj7bkak3GrtMyM,940
|
123
|
-
sonusai-0.18.
|
124
|
-
sonusai-0.18.
|
125
|
-
sonusai-0.18.
|
126
|
-
sonusai-0.18.
|
122
|
+
sonusai-0.18.8.dist-info/METADATA,sha256=KqBUJv7yMq-3lDPNfezRqi_z28ZB0w0mDEXJfBtlVmA,2591
|
123
|
+
sonusai-0.18.8.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
124
|
+
sonusai-0.18.8.dist-info/entry_points.txt,sha256=zMNjEphEPO6B3cD1GNpit7z-yA9tUU5-j3W2v-UWstU,92
|
125
|
+
sonusai-0.18.8.dist-info/RECORD,,
|
sonusai/post_spenh_targetf.py
DELETED
@@ -1,160 +0,0 @@
|
|
1
|
-
"""sonusai post_spenh_targetf
|
2
|
-
|
3
|
-
usage: post_spenh_targetf [-hv] (-m MODEL) (-w KMODEL) INPUT ...
|
4
|
-
|
5
|
-
options:
|
6
|
-
-h, --help
|
7
|
-
-v, --verbose Be verbose.
|
8
|
-
-m MODEL, --model MODEL Python model file.
|
9
|
-
-w KMODEL, --weights KMODEL Keras model weights file.
|
10
|
-
|
11
|
-
Run post-processing on speech enhancement targetf prediction data.
|
12
|
-
|
13
|
-
Inputs:
|
14
|
-
MODEL A SonusAI Python model file with build and/or hypermodel functions.
|
15
|
-
KMODEL A Keras model weights file (or model file with weights).
|
16
|
-
INPUT A single H5 file or a glob of H5 files
|
17
|
-
|
18
|
-
Outputs the following to post_spenh_targetf-<TIMESTAMP> directory:
|
19
|
-
<name>.wav
|
20
|
-
post_spenh_targetf.log
|
21
|
-
|
22
|
-
"""
|
23
|
-
import signal
|
24
|
-
from dataclasses import dataclass
|
25
|
-
|
26
|
-
|
27
|
-
def signal_handler(_sig, _frame):
|
28
|
-
import sys
|
29
|
-
|
30
|
-
from sonusai import logger
|
31
|
-
|
32
|
-
logger.info('Canceled due to keyboard interrupt')
|
33
|
-
sys.exit(1)
|
34
|
-
|
35
|
-
|
36
|
-
signal.signal(signal.SIGINT, signal_handler)
|
37
|
-
|
38
|
-
|
39
|
-
@dataclass
|
40
|
-
class MPGlobal:
|
41
|
-
N: int = None
|
42
|
-
R: int = None
|
43
|
-
bin_start: int = None
|
44
|
-
bin_end: int = None
|
45
|
-
ttype: str = None
|
46
|
-
output_dir: str = None
|
47
|
-
|
48
|
-
|
49
|
-
MP_GLOBAL = MPGlobal()
|
50
|
-
|
51
|
-
|
52
|
-
def main() -> None:
|
53
|
-
from docopt import docopt
|
54
|
-
|
55
|
-
import sonusai
|
56
|
-
from sonusai.utils import trim_docstring
|
57
|
-
|
58
|
-
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
59
|
-
|
60
|
-
verbose = args['--verbose']
|
61
|
-
model_name = args['--model']
|
62
|
-
weights_name = args['--weights']
|
63
|
-
input_name = args['INPUT']
|
64
|
-
|
65
|
-
import time
|
66
|
-
from os import makedirs
|
67
|
-
from os.path import isfile
|
68
|
-
from os.path import join
|
69
|
-
from os.path import splitext
|
70
|
-
|
71
|
-
from pyaaware import FeatureGenerator
|
72
|
-
from tqdm import tqdm
|
73
|
-
|
74
|
-
from sonusai import create_file_handler
|
75
|
-
from sonusai import initial_log_messages
|
76
|
-
from sonusai import logger
|
77
|
-
from sonusai import update_console_handler
|
78
|
-
from sonusai.utils import create_ts_name
|
79
|
-
from sonusai.utils import import_and_check_keras_model
|
80
|
-
from sonusai.utils import pp_tqdm_imap
|
81
|
-
from sonusai.utils import seconds_to_hms
|
82
|
-
|
83
|
-
start_time = time.monotonic()
|
84
|
-
|
85
|
-
output_dir = create_ts_name('post_spenh_targetf')
|
86
|
-
makedirs(output_dir, exist_ok=True)
|
87
|
-
|
88
|
-
# Setup logging file
|
89
|
-
create_file_handler(join(output_dir, 'post_spenh_targetf.log'))
|
90
|
-
update_console_handler(verbose)
|
91
|
-
initial_log_messages('post_spenh_targetf')
|
92
|
-
|
93
|
-
hypermodel = import_and_check_keras_model(model_name=model_name, weights_name=weights_name)
|
94
|
-
|
95
|
-
fg = FeatureGenerator(feature_mode=hypermodel.feature,
|
96
|
-
num_classes=hypermodel.num_classes,
|
97
|
-
truth_mutex=hypermodel.truth_mutex)
|
98
|
-
|
99
|
-
MP_GLOBAL.N = fg.itransform_N
|
100
|
-
MP_GLOBAL.R = fg.itransform_R
|
101
|
-
MP_GLOBAL.bin_start = fg.bin_start
|
102
|
-
MP_GLOBAL.bin_end = fg.bin_end
|
103
|
-
MP_GLOBAL.ttype = fg.itransform_ttype
|
104
|
-
MP_GLOBAL.output_dir = output_dir
|
105
|
-
|
106
|
-
if not all(isfile(file) and splitext(file)[1] == '.h5' for file in input_name):
|
107
|
-
logger.exception(f'Do not know how to process input from {input_name}')
|
108
|
-
raise SystemExit(1)
|
109
|
-
|
110
|
-
logger.info('')
|
111
|
-
logger.info(f'Found {len(input_name):,} files to process')
|
112
|
-
|
113
|
-
progress = tqdm(total=len(input_name))
|
114
|
-
pp_tqdm_imap(_process, input_name, progress=progress)
|
115
|
-
progress.close()
|
116
|
-
|
117
|
-
logger.info(f'Wrote {len(input_name)} mixtures to {output_dir}')
|
118
|
-
logger.info('')
|
119
|
-
|
120
|
-
end_time = time.monotonic()
|
121
|
-
logger.info(f'Completed in {seconds_to_hms(seconds=end_time - start_time)}')
|
122
|
-
logger.info('')
|
123
|
-
|
124
|
-
|
125
|
-
def _process(file: str) -> None:
|
126
|
-
"""Run extraction on predict data to generate estimation audio
|
127
|
-
"""
|
128
|
-
from os.path import basename
|
129
|
-
from os.path import join
|
130
|
-
from os.path import splitext
|
131
|
-
|
132
|
-
import h5py
|
133
|
-
import numpy as np
|
134
|
-
from sonusai import InverseTransform
|
135
|
-
|
136
|
-
from sonusai import SonusAIError
|
137
|
-
from sonusai.mixture import get_audio_from_transform
|
138
|
-
from sonusai.utils import float_to_int16
|
139
|
-
from sonusai.utils import unstack_complex
|
140
|
-
from sonusai.utils import write_audio
|
141
|
-
|
142
|
-
try:
|
143
|
-
with h5py.File(file, 'r') as f:
|
144
|
-
predict = unstack_complex(np.array(f['predict']))
|
145
|
-
except Exception as e:
|
146
|
-
raise SonusAIError(f'Error reading {file}: {e}')
|
147
|
-
|
148
|
-
output_name = join(MP_GLOBAL.output_dir, splitext(basename(file))[0] + '.wav')
|
149
|
-
audio, _ = get_audio_from_transform(data=predict,
|
150
|
-
transform=InverseTransform(N=MP_GLOBAL.N,
|
151
|
-
R=MP_GLOBAL.R,
|
152
|
-
bin_start=MP_GLOBAL.bin_start,
|
153
|
-
bin_end=MP_GLOBAL.bin_end,
|
154
|
-
ttype=MP_GLOBAL.ttype,
|
155
|
-
gain=np.float32(1)))
|
156
|
-
write_audio(name=output_name, audio=float_to_int16(audio))
|
157
|
-
|
158
|
-
|
159
|
-
if __name__ == '__main__':
|
160
|
-
main()
|
File without changes
|
File without changes
|