sonusai 0.18.1__py3-none-any.whl → 0.18.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +1 -0
- sonusai/audiofe.py +1 -1
- sonusai/calc_metric_spenh.py +32 -362
- sonusai/data/genmixdb.yml +2 -0
- sonusai/doc/doc.py +45 -4
- sonusai/genmetrics.py +137 -109
- sonusai/lsdb.py +2 -2
- sonusai/metrics/__init__.py +4 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_pesq.py +12 -8
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_snr_f.py +34 -0
- sonusai/metrics/calc_speech.py +312 -0
- sonusai/metrics/calc_wer.py +2 -3
- sonusai/metrics/calc_wsdr.py +0 -59
- sonusai/mixture/__init__.py +3 -2
- sonusai/mixture/audio.py +6 -5
- sonusai/mixture/config.py +13 -0
- sonusai/mixture/constants.py +1 -0
- sonusai/mixture/datatypes.py +33 -0
- sonusai/mixture/generation.py +6 -2
- sonusai/mixture/mixdb.py +284 -148
- sonusai/mixture/soundfile_audio.py +8 -6
- sonusai/mixture/sox_audio.py +16 -13
- sonusai/mixture/torchaudio_audio.py +6 -4
- sonusai/mixture/truth_functions/energy.py +40 -28
- sonusai/mixture/truth_functions/target.py +0 -1
- sonusai/utils/__init__.py +1 -1
- sonusai/utils/asr.py +26 -39
- sonusai/utils/asr_functions/aaware_whisper.py +3 -3
- {sonusai-0.18.1.dist-info → sonusai-0.18.4.dist-info}/METADATA +1 -1
- {sonusai-0.18.1.dist-info → sonusai-0.18.4.dist-info}/RECORD +34 -31
- sonusai/mixture/mapped_snr_f.py +0 -100
- {sonusai-0.18.1.dist-info → sonusai-0.18.4.dist-info}/WHEEL +0 -0
- {sonusai-0.18.1.dist-info → sonusai-0.18.4.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py
CHANGED
@@ -4,9 +4,9 @@ from functools import partial
|
|
4
4
|
from sqlite3 import Connection
|
5
5
|
from sqlite3 import Cursor
|
6
6
|
from typing import Any
|
7
|
-
from typing import Callable
|
8
7
|
from typing import Optional
|
9
8
|
|
9
|
+
from sonusai.mixture.datatypes import ASRConfigs
|
10
10
|
from sonusai.mixture.datatypes import AudioF
|
11
11
|
from sonusai.mixture.datatypes import AudioT
|
12
12
|
from sonusai.mixture.datatypes import AudiosF
|
@@ -89,6 +89,7 @@ class MixtureDatabase:
|
|
89
89
|
from .datatypes import MixtureDatabaseConfig
|
90
90
|
|
91
91
|
config = MixtureDatabaseConfig(
|
92
|
+
asr_configs=self.asr_configs,
|
92
93
|
class_balancing=self.class_balancing,
|
93
94
|
class_labels=self.class_labels,
|
94
95
|
class_weights_threshold=self.class_weights_thresholds,
|
@@ -146,6 +147,30 @@ class MixtureDatabase:
|
|
146
147
|
with self.db() as c:
|
147
148
|
return str(c.execute("SELECT top.noise_mix_mode FROM top").fetchone()[0])
|
148
149
|
|
150
|
+
@cached_property
|
151
|
+
def asr_configs(self) -> ASRConfigs:
|
152
|
+
import json
|
153
|
+
|
154
|
+
with self.db() as c:
|
155
|
+
return json.loads(c.execute("SELECT top.asr_configs FROM top").fetchone()[0])
|
156
|
+
|
157
|
+
@cached_property
|
158
|
+
def supported_metrics(self) -> set[str]:
|
159
|
+
metrics = {
|
160
|
+
'mxssnravg', 'mxssnrvar', 'mxssnrdavg', 'mxssnrdstd',
|
161
|
+
'mxpesq', 'mxcsig', 'mxcbak', 'mxcovl', 'mxwsdr',
|
162
|
+
'mxpd',
|
163
|
+
'mxstoi',
|
164
|
+
'tdco', 'tmin', 'tmax', 'tpkdb', 'tlrms', 'tpkr', 'ttr', 'tcr', 'tfl', 'tpkc',
|
165
|
+
'ndco', 'nmin', 'nmax', 'npkdb', 'nlrms', 'npkr', 'ntr', 'ncr', 'nfl', 'npkc',
|
166
|
+
'sedavg', 'sedcnt', 'sedtopn',
|
167
|
+
'ssnr',
|
168
|
+
}
|
169
|
+
for name in self.asr_configs:
|
170
|
+
metrics.add(f'mxwer.{name}')
|
171
|
+
|
172
|
+
return metrics
|
173
|
+
|
149
174
|
@cached_property
|
150
175
|
def class_balancing(self) -> bool:
|
151
176
|
with self.db() as c:
|
@@ -1062,13 +1087,13 @@ class MixtureDatabase:
|
|
1062
1087
|
|
1063
1088
|
def mixids_for_speech_metadata(self,
|
1064
1089
|
tier: str,
|
1065
|
-
value: str
|
1066
|
-
|
1090
|
+
value: str = None,
|
1091
|
+
where: str = None) -> list[int]:
|
1067
1092
|
"""Get a list of mixture IDs for the given speech metadata tier.
|
1068
1093
|
|
1069
|
-
If '
|
1070
|
-
If '
|
1071
|
-
to include.
|
1094
|
+
If 'where' is None, then include mixture IDs whose tier values are equal to the given 'value'.
|
1095
|
+
If 'where' is not None, then ignore 'value' and use the given SQL where clause to determine
|
1096
|
+
which entries to include.
|
1072
1097
|
|
1073
1098
|
Examples:
|
1074
1099
|
>>> mixdb = MixtureDatabase('/mixdb_location')
|
@@ -1076,208 +1101,319 @@ class MixtureDatabase:
|
|
1076
1101
|
>>> mixids = mixdb.mixids_for_speech_metadata('speaker_id', 'TIMIT_ARC0')
|
1077
1102
|
Get mixutre IDs for mixtures with speakers whose speaker_ids are 'TIMIT_ARC0'.
|
1078
1103
|
|
1079
|
-
>>> mixids = mixdb.mixids_for_speech_metadata('age', '
|
1104
|
+
>>> mixids = mixdb.mixids_for_speech_metadata('age', where='age < 25')
|
1080
1105
|
Get mixture IDs for mixtures with speakers whose ages are less than 25.
|
1081
1106
|
|
1082
|
-
>>> mixids = mixdb.mixids_for_speech_metadata('dialect',
|
1107
|
+
>>> mixids = mixdb.mixids_for_speech_metadata('dialect', where="dialect in ('New York City', 'Northern')")
|
1083
1108
|
Get mixture IDs for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
|
1084
1109
|
"""
|
1085
|
-
from
|
1110
|
+
from sonusai import SonusAIError
|
1086
1111
|
|
1087
|
-
if
|
1088
|
-
|
1089
|
-
return x == value
|
1112
|
+
if value is None and where is None:
|
1113
|
+
raise SonusAIError('Must provide either value or where')
|
1090
1114
|
|
1091
|
-
|
1092
|
-
|
1093
|
-
is_textgrid = tier in self.textgrid_metadata_tiers
|
1094
|
-
for target_file_id, target_file in enumerate(self.target_files):
|
1095
|
-
if is_textgrid:
|
1096
|
-
metadata = get_textgrid_tier_from_target_file(target_file.name, tier)
|
1097
|
-
else:
|
1098
|
-
metadata = self.speaker(target_file.speaker_id, tier)
|
1115
|
+
if where is None:
|
1116
|
+
where = f"{tier} = '{value}'"
|
1099
1117
|
|
1100
|
-
|
1101
|
-
|
1118
|
+
if tier in self.textgrid_metadata_tiers:
|
1119
|
+
raise SonusAIError(f'TextGrid tier data, "{tier}", is not supported in mixids_for_speech_metadata().')
|
1102
1120
|
|
1103
|
-
# Next get list of mixture IDs that contain those target files
|
1104
1121
|
with self.db() as c:
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1122
|
+
speaker_ids = [speaker_id[0] for speaker_id in
|
1123
|
+
c.execute(f"SELECT id FROM speaker WHERE {where}").fetchall()]
|
1124
|
+
results = c.execute(f"SELECT id FROM target_file " +
|
1125
|
+
f"WHERE speaker_id IN ({','.join(map(str, speaker_ids))})").fetchall()
|
1126
|
+
target_file_ids = [target_file_id[0] for target_file_id in results]
|
1127
|
+
results = c.execute("SELECT mixture_id FROM mixture_target " +
|
1128
|
+
f"WHERE mixture_target.target_id IN ({','.join(map(str, target_file_ids))})").fetchall()
|
1129
|
+
|
1130
|
+
return [mixture_id[0] - 1 for mixture_id in results]
|
1108
1131
|
|
1109
1132
|
def mixture_all_speech_metadata(self, m_id: int) -> list[dict[str, SpeechMetadata]]:
|
1110
1133
|
from .helpers import mixture_all_speech_metadata
|
1111
1134
|
|
1112
1135
|
return mixture_all_speech_metadata(self, self.mixture(m_id))
|
1113
1136
|
|
1114
|
-
def
|
1115
|
-
|
1137
|
+
def mixture_metrics(self, m_id: int,
|
1138
|
+
metrics: list[str],
|
1139
|
+
force: bool = False) -> list[float | int | Segsnr]:
|
1140
|
+
"""Get metrics data for the given mixture ID
|
1116
1141
|
|
1117
1142
|
:param m_id: Zero-based mixture ID
|
1118
|
-
:param
|
1143
|
+
:param metrics: List of metrics to get
|
1119
1144
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1120
|
-
:return:
|
1145
|
+
:return: List of metric data
|
1121
1146
|
"""
|
1147
|
+
from typing import Callable
|
1148
|
+
|
1149
|
+
import numpy as np
|
1150
|
+
from pystoi import stoi
|
1151
|
+
|
1122
1152
|
from sonusai import SonusAIError
|
1153
|
+
from sonusai.metrics import calc_audio_stats
|
1154
|
+
from sonusai.metrics import calc_phase_distance
|
1155
|
+
from sonusai.metrics import calc_snr_f
|
1156
|
+
from sonusai.metrics import calc_speech
|
1157
|
+
from sonusai.metrics import calc_wer
|
1158
|
+
from sonusai.metrics import calc_wsdr
|
1159
|
+
from sonusai.mixture import SAMPLE_RATE
|
1160
|
+
from sonusai.mixture import AudioStatsMetrics
|
1161
|
+
from sonusai.mixture import SpeechMetrics
|
1162
|
+
from sonusai.utils import calc_asr
|
1123
1163
|
|
1124
|
-
|
1125
|
-
|
1126
|
-
'MXSSNRAVG',
|
1127
|
-
'MXSSNRSTD',
|
1128
|
-
'MXSSNRDAVG',
|
1129
|
-
'MXSSNRDSTD',
|
1130
|
-
'MXPESQ',
|
1131
|
-
'MXWSDR',
|
1132
|
-
'MXPD',
|
1133
|
-
'MXSTOI',
|
1134
|
-
'MXCSIG',
|
1135
|
-
'MXCBAK',
|
1136
|
-
'MXCOVL',
|
1137
|
-
'TDCO',
|
1138
|
-
'TMIN',
|
1139
|
-
'TMAX',
|
1140
|
-
'TPKDB',
|
1141
|
-
'TLRMS',
|
1142
|
-
'TPKR',
|
1143
|
-
'TTR',
|
1144
|
-
'TCR',
|
1145
|
-
'TFL',
|
1146
|
-
'TPKC',
|
1147
|
-
'NDCO',
|
1148
|
-
'NMIN',
|
1149
|
-
'NMAX',
|
1150
|
-
'NPKDB',
|
1151
|
-
'NLRMS',
|
1152
|
-
'NPKR',
|
1153
|
-
'NTR',
|
1154
|
-
'NCR',
|
1155
|
-
'NFL',
|
1156
|
-
'NPKC',
|
1157
|
-
'SEDAVG',
|
1158
|
-
'SEDCNT',
|
1159
|
-
'SEDTOPN',
|
1160
|
-
)
|
1164
|
+
def create_target_audio() -> Callable:
|
1165
|
+
state = None
|
1161
1166
|
|
1162
|
-
|
1163
|
-
|
1167
|
+
def get() -> np.ndarray:
|
1168
|
+
nonlocal state
|
1169
|
+
if state is None:
|
1170
|
+
state = self.mixture_target(m_id)
|
1171
|
+
return state
|
1164
1172
|
|
1165
|
-
|
1166
|
-
result = self.read_mixture_data(m_id, metric)
|
1167
|
-
if result is not None:
|
1168
|
-
return result
|
1173
|
+
return get
|
1169
1174
|
|
1170
|
-
|
1171
|
-
if mixture is None:
|
1172
|
-
raise SonusAIError(f'Could not find mixture for m_id: {m_id}')
|
1175
|
+
target_audio = create_target_audio()
|
1173
1176
|
|
1174
|
-
|
1175
|
-
|
1177
|
+
def create_noise_audio() -> Callable:
|
1178
|
+
state = None
|
1176
1179
|
|
1177
|
-
|
1178
|
-
|
1180
|
+
def get() -> np.ndarray:
|
1181
|
+
nonlocal state
|
1182
|
+
if state is None:
|
1183
|
+
state = self.mixture_noise(m_id)
|
1184
|
+
return state
|
1179
1185
|
|
1180
|
-
|
1181
|
-
return None
|
1186
|
+
return get
|
1182
1187
|
|
1183
|
-
|
1184
|
-
return None
|
1188
|
+
noise_audio = create_noise_audio()
|
1185
1189
|
|
1186
|
-
|
1187
|
-
|
1190
|
+
def create_mixture_audio() -> Callable:
|
1191
|
+
state = None
|
1188
1192
|
|
1189
|
-
|
1190
|
-
|
1193
|
+
def get() -> np.ndarray:
|
1194
|
+
nonlocal state
|
1195
|
+
if state is None:
|
1196
|
+
state = self.mixture_mixture(m_id)
|
1197
|
+
return state
|
1191
1198
|
|
1192
|
-
|
1193
|
-
return None
|
1199
|
+
return get
|
1194
1200
|
|
1195
|
-
|
1196
|
-
return None
|
1201
|
+
mixture_audio = create_mixture_audio()
|
1197
1202
|
|
1198
|
-
|
1199
|
-
|
1203
|
+
def create_segsnr_f() -> Callable:
|
1204
|
+
state = None
|
1200
1205
|
|
1201
|
-
|
1202
|
-
|
1206
|
+
def get() -> np.ndarray:
|
1207
|
+
nonlocal state
|
1208
|
+
if state is None:
|
1209
|
+
state = self.mixture_segsnr(m_id)
|
1210
|
+
return state
|
1203
1211
|
|
1204
|
-
|
1205
|
-
return None
|
1212
|
+
return get
|
1206
1213
|
|
1207
|
-
|
1208
|
-
return None
|
1214
|
+
segsnr_f = create_segsnr_f()
|
1209
1215
|
|
1210
|
-
|
1211
|
-
|
1216
|
+
def create_speech() -> Callable:
|
1217
|
+
state = None
|
1212
1218
|
|
1213
|
-
|
1214
|
-
|
1219
|
+
def get() -> SpeechMetrics:
|
1220
|
+
nonlocal state
|
1221
|
+
if state is None:
|
1222
|
+
state = calc_speech(hypothesis=mixture_audio(), reference=target_audio())
|
1223
|
+
return state
|
1215
1224
|
|
1216
|
-
|
1217
|
-
return None
|
1225
|
+
return get
|
1218
1226
|
|
1219
|
-
|
1220
|
-
return None
|
1227
|
+
speech = create_speech()
|
1221
1228
|
|
1222
|
-
|
1223
|
-
|
1229
|
+
def create_target_stats() -> Callable:
|
1230
|
+
state = None
|
1224
1231
|
|
1225
|
-
|
1226
|
-
|
1232
|
+
def get() -> AudioStatsMetrics:
|
1233
|
+
nonlocal state
|
1234
|
+
if state is None:
|
1235
|
+
state = calc_audio_stats(target_audio(), self.fg_info.ft_config.N / SAMPLE_RATE)
|
1236
|
+
return state
|
1227
1237
|
|
1228
|
-
|
1229
|
-
return None
|
1238
|
+
return get
|
1230
1239
|
|
1231
|
-
|
1232
|
-
return None
|
1240
|
+
target_stats = create_target_stats()
|
1233
1241
|
|
1234
|
-
|
1235
|
-
|
1242
|
+
def create_noise_stats() -> Callable:
|
1243
|
+
state = None
|
1236
1244
|
|
1237
|
-
|
1238
|
-
|
1245
|
+
def get() -> AudioStatsMetrics:
|
1246
|
+
nonlocal state
|
1247
|
+
if state is None:
|
1248
|
+
state = calc_audio_stats(noise_audio(), self.fg_info.ft_config.N / SAMPLE_RATE)
|
1249
|
+
return state
|
1239
1250
|
|
1240
|
-
|
1241
|
-
return None
|
1251
|
+
return get
|
1242
1252
|
|
1243
|
-
|
1244
|
-
return None
|
1253
|
+
noise_stats = create_noise_stats()
|
1245
1254
|
|
1246
|
-
|
1247
|
-
|
1255
|
+
def calc(m: str) -> float | int | Segsnr:
|
1256
|
+
if m == 'mxsnr':
|
1257
|
+
return self.mixture(m_id).snr
|
1248
1258
|
|
1249
|
-
|
1250
|
-
|
1259
|
+
# Get cached data first, if exists
|
1260
|
+
if not force:
|
1261
|
+
value = self.read_mixture_data(m_id, m)
|
1262
|
+
if value is not None:
|
1263
|
+
return value
|
1251
1264
|
|
1252
|
-
|
1253
|
-
|
1265
|
+
# Otherwise, generate data as needed
|
1266
|
+
if m.startswith('mxwer'):
|
1267
|
+
parts = m.split('.')
|
1268
|
+
if len(parts) != 3:
|
1269
|
+
raise SonusAIError(
|
1270
|
+
f"Unrecognized 'mwwer' format: '{m}'; must be of the form: 'mxwer.<engine>.<model>'")
|
1271
|
+
asr_engine = parts[1]
|
1272
|
+
asr_model = parts[2]
|
1254
1273
|
|
1255
|
-
|
1256
|
-
|
1274
|
+
if asr_engine == 'none' or self.mixture(m_id).snr < -96:
|
1275
|
+
# noise only, ignore/reset target asr
|
1276
|
+
return float('nan')
|
1257
1277
|
|
1258
|
-
|
1259
|
-
|
1278
|
+
# ignore mixup
|
1279
|
+
target_asr = self.mixture_speech_metadata(m_id, 'text')[0]
|
1280
|
+
if target_asr is None:
|
1281
|
+
target_asr = calc_asr(target_audio(), engine=asr_engine, whisper_model_name=asr_model).text
|
1260
1282
|
|
1261
|
-
|
1262
|
-
|
1283
|
+
if target_asr:
|
1284
|
+
mixture_asr = calc_asr(mixture_audio(), engine=asr_engine, whisper_model_name=asr_model).text
|
1285
|
+
return calc_wer(mixture_asr, target_asr).wer * 100
|
1263
1286
|
|
1264
|
-
|
1265
|
-
|
1287
|
+
# TODO: should this be NaN like above?
|
1288
|
+
return float(0)
|
1266
1289
|
|
1267
|
-
|
1268
|
-
|
1290
|
+
if m == 'mxssnravg':
|
1291
|
+
return calc_snr_f(segsnr_f()).mean
|
1269
1292
|
|
1270
|
-
|
1271
|
-
|
1293
|
+
if m == 'mxssnrvar':
|
1294
|
+
return calc_snr_f(segsnr_f()).var
|
1272
1295
|
|
1273
|
-
|
1274
|
-
|
1296
|
+
if m == 'mxssnrdavg':
|
1297
|
+
return calc_snr_f(segsnr_f()).db_mean
|
1275
1298
|
|
1276
|
-
|
1277
|
-
|
1299
|
+
if m == 'mxssnrdstd':
|
1300
|
+
return calc_snr_f(segsnr_f()).db_std
|
1278
1301
|
|
1279
|
-
|
1280
|
-
|
1302
|
+
if m == 'mxpesq':
|
1303
|
+
if self.mixture(m_id).snr < -96:
|
1304
|
+
return 0
|
1305
|
+
return speech().pesq
|
1306
|
+
|
1307
|
+
if m == 'mxcsig':
|
1308
|
+
if self.mixture(m_id).snr < -96:
|
1309
|
+
return 0
|
1310
|
+
return speech().c_sig
|
1311
|
+
|
1312
|
+
if m == 'mxcbak':
|
1313
|
+
if self.mixture(m_id).snr < -96:
|
1314
|
+
return 0
|
1315
|
+
return speech().c_bak
|
1316
|
+
|
1317
|
+
if m == 'mxcovl':
|
1318
|
+
if self.mixture(m_id).snr < -96:
|
1319
|
+
return 0
|
1320
|
+
return speech().c_ovl
|
1321
|
+
|
1322
|
+
if m == 'mxwsdr':
|
1323
|
+
mixture = mixture_audio()[:, np.newaxis]
|
1324
|
+
target = target_audio()[:, np.newaxis]
|
1325
|
+
noise = noise_audio()[:, np.newaxis]
|
1326
|
+
return calc_wsdr(hypothesis=np.concatenate((mixture, noise), axis=1),
|
1327
|
+
reference=np.concatenate((target, noise), axis=1),
|
1328
|
+
with_log=True)[0]
|
1329
|
+
|
1330
|
+
if m == 'mxpd':
|
1331
|
+
mixture_f = self.mixture_mixture_f(m_id)
|
1332
|
+
target_f = self.mixture_target_f(m_id)
|
1333
|
+
return calc_phase_distance(hypothesis=mixture_f, reference=target_f)[0]
|
1334
|
+
|
1335
|
+
if m == 'mxstoi':
|
1336
|
+
return stoi(x=target_audio(), y=mixture_audio(), fs_sig=SAMPLE_RATE, extended=False)
|
1337
|
+
|
1338
|
+
if m == 'tdco':
|
1339
|
+
return target_stats().dco
|
1340
|
+
|
1341
|
+
if m == 'tmin':
|
1342
|
+
return target_stats().min
|
1343
|
+
|
1344
|
+
if m == 'tmax':
|
1345
|
+
return target_stats().max
|
1346
|
+
|
1347
|
+
if m == 'tpkdb':
|
1348
|
+
return target_stats().pkdb
|
1349
|
+
|
1350
|
+
if m == 'tlrms':
|
1351
|
+
return target_stats().lrms
|
1352
|
+
|
1353
|
+
if m == 'tpkr':
|
1354
|
+
return target_stats().pkr
|
1355
|
+
|
1356
|
+
if m == 'ttr':
|
1357
|
+
return target_stats().tr
|
1358
|
+
|
1359
|
+
if m == 'tcr':
|
1360
|
+
return target_stats().cr
|
1361
|
+
|
1362
|
+
if m == 'tfl':
|
1363
|
+
return target_stats().fl
|
1364
|
+
|
1365
|
+
if m == 'tpkc':
|
1366
|
+
return target_stats().pkc
|
1367
|
+
|
1368
|
+
if m == 'ndco':
|
1369
|
+
return noise_stats().dco
|
1370
|
+
|
1371
|
+
if m == 'nmin':
|
1372
|
+
return noise_stats().min
|
1373
|
+
|
1374
|
+
if m == 'nmax':
|
1375
|
+
return noise_stats().max
|
1376
|
+
|
1377
|
+
if m == 'npkdb':
|
1378
|
+
return noise_stats().pkdb
|
1379
|
+
|
1380
|
+
if m == 'nlrms':
|
1381
|
+
return noise_stats().lrms
|
1382
|
+
|
1383
|
+
if m == 'npkr':
|
1384
|
+
return noise_stats().pkr
|
1385
|
+
|
1386
|
+
if m == 'ntr':
|
1387
|
+
return noise_stats().tr
|
1388
|
+
|
1389
|
+
if m == 'ncr':
|
1390
|
+
return noise_stats().cr
|
1391
|
+
|
1392
|
+
if m == 'nfl':
|
1393
|
+
return noise_stats().fl
|
1394
|
+
|
1395
|
+
if m == 'npkc':
|
1396
|
+
return noise_stats().pkc
|
1397
|
+
|
1398
|
+
if m == 'sedavg':
|
1399
|
+
return 0
|
1400
|
+
|
1401
|
+
if m == 'sedcnt':
|
1402
|
+
return 0
|
1403
|
+
|
1404
|
+
if m == 'sedtopn':
|
1405
|
+
return 0
|
1406
|
+
|
1407
|
+
if m == 'ssnr':
|
1408
|
+
return self.mixture_segsnr(m_id)
|
1409
|
+
|
1410
|
+
raise SonusAIError(f"Unrecognized metric: '{m}'")
|
1411
|
+
|
1412
|
+
result: list[float | int | Segsnr] = []
|
1413
|
+
for metric in metrics:
|
1414
|
+
result.append(calc(metric))
|
1415
|
+
|
1416
|
+
return result
|
1281
1417
|
|
1282
1418
|
|
1283
1419
|
@lru_cache
|
@@ -1,8 +1,10 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
1
3
|
from sonusai.mixture.datatypes import AudioT
|
2
4
|
from sonusai.mixture.datatypes import ImpulseResponseData
|
3
5
|
|
4
6
|
|
5
|
-
def _raw_read(name: str) -> tuple[AudioT, int]:
|
7
|
+
def _raw_read(name: str | Path) -> tuple[AudioT, int]:
|
6
8
|
import numpy as np
|
7
9
|
import soundfile
|
8
10
|
from pydub import AudioSegment
|
@@ -34,7 +36,7 @@ def _raw_read(name: str) -> tuple[AudioT, int]:
|
|
34
36
|
return np.squeeze(raw[:, 0]), sample_rate
|
35
37
|
|
36
38
|
|
37
|
-
def get_sample_rate(name: str) -> int:
|
39
|
+
def get_sample_rate(name: str | Path) -> int:
|
38
40
|
"""Get sample rate from audio file using soundfile
|
39
41
|
|
40
42
|
:param name: File name
|
@@ -63,7 +65,7 @@ def get_sample_rate(name: str) -> int:
|
|
63
65
|
raise SonusAIError(f'Error reading {name}: {e}')
|
64
66
|
|
65
67
|
|
66
|
-
def read_ir(name: str) -> ImpulseResponseData:
|
68
|
+
def read_ir(name: str | Path) -> ImpulseResponseData:
|
67
69
|
"""Read impulse response data using soundfile
|
68
70
|
|
69
71
|
:param name: File name
|
@@ -79,10 +81,10 @@ def read_ir(name: str) -> ImpulseResponseData:
|
|
79
81
|
out = out[offset:]
|
80
82
|
out = out / np.linalg.norm(out)
|
81
83
|
|
82
|
-
return ImpulseResponseData(name=name, sample_rate=sample_rate, data=out)
|
84
|
+
return ImpulseResponseData(name=str(name), sample_rate=sample_rate, data=out)
|
83
85
|
|
84
86
|
|
85
|
-
def read_audio(name: str) -> AudioT:
|
87
|
+
def read_audio(name: str | Path) -> AudioT:
|
86
88
|
"""Read audio data from a file using soundfile
|
87
89
|
|
88
90
|
:param name: File name
|
@@ -101,7 +103,7 @@ def read_audio(name: str) -> AudioT:
|
|
101
103
|
return out
|
102
104
|
|
103
105
|
|
104
|
-
def get_num_samples(name: str) -> int:
|
106
|
+
def get_num_samples(name: str | Path) -> int:
|
105
107
|
"""Get the number of samples resampled to the SonusAI sample rate in the given file
|
106
108
|
|
107
109
|
:param name: File name
|
sonusai/mixture/sox_audio.py
CHANGED
@@ -1,16 +1,19 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import numpy as np
|
1
5
|
from sox import Transformer as SoxTransformer
|
2
6
|
|
3
7
|
from sonusai.mixture.datatypes import AudioT
|
4
8
|
from sonusai.mixture.datatypes import ImpulseResponseData
|
5
9
|
|
6
10
|
|
7
|
-
def read_impulse_response(name: str) -> ImpulseResponseData:
|
11
|
+
def read_impulse_response(name: str | Path) -> ImpulseResponseData:
|
8
12
|
"""Read impulse response data using SoX
|
9
13
|
|
10
14
|
:param name: File name
|
11
15
|
:return: ImpulseResponseData object
|
12
16
|
"""
|
13
|
-
import numpy as np
|
14
17
|
from scipy.io import wavfile
|
15
18
|
|
16
19
|
from sonusai import SonusAIError
|
@@ -33,10 +36,10 @@ def read_impulse_response(name: str) -> ImpulseResponseData:
|
|
33
36
|
data = data[offset:]
|
34
37
|
data = data / np.linalg.norm(data)
|
35
38
|
|
36
|
-
return ImpulseResponseData(name=name, sample_rate=sample_rate, data=data)
|
39
|
+
return ImpulseResponseData(name=str(name), sample_rate=sample_rate, data=data)
|
37
40
|
|
38
41
|
|
39
|
-
def read_audio(name: str) -> AudioT:
|
42
|
+
def read_audio(name: str | Path) -> AudioT:
|
40
43
|
"""Read audio data from a file using SoX
|
41
44
|
|
42
45
|
:param name: File name
|
@@ -44,7 +47,6 @@ def read_audio(name: str) -> AudioT:
|
|
44
47
|
"""
|
45
48
|
from typing import Any
|
46
49
|
|
47
|
-
import numpy as np
|
48
50
|
from sox.core import sox
|
49
51
|
|
50
52
|
from sonusai import SonusAIError
|
@@ -208,8 +210,11 @@ class Transformer(SoxTransformer):
|
|
208
210
|
|
209
211
|
return self
|
210
212
|
|
211
|
-
def build_array(self,
|
212
|
-
|
213
|
+
def build_array(self,
|
214
|
+
input_filepath: Optional[str | Path] = None,
|
215
|
+
input_array: Optional[np.ndarray] = None,
|
216
|
+
sample_rate_in: Optional[int] = None,
|
217
|
+
extra_args: Optional[list[str]] = None) -> np.ndarray:
|
213
218
|
"""Given an input file or array, returns the output as a numpy array
|
214
219
|
by executing the current set of commands. By default, the array will
|
215
220
|
have the same sample rate as the input file unless otherwise specified
|
@@ -220,7 +225,7 @@ class Transformer(SoxTransformer):
|
|
220
225
|
|
221
226
|
Parameters
|
222
227
|
----------
|
223
|
-
input_filepath : str or None
|
228
|
+
input_filepath : str, Path or None
|
224
229
|
Either path to input audio file or None.
|
225
230
|
input_array : np.ndarray or None
|
226
231
|
A np.ndarray of a waveform with shape (n_samples, n_channels).
|
@@ -270,8 +275,6 @@ class Transformer(SoxTransformer):
|
|
270
275
|
|
271
276
|
|
272
277
|
"""
|
273
|
-
import numpy as np
|
274
|
-
|
275
278
|
from sox.core import SoxError
|
276
279
|
from sox.core import sox
|
277
280
|
from sox.log import logger
|
@@ -324,13 +327,13 @@ class Transformer(SoxTransformer):
|
|
324
327
|
|
325
328
|
match n_bits:
|
326
329
|
case 8:
|
327
|
-
encoding_out = np.int8
|
330
|
+
encoding_out = np.int8 # type: ignore
|
328
331
|
case 16:
|
329
332
|
encoding_out = np.int16
|
330
333
|
case 32:
|
331
|
-
encoding_out = np.float32
|
334
|
+
encoding_out = np.float32 # type: ignore
|
332
335
|
case 64:
|
333
|
-
encoding_out = np.float64
|
336
|
+
encoding_out = np.float64 # type: ignore
|
334
337
|
case _:
|
335
338
|
raise ValueError("invalid n_bits {}".format(n_bits))
|
336
339
|
|