sonusai 0.18.6__py3-none-any.whl → 0.18.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +6 -1
- sonusai/genmetrics.py +4 -4
- sonusai/metrics/__init__.py +2 -1
- sonusai/metrics/calc_audio_stats.py +9 -1
- sonusai/metrics/calc_segsnr_f.py +84 -0
- sonusai/metrics/calc_speech.py +5 -5
- sonusai/mixture/__init__.py +3 -0
- sonusai/mixture/datatypes.py +65 -6
- sonusai/mixture/feature.py +4 -19
- sonusai/mixture/helpers.py +50 -39
- sonusai/mixture/mixdb.py +198 -59
- sonusai/mixture/sox_audio.py +125 -0
- sonusai/mixture/truth_functions/data.py +23 -22
- sonusai/mixture/truth_functions/energy.py +3 -1
- sonusai/mixture/truth_functions/sed.py +2 -1
- sonusai/mixture/truth_functions/target.py +3 -4
- sonusai/utils/__init__.py +2 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/energy_f.py +3 -4
- {sonusai-0.18.6.dist-info → sonusai-0.18.8.dist-info}/METADATA +1 -1
- {sonusai-0.18.6.dist-info → sonusai-0.18.8.dist-info}/RECORD +23 -23
- sonusai/metrics/calc_snr_f.py +0 -34
- sonusai/post_spenh_targetf.py +0 -160
- {sonusai-0.18.6.dist-info → sonusai-0.18.8.dist-info}/WHEEL +0 -0
- {sonusai-0.18.6.dist-info → sonusai-0.18.8.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py
CHANGED
@@ -17,6 +17,8 @@ from sonusai.mixture.datatypes import FeatureGeneratorConfig
|
|
17
17
|
from sonusai.mixture.datatypes import FeatureGeneratorInfo
|
18
18
|
from sonusai.mixture.datatypes import GeneralizedIDs
|
19
19
|
from sonusai.mixture.datatypes import ImpulseResponseFiles
|
20
|
+
from sonusai.mixture.datatypes import MetricDoc
|
21
|
+
from sonusai.mixture.datatypes import MetricDocs
|
20
22
|
from sonusai.mixture.datatypes import Mixture
|
21
23
|
from sonusai.mixture.datatypes import Mixtures
|
22
24
|
from sonusai.mixture.datatypes import NoiseFile
|
@@ -155,19 +157,69 @@ class MixtureDatabase:
|
|
155
157
|
return json.loads(c.execute("SELECT top.asr_configs FROM top").fetchone()[0])
|
156
158
|
|
157
159
|
@cached_property
|
158
|
-
def supported_metrics(self) ->
|
159
|
-
metrics =
|
160
|
-
'
|
161
|
-
'
|
162
|
-
'
|
163
|
-
'
|
164
|
-
|
165
|
-
'
|
166
|
-
|
167
|
-
'
|
168
|
-
|
160
|
+
def supported_metrics(self) -> MetricDocs:
|
161
|
+
metrics = MetricDocs([
|
162
|
+
MetricDoc('Mixture Metrics', 'mxsnr', 'SNR specification in dB'),
|
163
|
+
MetricDoc('Mixture Metrics', 'mxssnr_avg', 'Segmental SNR average over all frames'),
|
164
|
+
MetricDoc('Mixture Metrics', 'mxssnr_std', 'Segmental SNR standard deviation over all frames'),
|
165
|
+
MetricDoc('Mixture Metrics', 'mxssnrdb_avg',
|
166
|
+
'Segmental SNR average of the dB frame values over all frames'),
|
167
|
+
MetricDoc('Mixture Metrics', 'mxssnrdb_std',
|
168
|
+
'Segmental SNR standard deviation of the dB frame values over all frames'),
|
169
|
+
MetricDoc('Mixture Metrics', 'mxssnrf_avg',
|
170
|
+
'Per-bin segmental SNR average over all frames (using feature transform)'),
|
171
|
+
MetricDoc('Mixture Metrics', 'mxssnrf_std',
|
172
|
+
'Per-bin segmental SNR standard deviation over all frames (using feature transform)'),
|
173
|
+
MetricDoc('Mixture Metrics', 'mxssnrdbf_avg',
|
174
|
+
'Per-bin segmental average of the dB frame values over all frames (using feature transform)'),
|
175
|
+
MetricDoc('Mixture Metrics', 'mxssnrdbf_std',
|
176
|
+
'Per-bin segmental standard deviation of the dB frame values over all frames (using feature transform)'),
|
177
|
+
MetricDoc('Mixture Metrics', 'mxpesq', 'PESQ of mixture versus true target[0]'),
|
178
|
+
MetricDoc('Mixture Metrics', 'mxwsdr', 'Weighted signal distorion ratio of mixture versus true target[0]'),
|
179
|
+
MetricDoc('Mixture Metrics', 'mxpd', 'Phase distance between mixture and true target[0]'),
|
180
|
+
MetricDoc('Mixture Metrics', 'mxstoi',
|
181
|
+
'Short term objective intelligibility of mixture versus true target[0]'),
|
182
|
+
MetricDoc('Mixture Metrics', 'mxcsig',
|
183
|
+
'Predicted rating of speech distortion of mixture versus true target[0]'),
|
184
|
+
MetricDoc('Mixture Metrics', 'mxcbak',
|
185
|
+
'Predicted rating of background distortion of mixture versus true target[0]'),
|
186
|
+
MetricDoc('Mixture Metrics', 'mxcovl',
|
187
|
+
'Predicted rating of overall quality of mixture versus true target[0]'),
|
188
|
+
MetricDoc('Mixture Metrics', 'ssnr', 'Segmental SNR'),
|
189
|
+
MetricDoc('Target Metrics', 'tdco', 'Target[0] DC offset'),
|
190
|
+
MetricDoc('Target Metrics', 'tmin', 'Target[0] min level'),
|
191
|
+
MetricDoc('Target Metrics', 'tmax', 'Target[0] max levl'),
|
192
|
+
MetricDoc('Target Metrics', 'tpkdb', 'Target[0] Pk lev dB'),
|
193
|
+
MetricDoc('Target Metrics', 'tlrms', 'Target[0] RMS lev dB'),
|
194
|
+
MetricDoc('Target Metrics', 'tpkr', 'Target[0] RMS Pk dB'),
|
195
|
+
MetricDoc('Target Metrics', 'ttr', 'Target[0] RMS Tr dB'),
|
196
|
+
MetricDoc('Target Metrics', 'tcr', 'Target[0] Crest factor'),
|
197
|
+
MetricDoc('Target Metrics', 'tfl', 'Target[0] Flat factor'),
|
198
|
+
MetricDoc('Target Metrics', 'tpkc', 'Target[0] Pk count'),
|
199
|
+
MetricDoc('Noise Metrics', 'ndco', 'Noise DC offset'),
|
200
|
+
MetricDoc('Noise Metrics', 'nmin', 'Noise min level'),
|
201
|
+
MetricDoc('Noise Metrics', 'nmax', 'Noise max levl'),
|
202
|
+
MetricDoc('Noise Metrics', 'npkdb', 'Noise Pk lev dB'),
|
203
|
+
MetricDoc('Noise Metrics', 'nlrms', 'Noise RMS lev dB'),
|
204
|
+
MetricDoc('Noise Metrics', 'npkr', 'Noise RMS Pk dB'),
|
205
|
+
MetricDoc('Noise Metrics', 'ntr', 'Noise RMS Tr dB'),
|
206
|
+
MetricDoc('Noise Metrics', 'ncr', 'Noise Crest factor'),
|
207
|
+
MetricDoc('Noise Metrics', 'nfl', 'Noise Flat factor'),
|
208
|
+
MetricDoc('Noise Metrics', 'npkc', 'Noise Pk count'),
|
209
|
+
MetricDoc('Truth Metrics', 'sedavg',
|
210
|
+
'(not implemented) Average SED activity over all frames [num_classes, 1]'),
|
211
|
+
MetricDoc('Truth Metrics', 'sedcnt',
|
212
|
+
'(not implemented) Count in number of frames that SED is active [num_classes, 1]'),
|
213
|
+
MetricDoc('Truth Metrics', 'sedtop3', '(not implemented) 3 most active by largest sedavg [3, 1]'),
|
214
|
+
MetricDoc('Truth Metrics', 'sedtopn', '(not implemented) N most active by largest sedavg [N, 1]'),
|
215
|
+
])
|
169
216
|
for name in self.asr_configs:
|
170
|
-
metrics.
|
217
|
+
metrics.append(MetricDoc('Target Metrics', f'tasr.{name}',
|
218
|
+
f'Target[0] ASR text using {name} ASR as defined in mixdb asr_configs parameter'))
|
219
|
+
metrics.append(MetricDoc('Mixture Metrics', f'mxasr.{name}',
|
220
|
+
f'ASR text using {name} ASR as defined in mixdb asr_configs parameter'))
|
221
|
+
metrics.append(MetricDoc('Mixture Metrics', f'mxwer.{name}',
|
222
|
+
f'Word error rate using {name} ASR as defined in mixdb asr_configs parameter'))
|
171
223
|
|
172
224
|
return metrics
|
173
225
|
|
@@ -240,11 +292,15 @@ class MixtureDatabase:
|
|
240
292
|
def total_feature_frames(self, m_ids: GeneralizedIDs = '*') -> int:
|
241
293
|
return self.total_samples(m_ids) // self.feature_step_samples
|
242
294
|
|
243
|
-
def mixture_transform_frames(self,
|
244
|
-
|
295
|
+
def mixture_transform_frames(self, m_id: int) -> int:
|
296
|
+
from .helpers import frames_from_samples
|
245
297
|
|
246
|
-
|
247
|
-
|
298
|
+
return frames_from_samples(self.mixture(m_id).samples, self.ft_config.R)
|
299
|
+
|
300
|
+
def mixture_feature_frames(self, m_id: int) -> int:
|
301
|
+
from .helpers import frames_from_samples
|
302
|
+
|
303
|
+
return frames_from_samples(self.mixture(m_id).samples, self.feature_step_samples)
|
248
304
|
|
249
305
|
def mixids_to_list(self, m_ids: Optional[GeneralizedIDs] = None) -> list[int]:
|
250
306
|
"""Resolve generalized mixture IDs to a list of integers
|
@@ -907,8 +963,8 @@ class MixtureDatabase:
|
|
907
963
|
truth_t = self.mixture_truth_t(m_id=m_id, targets=targets, noise=noise, force=force)
|
908
964
|
|
909
965
|
m = self.mixture(m_id)
|
910
|
-
transform_frames = self.mixture_transform_frames(
|
911
|
-
feature_frames = self.mixture_feature_frames(
|
966
|
+
transform_frames = self.mixture_transform_frames(m_id)
|
967
|
+
feature_frames = self.mixture_feature_frames(m_id)
|
912
968
|
|
913
969
|
if truth_t is None:
|
914
970
|
truth_t = np.zeros((m.samples, self.num_classes), dtype=np.float32)
|
@@ -1133,7 +1189,7 @@ class MixtureDatabase:
|
|
1133
1189
|
|
1134
1190
|
def mixture_metrics(self, m_id: int,
|
1135
1191
|
metrics: list[str],
|
1136
|
-
force: bool = False) -> list[float | int | Segsnr]:
|
1192
|
+
force: bool = False) -> list[float | int | str | Segsnr]:
|
1137
1193
|
"""Get metrics data for the given mixture ID
|
1138
1194
|
|
1139
1195
|
:param m_id: Zero-based mixture ID
|
@@ -1149,7 +1205,8 @@ class MixtureDatabase:
|
|
1149
1205
|
from sonusai import SonusAIError
|
1150
1206
|
from sonusai.metrics import calc_audio_stats
|
1151
1207
|
from sonusai.metrics import calc_phase_distance
|
1152
|
-
from sonusai.metrics import
|
1208
|
+
from sonusai.metrics import calc_segsnr_f
|
1209
|
+
from sonusai.metrics import calc_segsnr_f_bin
|
1153
1210
|
from sonusai.metrics import calc_speech
|
1154
1211
|
from sonusai.metrics import calc_wer
|
1155
1212
|
from sonusai.metrics import calc_wsdr
|
@@ -1158,7 +1215,7 @@ class MixtureDatabase:
|
|
1158
1215
|
from sonusai.mixture import SpeechMetrics
|
1159
1216
|
from sonusai.utils import calc_asr
|
1160
1217
|
|
1161
|
-
def create_target_audio() -> Callable:
|
1218
|
+
def create_target_audio() -> Callable[[], np.ndarray]:
|
1162
1219
|
state = None
|
1163
1220
|
|
1164
1221
|
def get() -> np.ndarray:
|
@@ -1171,7 +1228,20 @@ class MixtureDatabase:
|
|
1171
1228
|
|
1172
1229
|
target_audio = create_target_audio()
|
1173
1230
|
|
1174
|
-
def
|
1231
|
+
def create_target_f() -> Callable[[], np.ndarray]:
|
1232
|
+
state = None
|
1233
|
+
|
1234
|
+
def get() -> np.ndarray:
|
1235
|
+
nonlocal state
|
1236
|
+
if state is None:
|
1237
|
+
state = self.mixture_targets_f(m_id)[0]
|
1238
|
+
return state
|
1239
|
+
|
1240
|
+
return get
|
1241
|
+
|
1242
|
+
target_f = create_target_f()
|
1243
|
+
|
1244
|
+
def create_noise_audio() -> Callable[[], np.ndarray]:
|
1175
1245
|
state = None
|
1176
1246
|
|
1177
1247
|
def get() -> np.ndarray:
|
@@ -1184,7 +1254,20 @@ class MixtureDatabase:
|
|
1184
1254
|
|
1185
1255
|
noise_audio = create_noise_audio()
|
1186
1256
|
|
1187
|
-
def
|
1257
|
+
def create_noise_f() -> Callable[[], np.ndarray]:
|
1258
|
+
state = None
|
1259
|
+
|
1260
|
+
def get() -> np.ndarray:
|
1261
|
+
nonlocal state
|
1262
|
+
if state is None:
|
1263
|
+
state = self.mixture_noise_f(m_id)
|
1264
|
+
return state
|
1265
|
+
|
1266
|
+
return get
|
1267
|
+
|
1268
|
+
noise_f = create_noise_f()
|
1269
|
+
|
1270
|
+
def create_mixture_audio() -> Callable[[], np.ndarray]:
|
1188
1271
|
state = None
|
1189
1272
|
|
1190
1273
|
def get() -> np.ndarray:
|
@@ -1197,7 +1280,7 @@ class MixtureDatabase:
|
|
1197
1280
|
|
1198
1281
|
mixture_audio = create_mixture_audio()
|
1199
1282
|
|
1200
|
-
def create_segsnr_f() -> Callable:
|
1283
|
+
def create_segsnr_f() -> Callable[[], np.ndarray]:
|
1201
1284
|
state = None
|
1202
1285
|
|
1203
1286
|
def get() -> np.ndarray:
|
@@ -1210,7 +1293,7 @@ class MixtureDatabase:
|
|
1210
1293
|
|
1211
1294
|
segsnr_f = create_segsnr_f()
|
1212
1295
|
|
1213
|
-
def create_speech() -> Callable:
|
1296
|
+
def create_speech() -> Callable[[], SpeechMetrics]:
|
1214
1297
|
state = None
|
1215
1298
|
|
1216
1299
|
def get() -> SpeechMetrics:
|
@@ -1223,7 +1306,7 @@ class MixtureDatabase:
|
|
1223
1306
|
|
1224
1307
|
speech = create_speech()
|
1225
1308
|
|
1226
|
-
def create_target_stats() -> Callable:
|
1309
|
+
def create_target_stats() -> Callable[[], AudioStatsMetrics]:
|
1227
1310
|
state = None
|
1228
1311
|
|
1229
1312
|
def get() -> AudioStatsMetrics:
|
@@ -1236,7 +1319,7 @@ class MixtureDatabase:
|
|
1236
1319
|
|
1237
1320
|
target_stats = create_target_stats()
|
1238
1321
|
|
1239
|
-
def create_noise_stats() -> Callable:
|
1322
|
+
def create_noise_stats() -> Callable[[], AudioStatsMetrics]:
|
1240
1323
|
state = None
|
1241
1324
|
|
1242
1325
|
def get() -> AudioStatsMetrics:
|
@@ -1249,7 +1332,56 @@ class MixtureDatabase:
|
|
1249
1332
|
|
1250
1333
|
noise_stats = create_noise_stats()
|
1251
1334
|
|
1252
|
-
def
|
1335
|
+
def create_asr_config() -> Callable[[str], dict]:
|
1336
|
+
state: dict[str, dict] = {}
|
1337
|
+
|
1338
|
+
def get(asr_name) -> dict:
|
1339
|
+
nonlocal state
|
1340
|
+
if asr_name not in state:
|
1341
|
+
state[asr_name] = self.asr_configs.get(asr_name, None)
|
1342
|
+
if state[asr_name] is None:
|
1343
|
+
raise SonusAIError(f"Unrecognized ASR name: '{asr_name}'")
|
1344
|
+
return state[asr_name]
|
1345
|
+
|
1346
|
+
return get
|
1347
|
+
|
1348
|
+
asr_config = create_asr_config()
|
1349
|
+
|
1350
|
+
def create_target_asr() -> Callable[[str], str]:
|
1351
|
+
state: dict[str, str] = {}
|
1352
|
+
|
1353
|
+
def get(asr_name) -> str:
|
1354
|
+
nonlocal state
|
1355
|
+
if asr_name not in state:
|
1356
|
+
state[asr_name] = calc_asr(target_audio(), **asr_config(asr_name)).text
|
1357
|
+
return state[asr_name]
|
1358
|
+
|
1359
|
+
return get
|
1360
|
+
|
1361
|
+
target_asr = create_target_asr()
|
1362
|
+
|
1363
|
+
def create_mixture_asr() -> Callable[[str], str]:
|
1364
|
+
state: dict[str, str] = {}
|
1365
|
+
|
1366
|
+
def get(asr_name) -> str:
|
1367
|
+
nonlocal state
|
1368
|
+
if asr_name not in state:
|
1369
|
+
state[asr_name] = calc_asr(mixture_audio(), **asr_config(asr_name)).text
|
1370
|
+
return state[asr_name]
|
1371
|
+
|
1372
|
+
return get
|
1373
|
+
|
1374
|
+
mixture_asr = create_mixture_asr()
|
1375
|
+
|
1376
|
+
def get_asr_name(m: str) -> str:
|
1377
|
+
parts = m.split('.')
|
1378
|
+
if len(parts) != 2:
|
1379
|
+
raise SonusAIError(
|
1380
|
+
f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
|
1381
|
+
asr_name = parts[1]
|
1382
|
+
return asr_name
|
1383
|
+
|
1384
|
+
def calc(m: str) -> float | int | str | Segsnr:
|
1253
1385
|
if m == 'mxsnr':
|
1254
1386
|
return self.mixture(m_id).snr
|
1255
1387
|
|
@@ -1261,42 +1393,44 @@ class MixtureDatabase:
|
|
1261
1393
|
|
1262
1394
|
# Otherwise, generate data as needed
|
1263
1395
|
if m.startswith('mxwer'):
|
1264
|
-
|
1265
|
-
if len(parts) != 2:
|
1266
|
-
raise SonusAIError(
|
1267
|
-
f"Unrecognized 'mxwer' format: '{m}'; must be of the form: 'mxwer.<name>'")
|
1268
|
-
asr_name = parts[1]
|
1269
|
-
asr_config = self.asr_configs.get(asr_name, None)
|
1270
|
-
if asr_config is None:
|
1271
|
-
raise SonusAIError(f"Unrecognized metric: '{m}'")
|
1396
|
+
asr_name = get_asr_name(m)
|
1272
1397
|
|
1273
1398
|
if self.mixture(m_id).snr < -96:
|
1274
1399
|
# noise only, ignore/reset target asr
|
1275
1400
|
return float('nan')
|
1276
1401
|
|
1277
|
-
|
1278
|
-
|
1279
|
-
if target_asr is None:
|
1280
|
-
target_asr = calc_asr(target_audio(), **asr_config).text
|
1281
|
-
|
1282
|
-
if target_asr:
|
1283
|
-
mixture_asr = calc_asr(mixture_audio(), **asr_config).text
|
1284
|
-
return calc_wer(mixture_asr, target_asr).wer * 100
|
1402
|
+
if target_asr(asr_name):
|
1403
|
+
return calc_wer(mixture_asr(asr_name), target_asr(asr_name)).wer * 100
|
1285
1404
|
|
1286
1405
|
# TODO: should this be NaN like above?
|
1287
1406
|
return float(0)
|
1288
1407
|
|
1289
|
-
if m
|
1290
|
-
return
|
1408
|
+
if m.startswith('mxasr'):
|
1409
|
+
return mixture_asr(get_asr_name(m))
|
1410
|
+
|
1411
|
+
if m == 'mxssnr_avg':
|
1412
|
+
return calc_segsnr_f(segsnr_f()).avg
|
1291
1413
|
|
1292
|
-
if m == '
|
1293
|
-
return
|
1414
|
+
if m == 'mxssnr_std':
|
1415
|
+
return calc_segsnr_f(segsnr_f()).std
|
1294
1416
|
|
1295
|
-
if m == '
|
1296
|
-
return
|
1417
|
+
if m == 'mxssnrdb_avg':
|
1418
|
+
return calc_segsnr_f(segsnr_f()).db_avg
|
1297
1419
|
|
1298
|
-
if m == '
|
1299
|
-
return
|
1420
|
+
if m == 'mxssnrdb_std':
|
1421
|
+
return calc_segsnr_f(segsnr_f()).db_std
|
1422
|
+
|
1423
|
+
if m == 'mxssnrf_avg':
|
1424
|
+
return calc_segsnr_f_bin(target_f(), noise_f()).avg
|
1425
|
+
|
1426
|
+
if m == 'mxssnrf_std':
|
1427
|
+
return calc_segsnr_f_bin(target_f(), noise_f()).std
|
1428
|
+
|
1429
|
+
if m == 'mxssnrdbf_avg':
|
1430
|
+
return calc_segsnr_f_bin(target_f(), noise_f()).db_avg
|
1431
|
+
|
1432
|
+
if m == 'mxssnrdbf_std':
|
1433
|
+
return calc_segsnr_f_bin(target_f(), noise_f()).db_std
|
1300
1434
|
|
1301
1435
|
if m == 'mxpesq':
|
1302
1436
|
if self.mixture(m_id).snr < -96:
|
@@ -1306,17 +1440,17 @@ class MixtureDatabase:
|
|
1306
1440
|
if m == 'mxcsig':
|
1307
1441
|
if self.mixture(m_id).snr < -96:
|
1308
1442
|
return 0
|
1309
|
-
return speech().
|
1443
|
+
return speech().csig
|
1310
1444
|
|
1311
1445
|
if m == 'mxcbak':
|
1312
1446
|
if self.mixture(m_id).snr < -96:
|
1313
1447
|
return 0
|
1314
|
-
return speech().
|
1448
|
+
return speech().cbak
|
1315
1449
|
|
1316
1450
|
if m == 'mxcovl':
|
1317
1451
|
if self.mixture(m_id).snr < -96:
|
1318
1452
|
return 0
|
1319
|
-
return speech().
|
1453
|
+
return speech().covl
|
1320
1454
|
|
1321
1455
|
if m == 'mxwsdr':
|
1322
1456
|
mixture = mixture_audio()[:, np.newaxis]
|
@@ -1328,8 +1462,7 @@ class MixtureDatabase:
|
|
1328
1462
|
|
1329
1463
|
if m == 'mxpd':
|
1330
1464
|
mixture_f = self.mixture_mixture_f(m_id)
|
1331
|
-
|
1332
|
-
return calc_phase_distance(hypothesis=mixture_f, reference=target_f)[0]
|
1465
|
+
return calc_phase_distance(hypothesis=mixture_f, reference=target_f())[0]
|
1333
1466
|
|
1334
1467
|
if m == 'mxstoi':
|
1335
1468
|
return stoi(x=target_audio(), y=mixture_audio(), fs_sig=SAMPLE_RATE, extended=False)
|
@@ -1364,6 +1497,9 @@ class MixtureDatabase:
|
|
1364
1497
|
if m == 'tpkc':
|
1365
1498
|
return target_stats().pkc
|
1366
1499
|
|
1500
|
+
if m.startswith('tasr'):
|
1501
|
+
return target_asr(get_asr_name(m))
|
1502
|
+
|
1367
1503
|
if m == 'ndco':
|
1368
1504
|
return noise_stats().dco
|
1369
1505
|
|
@@ -1400,15 +1536,18 @@ class MixtureDatabase:
|
|
1400
1536
|
if m == 'sedcnt':
|
1401
1537
|
return 0
|
1402
1538
|
|
1539
|
+
if m == 'sedtop3':
|
1540
|
+
return np.zeros(3, dtype=np.float32)
|
1541
|
+
|
1403
1542
|
if m == 'sedtopn':
|
1404
1543
|
return 0
|
1405
1544
|
|
1406
1545
|
if m == 'ssnr':
|
1407
|
-
return
|
1546
|
+
return segsnr_f()
|
1408
1547
|
|
1409
1548
|
raise SonusAIError(f"Unrecognized metric: '{m}'")
|
1410
1549
|
|
1411
|
-
result: list[float | int | Segsnr] = []
|
1550
|
+
result: list[float | int | str | Segsnr] = []
|
1412
1551
|
for metric in metrics:
|
1413
1552
|
result.append(calc(metric))
|
1414
1553
|
|
sonusai/mixture/sox_audio.py
CHANGED
@@ -210,6 +210,131 @@ class Transformer(SoxTransformer):
|
|
210
210
|
|
211
211
|
return self
|
212
212
|
|
213
|
+
def build(self,
|
214
|
+
input_filepath: Optional[str | Path] = None,
|
215
|
+
output_filepath: Optional[str | Path] = None,
|
216
|
+
input_array: Optional[np.ndarray] = None,
|
217
|
+
sample_rate_in: Optional[float] = None,
|
218
|
+
extra_args: Optional[list[str]] = None,
|
219
|
+
return_output: bool = False) -> tuple[bool, Optional[str], Optional[str]]:
|
220
|
+
"""Given an input file or array, creates an output_file on disk by
|
221
|
+
executing the current set of commands. This function returns True on
|
222
|
+
success. If return_output is True, this function returns a triple of
|
223
|
+
(status, out, err), giving the success state, along with stdout and
|
224
|
+
stderr returned by sox.
|
225
|
+
|
226
|
+
Parameters
|
227
|
+
----------
|
228
|
+
input_filepath : str or None
|
229
|
+
Either path to input audio file or None for array input.
|
230
|
+
output_filepath : str
|
231
|
+
Path to desired output file. If a file already exists at
|
232
|
+
the given path, the file will be overwritten.
|
233
|
+
If '-n', no file is created.
|
234
|
+
input_array : np.ndarray or None
|
235
|
+
An np.ndarray of an waveform with shape (n_samples, n_channels).
|
236
|
+
sample_rate_in must also be provided.
|
237
|
+
If None, input_filepath must be specified.
|
238
|
+
sample_rate_in : int
|
239
|
+
Sample rate of input_array.
|
240
|
+
This argument is ignored if input_array is None.
|
241
|
+
extra_args : list or None, default=None
|
242
|
+
If a list is given, these additional arguments are passed to SoX
|
243
|
+
at the end of the list of effects.
|
244
|
+
Don't use this argument unless you know exactly what you're doing!
|
245
|
+
return_output : bool, default=False
|
246
|
+
If True, returns the status and information sent to stderr and
|
247
|
+
stdout as a tuple (status, stdout, stderr).
|
248
|
+
If output_filepath is None, return_output=True by default.
|
249
|
+
If False, returns True on success.
|
250
|
+
|
251
|
+
Returns
|
252
|
+
-------
|
253
|
+
status : bool
|
254
|
+
True on success.
|
255
|
+
out : str (optional)
|
256
|
+
This is not returned unless return_output is True.
|
257
|
+
When returned, captures the stdout produced by sox.
|
258
|
+
err : str (optional)
|
259
|
+
This is not returned unless return_output is True.
|
260
|
+
When returned, captures the stderr produced by sox.
|
261
|
+
|
262
|
+
Examples
|
263
|
+
--------
|
264
|
+
> import numpy as np
|
265
|
+
> import sox
|
266
|
+
> tfm = sox.Transformer()
|
267
|
+
> sample_rate = 44100
|
268
|
+
> y = np.sin(2 * np.pi * 440.0 * np.arange(sample_rate * 1.0) / sample_rate)
|
269
|
+
|
270
|
+
file in, file out - basic usage
|
271
|
+
|
272
|
+
> status = tfm.build('path/to/input.wav', 'path/to/output.mp3')
|
273
|
+
|
274
|
+
file in, file out - equivalent usage
|
275
|
+
|
276
|
+
> status = tfm.build(
|
277
|
+
input_filepath='path/to/input.wav',
|
278
|
+
output_filepath='path/to/output.mp3'
|
279
|
+
)
|
280
|
+
|
281
|
+
array in, file out
|
282
|
+
|
283
|
+
> status = tfm.build(
|
284
|
+
input_array=y, sample_rate_in=sample_rate,
|
285
|
+
output_filepath='path/to/output.mp3'
|
286
|
+
)
|
287
|
+
|
288
|
+
"""
|
289
|
+
from sox import file_info
|
290
|
+
from sox.core import SoxError
|
291
|
+
from sox.core import sox
|
292
|
+
from sox.log import logger
|
293
|
+
|
294
|
+
input_format, input_filepath = self._parse_inputs(
|
295
|
+
input_filepath, input_array, sample_rate_in
|
296
|
+
)
|
297
|
+
|
298
|
+
if output_filepath is None:
|
299
|
+
raise ValueError("output_filepath is not specified!")
|
300
|
+
|
301
|
+
# set output parameters
|
302
|
+
if input_filepath == output_filepath:
|
303
|
+
raise ValueError(
|
304
|
+
"input_filepath must be different from output_filepath."
|
305
|
+
)
|
306
|
+
file_info.validate_output_file(output_filepath)
|
307
|
+
|
308
|
+
args = []
|
309
|
+
args.extend(self.globals)
|
310
|
+
args.extend(self._input_format_args(input_format))
|
311
|
+
args.append(input_filepath)
|
312
|
+
args.extend(self._output_format_args(self.output_format))
|
313
|
+
args.append(output_filepath)
|
314
|
+
args.extend(self.effects)
|
315
|
+
|
316
|
+
if extra_args is not None:
|
317
|
+
if not isinstance(extra_args, list):
|
318
|
+
raise ValueError("extra_args must be a list.")
|
319
|
+
args.extend(extra_args)
|
320
|
+
|
321
|
+
status, out, err = sox(args, input_array, True)
|
322
|
+
if status != 0:
|
323
|
+
raise SoxError(
|
324
|
+
f"Stdout: {out}\nStderr: {err}"
|
325
|
+
)
|
326
|
+
|
327
|
+
logger.info(
|
328
|
+
"Created %s with effects: %s",
|
329
|
+
output_filepath,
|
330
|
+
" ".join(self.effects_log)
|
331
|
+
)
|
332
|
+
|
333
|
+
if return_output:
|
334
|
+
return status, out, err
|
335
|
+
|
336
|
+
return True, None, None
|
337
|
+
|
213
338
|
def build_array(self,
|
214
339
|
input_filepath: Optional[str | Path] = None,
|
215
340
|
input_array: Optional[np.ndarray] = None,
|
@@ -3,13 +3,14 @@ from sonusai.mixture.datatypes import TruthFunctionConfig
|
|
3
3
|
|
4
4
|
|
5
5
|
class Data:
|
6
|
-
def __init__(self,
|
6
|
+
def __init__(self,
|
7
|
+
target_audio: AudioT,
|
7
8
|
noise_audio: AudioT,
|
8
9
|
mixture_audio: AudioT,
|
9
10
|
config: TruthFunctionConfig) -> None:
|
10
11
|
import numpy as np
|
11
|
-
from
|
12
|
-
from
|
12
|
+
from sonusai import ForwardTransform
|
13
|
+
from sonusai import InverseTransform
|
13
14
|
from pyaaware import FeatureGenerator
|
14
15
|
|
15
16
|
from sonusai import SonusAIError
|
@@ -33,25 +34,25 @@ class Data:
|
|
33
34
|
|
34
35
|
self.offsets = range(0, len(target_audio), self.frame_size)
|
35
36
|
self.zero_based_indices = [x - 1 for x in config.index]
|
36
|
-
self.target_fft =
|
37
|
-
|
38
|
-
bin_start=fg.bin_start,
|
39
|
-
bin_end=fg.bin_end,
|
40
|
-
ttype=fg.ftransform_ttype)
|
41
|
-
self.noise_fft = AawareForwardTransform(N=fg.ftransform_N,
|
42
|
-
R=fg.ftransform_R,
|
43
|
-
bin_start=fg.bin_start,
|
44
|
-
bin_end=fg.bin_end,
|
45
|
-
ttype=fg.ftransform_ttype)
|
46
|
-
self.mixture_fft = AawareForwardTransform(N=fg.ftransform_N,
|
47
|
-
R=fg.ftransform_R,
|
48
|
-
bin_start=fg.bin_start,
|
49
|
-
bin_end=fg.bin_end,
|
50
|
-
ttype=fg.ftransform_ttype)
|
51
|
-
self.swin = AawareInverseTransform(N=fg.itransform_N,
|
52
|
-
R=fg.itransform_R,
|
37
|
+
self.target_fft = ForwardTransform(N=fg.ftransform_N,
|
38
|
+
R=fg.ftransform_R,
|
53
39
|
bin_start=fg.bin_start,
|
54
40
|
bin_end=fg.bin_end,
|
55
|
-
ttype=fg.
|
56
|
-
|
41
|
+
ttype=fg.ftransform_ttype)
|
42
|
+
self.noise_fft = ForwardTransform(N=fg.ftransform_N,
|
43
|
+
R=fg.ftransform_R,
|
44
|
+
bin_start=fg.bin_start,
|
45
|
+
bin_end=fg.bin_end,
|
46
|
+
ttype=fg.ftransform_ttype)
|
47
|
+
self.mixture_fft = ForwardTransform(N=fg.ftransform_N,
|
48
|
+
R=fg.ftransform_R,
|
49
|
+
bin_start=fg.bin_start,
|
50
|
+
bin_end=fg.bin_end,
|
51
|
+
ttype=fg.ftransform_ttype)
|
52
|
+
self.swin = InverseTransform(N=fg.itransform_N,
|
53
|
+
R=fg.itransform_R,
|
54
|
+
bin_start=fg.bin_start,
|
55
|
+
bin_end=fg.bin_end,
|
56
|
+
ttype=fg.itransform_ttype,
|
57
|
+
gain=np.float32(1)).W
|
57
58
|
self.truth = np.zeros((len(target_audio), config.num_classes), dtype=np.float32)
|
@@ -132,9 +132,11 @@ def energy_t(data: Data) -> Truth:
|
|
132
132
|
will reflect the total energy over all bins regardless of the feature
|
133
133
|
transform config.
|
134
134
|
"""
|
135
|
+
import torch
|
136
|
+
|
135
137
|
from sonusai import SonusAIError
|
136
138
|
|
137
|
-
|
139
|
+
target_energy = data.target_fft.execute_all(torch.from_numpy(data.target_audio))[1].numpy()
|
138
140
|
if len(target_energy) != len(data.offsets):
|
139
141
|
raise SonusAIError(f'Number of frames in target_energy, {len(target_energy)},'
|
140
142
|
f' is not number of frames in truth, {len(data.offsets)}')
|
@@ -21,6 +21,7 @@ should be set to the number of sounds/classes to be detected + 1 for
|
|
21
21
|
the other class.
|
22
22
|
"""
|
23
23
|
import numpy as np
|
24
|
+
import torch
|
24
25
|
from pyaaware import SED
|
25
26
|
|
26
27
|
from sonusai import SonusAIError
|
@@ -48,7 +49,7 @@ the other class.
|
|
48
49
|
mutex=data.config.mutex)
|
49
50
|
|
50
51
|
target_audio = data.target_audio / data.config.target_gain
|
51
|
-
|
52
|
+
energy_t = data.target_fft.execute_all(torch.from_numpy(target_audio))[1].numpy()
|
52
53
|
if len(energy_t) != len(data.offsets):
|
53
54
|
raise SonusAIError(f'Number of frames in energy_t, {len(energy_t)},'
|
54
55
|
f' is not number of frames in truth, {len(data.offsets)}')
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from
|
1
|
+
from sonusai import ForwardTransform
|
2
2
|
|
3
3
|
from sonusai.mixture.datatypes import AudioF
|
4
4
|
from sonusai.mixture.datatypes import AudioT
|
@@ -98,7 +98,6 @@ Output shape: [:, 2 * bins] (stacked real, imag)
|
|
98
98
|
for idx, offset in enumerate(data.offsets):
|
99
99
|
target_freq, _ = data.target_fft.execute(
|
100
100
|
np.multiply(data.target_audio[offset:offset + data.frame_size], data.swin))
|
101
|
-
target_freq = target_freq.transpose()
|
102
101
|
|
103
102
|
indices = slice(offset, offset + data.frame_size)
|
104
103
|
for index in data.zero_based_indices:
|
@@ -112,10 +111,10 @@ Output shape: [:, 2 * bins] (stacked real, imag)
|
|
112
111
|
|
113
112
|
|
114
113
|
def _execute_fft(audio: AudioT, transform: ForwardTransform, expected_frames: int) -> AudioF:
|
114
|
+
import torch
|
115
115
|
from sonusai import SonusAIError
|
116
116
|
|
117
|
-
freq
|
118
|
-
freq = freq.transpose()
|
117
|
+
freq = transform.execute_all(torch.from_numpy(audio))[0].numpy()
|
119
118
|
if len(freq) != expected_frames:
|
120
119
|
raise SonusAIError(f'Number of frames, {len(freq)}, is not number of frames expected, {expected_frames}')
|
121
120
|
return freq
|