sonusai 0.18.2__py3-none-any.whl → 0.18.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +1 -0
- sonusai/audiofe.py +1 -1
- sonusai/calc_metric_spenh.py +32 -362
- sonusai/data/genmixdb.yml +2 -0
- sonusai/doc/doc.py +45 -4
- sonusai/genmetrics.py +137 -109
- sonusai/lsdb.py +2 -2
- sonusai/metrics/__init__.py +4 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_pesq.py +12 -8
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_snr_f.py +34 -0
- sonusai/metrics/calc_speech.py +312 -0
- sonusai/metrics/calc_wer.py +2 -3
- sonusai/metrics/calc_wsdr.py +0 -59
- sonusai/mixture/__init__.py +3 -2
- sonusai/mixture/audio.py +6 -5
- sonusai/mixture/config.py +13 -0
- sonusai/mixture/constants.py +1 -0
- sonusai/mixture/datatypes.py +33 -0
- sonusai/mixture/generation.py +6 -2
- sonusai/mixture/mixdb.py +261 -122
- sonusai/mixture/soundfile_audio.py +8 -6
- sonusai/mixture/sox_audio.py +16 -13
- sonusai/mixture/torchaudio_audio.py +6 -4
- sonusai/mixture/truth_functions/energy.py +40 -28
- sonusai/mixture/truth_functions/target.py +0 -1
- sonusai/utils/__init__.py +1 -1
- sonusai/utils/asr.py +26 -39
- sonusai/utils/asr_functions/aaware_whisper.py +3 -3
- {sonusai-0.18.2.dist-info → sonusai-0.18.4.dist-info}/METADATA +1 -1
- {sonusai-0.18.2.dist-info → sonusai-0.18.4.dist-info}/RECORD +34 -31
- sonusai/mixture/mapped_snr_f.py +0 -100
- {sonusai-0.18.2.dist-info → sonusai-0.18.4.dist-info}/WHEEL +0 -0
- {sonusai-0.18.2.dist-info → sonusai-0.18.4.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py
CHANGED
@@ -6,6 +6,7 @@ from sqlite3 import Cursor
|
|
6
6
|
from typing import Any
|
7
7
|
from typing import Optional
|
8
8
|
|
9
|
+
from sonusai.mixture.datatypes import ASRConfigs
|
9
10
|
from sonusai.mixture.datatypes import AudioF
|
10
11
|
from sonusai.mixture.datatypes import AudioT
|
11
12
|
from sonusai.mixture.datatypes import AudiosF
|
@@ -88,6 +89,7 @@ class MixtureDatabase:
|
|
88
89
|
from .datatypes import MixtureDatabaseConfig
|
89
90
|
|
90
91
|
config = MixtureDatabaseConfig(
|
92
|
+
asr_configs=self.asr_configs,
|
91
93
|
class_balancing=self.class_balancing,
|
92
94
|
class_labels=self.class_labels,
|
93
95
|
class_weights_threshold=self.class_weights_thresholds,
|
@@ -145,6 +147,30 @@ class MixtureDatabase:
|
|
145
147
|
with self.db() as c:
|
146
148
|
return str(c.execute("SELECT top.noise_mix_mode FROM top").fetchone()[0])
|
147
149
|
|
150
|
+
@cached_property
|
151
|
+
def asr_configs(self) -> ASRConfigs:
|
152
|
+
import json
|
153
|
+
|
154
|
+
with self.db() as c:
|
155
|
+
return json.loads(c.execute("SELECT top.asr_configs FROM top").fetchone()[0])
|
156
|
+
|
157
|
+
@cached_property
|
158
|
+
def supported_metrics(self) -> set[str]:
|
159
|
+
metrics = {
|
160
|
+
'mxssnravg', 'mxssnrvar', 'mxssnrdavg', 'mxssnrdstd',
|
161
|
+
'mxpesq', 'mxcsig', 'mxcbak', 'mxcovl', 'mxwsdr',
|
162
|
+
'mxpd',
|
163
|
+
'mxstoi',
|
164
|
+
'tdco', 'tmin', 'tmax', 'tpkdb', 'tlrms', 'tpkr', 'ttr', 'tcr', 'tfl', 'tpkc',
|
165
|
+
'ndco', 'nmin', 'nmax', 'npkdb', 'nlrms', 'npkr', 'ntr', 'ncr', 'nfl', 'npkc',
|
166
|
+
'sedavg', 'sedcnt', 'sedtopn',
|
167
|
+
'ssnr',
|
168
|
+
}
|
169
|
+
for name in self.asr_configs:
|
170
|
+
metrics.add(f'mxwer.{name}')
|
171
|
+
|
172
|
+
return metrics
|
173
|
+
|
148
174
|
@cached_property
|
149
175
|
def class_balancing(self) -> bool:
|
150
176
|
with self.db() as c:
|
@@ -1108,173 +1134,286 @@ class MixtureDatabase:
|
|
1108
1134
|
|
1109
1135
|
return mixture_all_speech_metadata(self, self.mixture(m_id))
|
1110
1136
|
|
1111
|
-
def
|
1112
|
-
|
1137
|
+
def mixture_metrics(self, m_id: int,
|
1138
|
+
metrics: list[str],
|
1139
|
+
force: bool = False) -> list[float | int | Segsnr]:
|
1140
|
+
"""Get metrics data for the given mixture ID
|
1113
1141
|
|
1114
1142
|
:param m_id: Zero-based mixture ID
|
1115
|
-
:param
|
1143
|
+
:param metrics: List of metrics to get
|
1116
1144
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1117
|
-
:return:
|
1145
|
+
:return: List of metric data
|
1118
1146
|
"""
|
1147
|
+
from typing import Callable
|
1148
|
+
|
1149
|
+
import numpy as np
|
1150
|
+
from pystoi import stoi
|
1151
|
+
|
1119
1152
|
from sonusai import SonusAIError
|
1153
|
+
from sonusai.metrics import calc_audio_stats
|
1154
|
+
from sonusai.metrics import calc_phase_distance
|
1155
|
+
from sonusai.metrics import calc_snr_f
|
1156
|
+
from sonusai.metrics import calc_speech
|
1157
|
+
from sonusai.metrics import calc_wer
|
1158
|
+
from sonusai.metrics import calc_wsdr
|
1159
|
+
from sonusai.mixture import SAMPLE_RATE
|
1160
|
+
from sonusai.mixture import AudioStatsMetrics
|
1161
|
+
from sonusai.mixture import SpeechMetrics
|
1162
|
+
from sonusai.utils import calc_asr
|
1120
1163
|
|
1121
|
-
|
1122
|
-
|
1123
|
-
'MXSSNRAVG',
|
1124
|
-
'MXSSNRSTD',
|
1125
|
-
'MXSSNRDAVG',
|
1126
|
-
'MXSSNRDSTD',
|
1127
|
-
'MXPESQ',
|
1128
|
-
'MXWSDR',
|
1129
|
-
'MXPD',
|
1130
|
-
'MXSTOI',
|
1131
|
-
'MXCSIG',
|
1132
|
-
'MXCBAK',
|
1133
|
-
'MXCOVL',
|
1134
|
-
'TDCO',
|
1135
|
-
'TMIN',
|
1136
|
-
'TMAX',
|
1137
|
-
'TPKDB',
|
1138
|
-
'TLRMS',
|
1139
|
-
'TPKR',
|
1140
|
-
'TTR',
|
1141
|
-
'TCR',
|
1142
|
-
'TFL',
|
1143
|
-
'TPKC',
|
1144
|
-
'NDCO',
|
1145
|
-
'NMIN',
|
1146
|
-
'NMAX',
|
1147
|
-
'NPKDB',
|
1148
|
-
'NLRMS',
|
1149
|
-
'NPKR',
|
1150
|
-
'NTR',
|
1151
|
-
'NCR',
|
1152
|
-
'NFL',
|
1153
|
-
'NPKC',
|
1154
|
-
'SEDAVG',
|
1155
|
-
'SEDCNT',
|
1156
|
-
'SEDTOPN',
|
1157
|
-
)
|
1164
|
+
def create_target_audio() -> Callable:
|
1165
|
+
state = None
|
1158
1166
|
|
1159
|
-
|
1160
|
-
|
1167
|
+
def get() -> np.ndarray:
|
1168
|
+
nonlocal state
|
1169
|
+
if state is None:
|
1170
|
+
state = self.mixture_target(m_id)
|
1171
|
+
return state
|
1161
1172
|
|
1162
|
-
|
1163
|
-
result = self.read_mixture_data(m_id, metric)
|
1164
|
-
if result is not None:
|
1165
|
-
return result
|
1173
|
+
return get
|
1166
1174
|
|
1167
|
-
|
1168
|
-
if mixture is None:
|
1169
|
-
raise SonusAIError(f'Could not find mixture for m_id: {m_id}')
|
1175
|
+
target_audio = create_target_audio()
|
1170
1176
|
|
1171
|
-
|
1172
|
-
|
1177
|
+
def create_noise_audio() -> Callable:
|
1178
|
+
state = None
|
1173
1179
|
|
1174
|
-
|
1175
|
-
|
1180
|
+
def get() -> np.ndarray:
|
1181
|
+
nonlocal state
|
1182
|
+
if state is None:
|
1183
|
+
state = self.mixture_noise(m_id)
|
1184
|
+
return state
|
1176
1185
|
|
1177
|
-
|
1178
|
-
return None
|
1186
|
+
return get
|
1179
1187
|
|
1180
|
-
|
1181
|
-
return None
|
1188
|
+
noise_audio = create_noise_audio()
|
1182
1189
|
|
1183
|
-
|
1184
|
-
|
1190
|
+
def create_mixture_audio() -> Callable:
|
1191
|
+
state = None
|
1185
1192
|
|
1186
|
-
|
1187
|
-
|
1193
|
+
def get() -> np.ndarray:
|
1194
|
+
nonlocal state
|
1195
|
+
if state is None:
|
1196
|
+
state = self.mixture_mixture(m_id)
|
1197
|
+
return state
|
1188
1198
|
|
1189
|
-
|
1190
|
-
return None
|
1199
|
+
return get
|
1191
1200
|
|
1192
|
-
|
1193
|
-
return None
|
1201
|
+
mixture_audio = create_mixture_audio()
|
1194
1202
|
|
1195
|
-
|
1196
|
-
|
1203
|
+
def create_segsnr_f() -> Callable:
|
1204
|
+
state = None
|
1197
1205
|
|
1198
|
-
|
1199
|
-
|
1206
|
+
def get() -> np.ndarray:
|
1207
|
+
nonlocal state
|
1208
|
+
if state is None:
|
1209
|
+
state = self.mixture_segsnr(m_id)
|
1210
|
+
return state
|
1200
1211
|
|
1201
|
-
|
1202
|
-
return None
|
1212
|
+
return get
|
1203
1213
|
|
1204
|
-
|
1205
|
-
return None
|
1214
|
+
segsnr_f = create_segsnr_f()
|
1206
1215
|
|
1207
|
-
|
1208
|
-
|
1216
|
+
def create_speech() -> Callable:
|
1217
|
+
state = None
|
1209
1218
|
|
1210
|
-
|
1211
|
-
|
1219
|
+
def get() -> SpeechMetrics:
|
1220
|
+
nonlocal state
|
1221
|
+
if state is None:
|
1222
|
+
state = calc_speech(hypothesis=mixture_audio(), reference=target_audio())
|
1223
|
+
return state
|
1212
1224
|
|
1213
|
-
|
1214
|
-
return None
|
1225
|
+
return get
|
1215
1226
|
|
1216
|
-
|
1217
|
-
return None
|
1227
|
+
speech = create_speech()
|
1218
1228
|
|
1219
|
-
|
1220
|
-
|
1229
|
+
def create_target_stats() -> Callable:
|
1230
|
+
state = None
|
1221
1231
|
|
1222
|
-
|
1223
|
-
|
1232
|
+
def get() -> AudioStatsMetrics:
|
1233
|
+
nonlocal state
|
1234
|
+
if state is None:
|
1235
|
+
state = calc_audio_stats(target_audio(), self.fg_info.ft_config.N / SAMPLE_RATE)
|
1236
|
+
return state
|
1224
1237
|
|
1225
|
-
|
1226
|
-
return None
|
1238
|
+
return get
|
1227
1239
|
|
1228
|
-
|
1229
|
-
return None
|
1240
|
+
target_stats = create_target_stats()
|
1230
1241
|
|
1231
|
-
|
1232
|
-
|
1242
|
+
def create_noise_stats() -> Callable:
|
1243
|
+
state = None
|
1233
1244
|
|
1234
|
-
|
1235
|
-
|
1245
|
+
def get() -> AudioStatsMetrics:
|
1246
|
+
nonlocal state
|
1247
|
+
if state is None:
|
1248
|
+
state = calc_audio_stats(noise_audio(), self.fg_info.ft_config.N / SAMPLE_RATE)
|
1249
|
+
return state
|
1236
1250
|
|
1237
|
-
|
1238
|
-
return None
|
1251
|
+
return get
|
1239
1252
|
|
1240
|
-
|
1241
|
-
return None
|
1253
|
+
noise_stats = create_noise_stats()
|
1242
1254
|
|
1243
|
-
|
1244
|
-
|
1255
|
+
def calc(m: str) -> float | int | Segsnr:
|
1256
|
+
if m == 'mxsnr':
|
1257
|
+
return self.mixture(m_id).snr
|
1245
1258
|
|
1246
|
-
|
1247
|
-
|
1259
|
+
# Get cached data first, if exists
|
1260
|
+
if not force:
|
1261
|
+
value = self.read_mixture_data(m_id, m)
|
1262
|
+
if value is not None:
|
1263
|
+
return value
|
1248
1264
|
|
1249
|
-
|
1250
|
-
|
1265
|
+
# Otherwise, generate data as needed
|
1266
|
+
if m.startswith('mxwer'):
|
1267
|
+
parts = m.split('.')
|
1268
|
+
if len(parts) != 3:
|
1269
|
+
raise SonusAIError(
|
1270
|
+
f"Unrecognized 'mwwer' format: '{m}'; must be of the form: 'mxwer.<engine>.<model>'")
|
1271
|
+
asr_engine = parts[1]
|
1272
|
+
asr_model = parts[2]
|
1251
1273
|
|
1252
|
-
|
1253
|
-
|
1274
|
+
if asr_engine == 'none' or self.mixture(m_id).snr < -96:
|
1275
|
+
# noise only, ignore/reset target asr
|
1276
|
+
return float('nan')
|
1254
1277
|
|
1255
|
-
|
1256
|
-
|
1278
|
+
# ignore mixup
|
1279
|
+
target_asr = self.mixture_speech_metadata(m_id, 'text')[0]
|
1280
|
+
if target_asr is None:
|
1281
|
+
target_asr = calc_asr(target_audio(), engine=asr_engine, whisper_model_name=asr_model).text
|
1257
1282
|
|
1258
|
-
|
1259
|
-
|
1283
|
+
if target_asr:
|
1284
|
+
mixture_asr = calc_asr(mixture_audio(), engine=asr_engine, whisper_model_name=asr_model).text
|
1285
|
+
return calc_wer(mixture_asr, target_asr).wer * 100
|
1260
1286
|
|
1261
|
-
|
1262
|
-
|
1287
|
+
# TODO: should this be NaN like above?
|
1288
|
+
return float(0)
|
1263
1289
|
|
1264
|
-
|
1265
|
-
|
1290
|
+
if m == 'mxssnravg':
|
1291
|
+
return calc_snr_f(segsnr_f()).mean
|
1266
1292
|
|
1267
|
-
|
1268
|
-
|
1293
|
+
if m == 'mxssnrvar':
|
1294
|
+
return calc_snr_f(segsnr_f()).var
|
1269
1295
|
|
1270
|
-
|
1271
|
-
|
1296
|
+
if m == 'mxssnrdavg':
|
1297
|
+
return calc_snr_f(segsnr_f()).db_mean
|
1272
1298
|
|
1273
|
-
|
1274
|
-
|
1299
|
+
if m == 'mxssnrdstd':
|
1300
|
+
return calc_snr_f(segsnr_f()).db_std
|
1275
1301
|
|
1276
|
-
|
1277
|
-
|
1302
|
+
if m == 'mxpesq':
|
1303
|
+
if self.mixture(m_id).snr < -96:
|
1304
|
+
return 0
|
1305
|
+
return speech().pesq
|
1306
|
+
|
1307
|
+
if m == 'mxcsig':
|
1308
|
+
if self.mixture(m_id).snr < -96:
|
1309
|
+
return 0
|
1310
|
+
return speech().c_sig
|
1311
|
+
|
1312
|
+
if m == 'mxcbak':
|
1313
|
+
if self.mixture(m_id).snr < -96:
|
1314
|
+
return 0
|
1315
|
+
return speech().c_bak
|
1316
|
+
|
1317
|
+
if m == 'mxcovl':
|
1318
|
+
if self.mixture(m_id).snr < -96:
|
1319
|
+
return 0
|
1320
|
+
return speech().c_ovl
|
1321
|
+
|
1322
|
+
if m == 'mxwsdr':
|
1323
|
+
mixture = mixture_audio()[:, np.newaxis]
|
1324
|
+
target = target_audio()[:, np.newaxis]
|
1325
|
+
noise = noise_audio()[:, np.newaxis]
|
1326
|
+
return calc_wsdr(hypothesis=np.concatenate((mixture, noise), axis=1),
|
1327
|
+
reference=np.concatenate((target, noise), axis=1),
|
1328
|
+
with_log=True)[0]
|
1329
|
+
|
1330
|
+
if m == 'mxpd':
|
1331
|
+
mixture_f = self.mixture_mixture_f(m_id)
|
1332
|
+
target_f = self.mixture_target_f(m_id)
|
1333
|
+
return calc_phase_distance(hypothesis=mixture_f, reference=target_f)[0]
|
1334
|
+
|
1335
|
+
if m == 'mxstoi':
|
1336
|
+
return stoi(x=target_audio(), y=mixture_audio(), fs_sig=SAMPLE_RATE, extended=False)
|
1337
|
+
|
1338
|
+
if m == 'tdco':
|
1339
|
+
return target_stats().dco
|
1340
|
+
|
1341
|
+
if m == 'tmin':
|
1342
|
+
return target_stats().min
|
1343
|
+
|
1344
|
+
if m == 'tmax':
|
1345
|
+
return target_stats().max
|
1346
|
+
|
1347
|
+
if m == 'tpkdb':
|
1348
|
+
return target_stats().pkdb
|
1349
|
+
|
1350
|
+
if m == 'tlrms':
|
1351
|
+
return target_stats().lrms
|
1352
|
+
|
1353
|
+
if m == 'tpkr':
|
1354
|
+
return target_stats().pkr
|
1355
|
+
|
1356
|
+
if m == 'ttr':
|
1357
|
+
return target_stats().tr
|
1358
|
+
|
1359
|
+
if m == 'tcr':
|
1360
|
+
return target_stats().cr
|
1361
|
+
|
1362
|
+
if m == 'tfl':
|
1363
|
+
return target_stats().fl
|
1364
|
+
|
1365
|
+
if m == 'tpkc':
|
1366
|
+
return target_stats().pkc
|
1367
|
+
|
1368
|
+
if m == 'ndco':
|
1369
|
+
return noise_stats().dco
|
1370
|
+
|
1371
|
+
if m == 'nmin':
|
1372
|
+
return noise_stats().min
|
1373
|
+
|
1374
|
+
if m == 'nmax':
|
1375
|
+
return noise_stats().max
|
1376
|
+
|
1377
|
+
if m == 'npkdb':
|
1378
|
+
return noise_stats().pkdb
|
1379
|
+
|
1380
|
+
if m == 'nlrms':
|
1381
|
+
return noise_stats().lrms
|
1382
|
+
|
1383
|
+
if m == 'npkr':
|
1384
|
+
return noise_stats().pkr
|
1385
|
+
|
1386
|
+
if m == 'ntr':
|
1387
|
+
return noise_stats().tr
|
1388
|
+
|
1389
|
+
if m == 'ncr':
|
1390
|
+
return noise_stats().cr
|
1391
|
+
|
1392
|
+
if m == 'nfl':
|
1393
|
+
return noise_stats().fl
|
1394
|
+
|
1395
|
+
if m == 'npkc':
|
1396
|
+
return noise_stats().pkc
|
1397
|
+
|
1398
|
+
if m == 'sedavg':
|
1399
|
+
return 0
|
1400
|
+
|
1401
|
+
if m == 'sedcnt':
|
1402
|
+
return 0
|
1403
|
+
|
1404
|
+
if m == 'sedtopn':
|
1405
|
+
return 0
|
1406
|
+
|
1407
|
+
if m == 'ssnr':
|
1408
|
+
return self.mixture_segsnr(m_id)
|
1409
|
+
|
1410
|
+
raise SonusAIError(f"Unrecognized metric: '{m}'")
|
1411
|
+
|
1412
|
+
result: list[float | int | Segsnr] = []
|
1413
|
+
for metric in metrics:
|
1414
|
+
result.append(calc(metric))
|
1415
|
+
|
1416
|
+
return result
|
1278
1417
|
|
1279
1418
|
|
1280
1419
|
@lru_cache
|
@@ -1,8 +1,10 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
1
3
|
from sonusai.mixture.datatypes import AudioT
|
2
4
|
from sonusai.mixture.datatypes import ImpulseResponseData
|
3
5
|
|
4
6
|
|
5
|
-
def _raw_read(name: str) -> tuple[AudioT, int]:
|
7
|
+
def _raw_read(name: str | Path) -> tuple[AudioT, int]:
|
6
8
|
import numpy as np
|
7
9
|
import soundfile
|
8
10
|
from pydub import AudioSegment
|
@@ -34,7 +36,7 @@ def _raw_read(name: str) -> tuple[AudioT, int]:
|
|
34
36
|
return np.squeeze(raw[:, 0]), sample_rate
|
35
37
|
|
36
38
|
|
37
|
-
def get_sample_rate(name: str) -> int:
|
39
|
+
def get_sample_rate(name: str | Path) -> int:
|
38
40
|
"""Get sample rate from audio file using soundfile
|
39
41
|
|
40
42
|
:param name: File name
|
@@ -63,7 +65,7 @@ def get_sample_rate(name: str) -> int:
|
|
63
65
|
raise SonusAIError(f'Error reading {name}: {e}')
|
64
66
|
|
65
67
|
|
66
|
-
def read_ir(name: str) -> ImpulseResponseData:
|
68
|
+
def read_ir(name: str | Path) -> ImpulseResponseData:
|
67
69
|
"""Read impulse response data using soundfile
|
68
70
|
|
69
71
|
:param name: File name
|
@@ -79,10 +81,10 @@ def read_ir(name: str) -> ImpulseResponseData:
|
|
79
81
|
out = out[offset:]
|
80
82
|
out = out / np.linalg.norm(out)
|
81
83
|
|
82
|
-
return ImpulseResponseData(name=name, sample_rate=sample_rate, data=out)
|
84
|
+
return ImpulseResponseData(name=str(name), sample_rate=sample_rate, data=out)
|
83
85
|
|
84
86
|
|
85
|
-
def read_audio(name: str) -> AudioT:
|
87
|
+
def read_audio(name: str | Path) -> AudioT:
|
86
88
|
"""Read audio data from a file using soundfile
|
87
89
|
|
88
90
|
:param name: File name
|
@@ -101,7 +103,7 @@ def read_audio(name: str) -> AudioT:
|
|
101
103
|
return out
|
102
104
|
|
103
105
|
|
104
|
-
def get_num_samples(name: str) -> int:
|
106
|
+
def get_num_samples(name: str | Path) -> int:
|
105
107
|
"""Get the number of samples resampled to the SonusAI sample rate in the given file
|
106
108
|
|
107
109
|
:param name: File name
|
sonusai/mixture/sox_audio.py
CHANGED
@@ -1,16 +1,19 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import numpy as np
|
1
5
|
from sox import Transformer as SoxTransformer
|
2
6
|
|
3
7
|
from sonusai.mixture.datatypes import AudioT
|
4
8
|
from sonusai.mixture.datatypes import ImpulseResponseData
|
5
9
|
|
6
10
|
|
7
|
-
def read_impulse_response(name: str) -> ImpulseResponseData:
|
11
|
+
def read_impulse_response(name: str | Path) -> ImpulseResponseData:
|
8
12
|
"""Read impulse response data using SoX
|
9
13
|
|
10
14
|
:param name: File name
|
11
15
|
:return: ImpulseResponseData object
|
12
16
|
"""
|
13
|
-
import numpy as np
|
14
17
|
from scipy.io import wavfile
|
15
18
|
|
16
19
|
from sonusai import SonusAIError
|
@@ -33,10 +36,10 @@ def read_impulse_response(name: str) -> ImpulseResponseData:
|
|
33
36
|
data = data[offset:]
|
34
37
|
data = data / np.linalg.norm(data)
|
35
38
|
|
36
|
-
return ImpulseResponseData(name=name, sample_rate=sample_rate, data=data)
|
39
|
+
return ImpulseResponseData(name=str(name), sample_rate=sample_rate, data=data)
|
37
40
|
|
38
41
|
|
39
|
-
def read_audio(name: str) -> AudioT:
|
42
|
+
def read_audio(name: str | Path) -> AudioT:
|
40
43
|
"""Read audio data from a file using SoX
|
41
44
|
|
42
45
|
:param name: File name
|
@@ -44,7 +47,6 @@ def read_audio(name: str) -> AudioT:
|
|
44
47
|
"""
|
45
48
|
from typing import Any
|
46
49
|
|
47
|
-
import numpy as np
|
48
50
|
from sox.core import sox
|
49
51
|
|
50
52
|
from sonusai import SonusAIError
|
@@ -208,8 +210,11 @@ class Transformer(SoxTransformer):
|
|
208
210
|
|
209
211
|
return self
|
210
212
|
|
211
|
-
def build_array(self,
|
212
|
-
|
213
|
+
def build_array(self,
|
214
|
+
input_filepath: Optional[str | Path] = None,
|
215
|
+
input_array: Optional[np.ndarray] = None,
|
216
|
+
sample_rate_in: Optional[int] = None,
|
217
|
+
extra_args: Optional[list[str]] = None) -> np.ndarray:
|
213
218
|
"""Given an input file or array, returns the output as a numpy array
|
214
219
|
by executing the current set of commands. By default, the array will
|
215
220
|
have the same sample rate as the input file unless otherwise specified
|
@@ -220,7 +225,7 @@ class Transformer(SoxTransformer):
|
|
220
225
|
|
221
226
|
Parameters
|
222
227
|
----------
|
223
|
-
input_filepath : str or None
|
228
|
+
input_filepath : str, Path or None
|
224
229
|
Either path to input audio file or None.
|
225
230
|
input_array : np.ndarray or None
|
226
231
|
A np.ndarray of a waveform with shape (n_samples, n_channels).
|
@@ -270,8 +275,6 @@ class Transformer(SoxTransformer):
|
|
270
275
|
|
271
276
|
|
272
277
|
"""
|
273
|
-
import numpy as np
|
274
|
-
|
275
278
|
from sox.core import SoxError
|
276
279
|
from sox.core import sox
|
277
280
|
from sox.log import logger
|
@@ -324,13 +327,13 @@ class Transformer(SoxTransformer):
|
|
324
327
|
|
325
328
|
match n_bits:
|
326
329
|
case 8:
|
327
|
-
encoding_out = np.int8
|
330
|
+
encoding_out = np.int8 # type: ignore
|
328
331
|
case 16:
|
329
332
|
encoding_out = np.int16
|
330
333
|
case 32:
|
331
|
-
encoding_out = np.float32
|
334
|
+
encoding_out = np.float32 # type: ignore
|
332
335
|
case 64:
|
333
|
-
encoding_out = np.float64
|
336
|
+
encoding_out = np.float64 # type: ignore
|
334
337
|
case _:
|
335
338
|
raise ValueError("invalid n_bits {}".format(n_bits))
|
336
339
|
|
@@ -1,8 +1,10 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
1
3
|
from sonusai.mixture.datatypes import AudioT
|
2
4
|
from sonusai.mixture.datatypes import ImpulseResponseData
|
3
5
|
|
4
6
|
|
5
|
-
def read_impulse_response(name: str) -> ImpulseResponseData:
|
7
|
+
def read_impulse_response(name: str | Path) -> ImpulseResponseData:
|
6
8
|
"""Read impulse response data using torchaudio
|
7
9
|
|
8
10
|
:param name: File name
|
@@ -36,10 +38,10 @@ def read_impulse_response(name: str) -> ImpulseResponseData:
|
|
36
38
|
data = np.array(raw).astype(np.float32)
|
37
39
|
data = data / np.linalg.norm(data)
|
38
40
|
|
39
|
-
return ImpulseResponseData(name=name, sample_rate=sample_rate, data=data)
|
41
|
+
return ImpulseResponseData(name=str(name), sample_rate=sample_rate, data=data)
|
40
42
|
|
41
43
|
|
42
|
-
def get_sample_rate(name: str) -> int:
|
44
|
+
def get_sample_rate(name: str | Path) -> int:
|
43
45
|
"""Get sample rate from audio file using torchaudio
|
44
46
|
|
45
47
|
:param name: File name
|
@@ -61,7 +63,7 @@ def get_sample_rate(name: str) -> int:
|
|
61
63
|
raise SonusAIError(f'Error reading {name}:\n{e}')
|
62
64
|
|
63
65
|
|
64
|
-
def read_audio(name: str) -> AudioT:
|
66
|
+
def read_audio(name: str | Path) -> AudioT:
|
65
67
|
"""Read audio data from a file using torchaudio
|
66
68
|
|
67
69
|
:param name: File name
|