sonusai 0.18.2__py3-none-any.whl → 0.18.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/mixture/mixdb.py CHANGED
@@ -6,6 +6,7 @@ from sqlite3 import Cursor
6
6
  from typing import Any
7
7
  from typing import Optional
8
8
 
9
+ from sonusai.mixture.datatypes import ASRConfigs
9
10
  from sonusai.mixture.datatypes import AudioF
10
11
  from sonusai.mixture.datatypes import AudioT
11
12
  from sonusai.mixture.datatypes import AudiosF
@@ -88,6 +89,7 @@ class MixtureDatabase:
88
89
  from .datatypes import MixtureDatabaseConfig
89
90
 
90
91
  config = MixtureDatabaseConfig(
92
+ asr_configs=self.asr_configs,
91
93
  class_balancing=self.class_balancing,
92
94
  class_labels=self.class_labels,
93
95
  class_weights_threshold=self.class_weights_thresholds,
@@ -145,6 +147,30 @@ class MixtureDatabase:
145
147
  with self.db() as c:
146
148
  return str(c.execute("SELECT top.noise_mix_mode FROM top").fetchone()[0])
147
149
 
150
+ @cached_property
151
+ def asr_configs(self) -> ASRConfigs:
152
+ import json
153
+
154
+ with self.db() as c:
155
+ return json.loads(c.execute("SELECT top.asr_configs FROM top").fetchone()[0])
156
+
157
+ @cached_property
158
+ def supported_metrics(self) -> set[str]:
159
+ metrics = {
160
+ 'mxssnravg', 'mxssnrvar', 'mxssnrdavg', 'mxssnrdstd',
161
+ 'mxpesq', 'mxcsig', 'mxcbak', 'mxcovl', 'mxwsdr',
162
+ 'mxpd',
163
+ 'mxstoi',
164
+ 'tdco', 'tmin', 'tmax', 'tpkdb', 'tlrms', 'tpkr', 'ttr', 'tcr', 'tfl', 'tpkc',
165
+ 'ndco', 'nmin', 'nmax', 'npkdb', 'nlrms', 'npkr', 'ntr', 'ncr', 'nfl', 'npkc',
166
+ 'sedavg', 'sedcnt', 'sedtopn',
167
+ 'ssnr',
168
+ }
169
+ for name in self.asr_configs:
170
+ metrics.add(f'mxwer.{name}')
171
+
172
+ return metrics
173
+
148
174
  @cached_property
149
175
  def class_balancing(self) -> bool:
150
176
  with self.db() as c:
@@ -1108,173 +1134,286 @@ class MixtureDatabase:
1108
1134
 
1109
1135
  return mixture_all_speech_metadata(self, self.mixture(m_id))
1110
1136
 
1111
- def mixture_metric(self, m_id: int, metric: str, force: bool = False) -> Any:
1112
- """Get metric data for the given mixture ID
1137
+ def mixture_metrics(self, m_id: int,
1138
+ metrics: list[str],
1139
+ force: bool = False) -> list[float | int | Segsnr]:
1140
+ """Get metrics data for the given mixture ID
1113
1141
 
1114
1142
  :param m_id: Zero-based mixture ID
1115
- :param metric: Metric data to retrieve
1143
+ :param metrics: List of metrics to get
1116
1144
  :param force: Force computing data from original sources regardless of whether cached data exists
1117
- :return: Metric data
1145
+ :return: List of metric data
1118
1146
  """
1147
+ from typing import Callable
1148
+
1149
+ import numpy as np
1150
+ from pystoi import stoi
1151
+
1119
1152
  from sonusai import SonusAIError
1153
+ from sonusai.metrics import calc_audio_stats
1154
+ from sonusai.metrics import calc_phase_distance
1155
+ from sonusai.metrics import calc_snr_f
1156
+ from sonusai.metrics import calc_speech
1157
+ from sonusai.metrics import calc_wer
1158
+ from sonusai.metrics import calc_wsdr
1159
+ from sonusai.mixture import SAMPLE_RATE
1160
+ from sonusai.mixture import AudioStatsMetrics
1161
+ from sonusai.mixture import SpeechMetrics
1162
+ from sonusai.utils import calc_asr
1120
1163
 
1121
- supported_metrics = (
1122
- 'MXSNR',
1123
- 'MXSSNRAVG',
1124
- 'MXSSNRSTD',
1125
- 'MXSSNRDAVG',
1126
- 'MXSSNRDSTD',
1127
- 'MXPESQ',
1128
- 'MXWSDR',
1129
- 'MXPD',
1130
- 'MXSTOI',
1131
- 'MXCSIG',
1132
- 'MXCBAK',
1133
- 'MXCOVL',
1134
- 'TDCO',
1135
- 'TMIN',
1136
- 'TMAX',
1137
- 'TPKDB',
1138
- 'TLRMS',
1139
- 'TPKR',
1140
- 'TTR',
1141
- 'TCR',
1142
- 'TFL',
1143
- 'TPKC',
1144
- 'NDCO',
1145
- 'NMIN',
1146
- 'NMAX',
1147
- 'NPKDB',
1148
- 'NLRMS',
1149
- 'NPKR',
1150
- 'NTR',
1151
- 'NCR',
1152
- 'NFL',
1153
- 'NPKC',
1154
- 'SEDAVG',
1155
- 'SEDCNT',
1156
- 'SEDTOPN',
1157
- )
1164
+ def create_target_audio() -> Callable:
1165
+ state = None
1158
1166
 
1159
- if not (metric in supported_metrics or metric.startswith('MXWER')):
1160
- raise SonusAIError(f'Unsupported metric: {metric}')
1167
+ def get() -> np.ndarray:
1168
+ nonlocal state
1169
+ if state is None:
1170
+ state = self.mixture_target(m_id)
1171
+ return state
1161
1172
 
1162
- if not force:
1163
- result = self.read_mixture_data(m_id, metric)
1164
- if result is not None:
1165
- return result
1173
+ return get
1166
1174
 
1167
- mixture = self.mixture(m_id)
1168
- if mixture is None:
1169
- raise SonusAIError(f'Could not find mixture for m_id: {m_id}')
1175
+ target_audio = create_target_audio()
1170
1176
 
1171
- if metric.startswith('MXWER'):
1172
- return None
1177
+ def create_noise_audio() -> Callable:
1178
+ state = None
1173
1179
 
1174
- if metric == 'MXSNR':
1175
- return self.snrs
1180
+ def get() -> np.ndarray:
1181
+ nonlocal state
1182
+ if state is None:
1183
+ state = self.mixture_noise(m_id)
1184
+ return state
1176
1185
 
1177
- if metric == 'MXSSNRAVG':
1178
- return None
1186
+ return get
1179
1187
 
1180
- if metric == 'MXSSNRSTD':
1181
- return None
1188
+ noise_audio = create_noise_audio()
1182
1189
 
1183
- if metric == 'MXSSNRDAVG':
1184
- return None
1190
+ def create_mixture_audio() -> Callable:
1191
+ state = None
1185
1192
 
1186
- if metric == 'MXSSNRDSTD':
1187
- return None
1193
+ def get() -> np.ndarray:
1194
+ nonlocal state
1195
+ if state is None:
1196
+ state = self.mixture_mixture(m_id)
1197
+ return state
1188
1198
 
1189
- if metric == 'MXPESQ':
1190
- return None
1199
+ return get
1191
1200
 
1192
- if metric == 'MXWSDR':
1193
- return None
1201
+ mixture_audio = create_mixture_audio()
1194
1202
 
1195
- if metric == 'MXPD':
1196
- return None
1203
+ def create_segsnr_f() -> Callable:
1204
+ state = None
1197
1205
 
1198
- if metric == 'MXSTOI':
1199
- return None
1206
+ def get() -> np.ndarray:
1207
+ nonlocal state
1208
+ if state is None:
1209
+ state = self.mixture_segsnr(m_id)
1210
+ return state
1200
1211
 
1201
- if metric == 'MXCSIG':
1202
- return None
1212
+ return get
1203
1213
 
1204
- if metric == 'MXCBAK':
1205
- return None
1214
+ segsnr_f = create_segsnr_f()
1206
1215
 
1207
- if metric == 'MXCOVL':
1208
- return None
1216
+ def create_speech() -> Callable:
1217
+ state = None
1209
1218
 
1210
- if metric == 'TDCO':
1211
- return None
1219
+ def get() -> SpeechMetrics:
1220
+ nonlocal state
1221
+ if state is None:
1222
+ state = calc_speech(hypothesis=mixture_audio(), reference=target_audio())
1223
+ return state
1212
1224
 
1213
- if metric == 'TMIN':
1214
- return None
1225
+ return get
1215
1226
 
1216
- if metric == 'TMAX':
1217
- return None
1227
+ speech = create_speech()
1218
1228
 
1219
- if metric == 'TPKDB':
1220
- return None
1229
+ def create_target_stats() -> Callable:
1230
+ state = None
1221
1231
 
1222
- if metric == 'TLRMS':
1223
- return None
1232
+ def get() -> AudioStatsMetrics:
1233
+ nonlocal state
1234
+ if state is None:
1235
+ state = calc_audio_stats(target_audio(), self.fg_info.ft_config.N / SAMPLE_RATE)
1236
+ return state
1224
1237
 
1225
- if metric == 'TPKR':
1226
- return None
1238
+ return get
1227
1239
 
1228
- if metric == 'TTR':
1229
- return None
1240
+ target_stats = create_target_stats()
1230
1241
 
1231
- if metric == 'TCR':
1232
- return None
1242
+ def create_noise_stats() -> Callable:
1243
+ state = None
1233
1244
 
1234
- if metric == 'TFL':
1235
- return None
1245
+ def get() -> AudioStatsMetrics:
1246
+ nonlocal state
1247
+ if state is None:
1248
+ state = calc_audio_stats(noise_audio(), self.fg_info.ft_config.N / SAMPLE_RATE)
1249
+ return state
1236
1250
 
1237
- if metric == 'TPKC':
1238
- return None
1251
+ return get
1239
1252
 
1240
- if metric == 'NDCO':
1241
- return None
1253
+ noise_stats = create_noise_stats()
1242
1254
 
1243
- if metric == 'NMIN':
1244
- return None
1255
+ def calc(m: str) -> float | int | Segsnr:
1256
+ if m == 'mxsnr':
1257
+ return self.mixture(m_id).snr
1245
1258
 
1246
- if metric == 'NMAX':
1247
- return None
1259
+ # Get cached data first, if exists
1260
+ if not force:
1261
+ value = self.read_mixture_data(m_id, m)
1262
+ if value is not None:
1263
+ return value
1248
1264
 
1249
- if metric == 'NPKDB':
1250
- return None
1265
+ # Otherwise, generate data as needed
1266
+ if m.startswith('mxwer'):
1267
+ parts = m.split('.')
1268
+ if len(parts) != 3:
1269
+ raise SonusAIError(
1270
+ f"Unrecognized 'mwwer' format: '{m}'; must be of the form: 'mxwer.<engine>.<model>'")
1271
+ asr_engine = parts[1]
1272
+ asr_model = parts[2]
1251
1273
 
1252
- if metric == 'NLRMS':
1253
- return None
1274
+ if asr_engine == 'none' or self.mixture(m_id).snr < -96:
1275
+ # noise only, ignore/reset target asr
1276
+ return float('nan')
1254
1277
 
1255
- if metric == 'NPKR':
1256
- return None
1278
+ # ignore mixup
1279
+ target_asr = self.mixture_speech_metadata(m_id, 'text')[0]
1280
+ if target_asr is None:
1281
+ target_asr = calc_asr(target_audio(), engine=asr_engine, whisper_model_name=asr_model).text
1257
1282
 
1258
- if metric == 'NTR':
1259
- return None
1283
+ if target_asr:
1284
+ mixture_asr = calc_asr(mixture_audio(), engine=asr_engine, whisper_model_name=asr_model).text
1285
+ return calc_wer(mixture_asr, target_asr).wer * 100
1260
1286
 
1261
- if metric == 'NCR':
1262
- return None
1287
+ # TODO: should this be NaN like above?
1288
+ return float(0)
1263
1289
 
1264
- if metric == 'NFL':
1265
- return None
1290
+ if m == 'mxssnravg':
1291
+ return calc_snr_f(segsnr_f()).mean
1266
1292
 
1267
- if metric == 'NPKC':
1268
- return None
1293
+ if m == 'mxssnrvar':
1294
+ return calc_snr_f(segsnr_f()).var
1269
1295
 
1270
- if metric == 'SEDAVG':
1271
- return None
1296
+ if m == 'mxssnrdavg':
1297
+ return calc_snr_f(segsnr_f()).db_mean
1272
1298
 
1273
- if metric == 'SEDCNT':
1274
- return None
1299
+ if m == 'mxssnrdstd':
1300
+ return calc_snr_f(segsnr_f()).db_std
1275
1301
 
1276
- if metric == 'SEDTOPN':
1277
- return None
1302
+ if m == 'mxpesq':
1303
+ if self.mixture(m_id).snr < -96:
1304
+ return 0
1305
+ return speech().pesq
1306
+
1307
+ if m == 'mxcsig':
1308
+ if self.mixture(m_id).snr < -96:
1309
+ return 0
1310
+ return speech().c_sig
1311
+
1312
+ if m == 'mxcbak':
1313
+ if self.mixture(m_id).snr < -96:
1314
+ return 0
1315
+ return speech().c_bak
1316
+
1317
+ if m == 'mxcovl':
1318
+ if self.mixture(m_id).snr < -96:
1319
+ return 0
1320
+ return speech().c_ovl
1321
+
1322
+ if m == 'mxwsdr':
1323
+ mixture = mixture_audio()[:, np.newaxis]
1324
+ target = target_audio()[:, np.newaxis]
1325
+ noise = noise_audio()[:, np.newaxis]
1326
+ return calc_wsdr(hypothesis=np.concatenate((mixture, noise), axis=1),
1327
+ reference=np.concatenate((target, noise), axis=1),
1328
+ with_log=True)[0]
1329
+
1330
+ if m == 'mxpd':
1331
+ mixture_f = self.mixture_mixture_f(m_id)
1332
+ target_f = self.mixture_target_f(m_id)
1333
+ return calc_phase_distance(hypothesis=mixture_f, reference=target_f)[0]
1334
+
1335
+ if m == 'mxstoi':
1336
+ return stoi(x=target_audio(), y=mixture_audio(), fs_sig=SAMPLE_RATE, extended=False)
1337
+
1338
+ if m == 'tdco':
1339
+ return target_stats().dco
1340
+
1341
+ if m == 'tmin':
1342
+ return target_stats().min
1343
+
1344
+ if m == 'tmax':
1345
+ return target_stats().max
1346
+
1347
+ if m == 'tpkdb':
1348
+ return target_stats().pkdb
1349
+
1350
+ if m == 'tlrms':
1351
+ return target_stats().lrms
1352
+
1353
+ if m == 'tpkr':
1354
+ return target_stats().pkr
1355
+
1356
+ if m == 'ttr':
1357
+ return target_stats().tr
1358
+
1359
+ if m == 'tcr':
1360
+ return target_stats().cr
1361
+
1362
+ if m == 'tfl':
1363
+ return target_stats().fl
1364
+
1365
+ if m == 'tpkc':
1366
+ return target_stats().pkc
1367
+
1368
+ if m == 'ndco':
1369
+ return noise_stats().dco
1370
+
1371
+ if m == 'nmin':
1372
+ return noise_stats().min
1373
+
1374
+ if m == 'nmax':
1375
+ return noise_stats().max
1376
+
1377
+ if m == 'npkdb':
1378
+ return noise_stats().pkdb
1379
+
1380
+ if m == 'nlrms':
1381
+ return noise_stats().lrms
1382
+
1383
+ if m == 'npkr':
1384
+ return noise_stats().pkr
1385
+
1386
+ if m == 'ntr':
1387
+ return noise_stats().tr
1388
+
1389
+ if m == 'ncr':
1390
+ return noise_stats().cr
1391
+
1392
+ if m == 'nfl':
1393
+ return noise_stats().fl
1394
+
1395
+ if m == 'npkc':
1396
+ return noise_stats().pkc
1397
+
1398
+ if m == 'sedavg':
1399
+ return 0
1400
+
1401
+ if m == 'sedcnt':
1402
+ return 0
1403
+
1404
+ if m == 'sedtopn':
1405
+ return 0
1406
+
1407
+ if m == 'ssnr':
1408
+ return self.mixture_segsnr(m_id)
1409
+
1410
+ raise SonusAIError(f"Unrecognized metric: '{m}'")
1411
+
1412
+ result: list[float | int | Segsnr] = []
1413
+ for metric in metrics:
1414
+ result.append(calc(metric))
1415
+
1416
+ return result
1278
1417
 
1279
1418
 
1280
1419
  @lru_cache
@@ -1,8 +1,10 @@
1
+ from pathlib import Path
2
+
1
3
  from sonusai.mixture.datatypes import AudioT
2
4
  from sonusai.mixture.datatypes import ImpulseResponseData
3
5
 
4
6
 
5
- def _raw_read(name: str) -> tuple[AudioT, int]:
7
+ def _raw_read(name: str | Path) -> tuple[AudioT, int]:
6
8
  import numpy as np
7
9
  import soundfile
8
10
  from pydub import AudioSegment
@@ -34,7 +36,7 @@ def _raw_read(name: str) -> tuple[AudioT, int]:
34
36
  return np.squeeze(raw[:, 0]), sample_rate
35
37
 
36
38
 
37
- def get_sample_rate(name: str) -> int:
39
+ def get_sample_rate(name: str | Path) -> int:
38
40
  """Get sample rate from audio file using soundfile
39
41
 
40
42
  :param name: File name
@@ -63,7 +65,7 @@ def get_sample_rate(name: str) -> int:
63
65
  raise SonusAIError(f'Error reading {name}: {e}')
64
66
 
65
67
 
66
- def read_ir(name: str) -> ImpulseResponseData:
68
+ def read_ir(name: str | Path) -> ImpulseResponseData:
67
69
  """Read impulse response data using soundfile
68
70
 
69
71
  :param name: File name
@@ -79,10 +81,10 @@ def read_ir(name: str) -> ImpulseResponseData:
79
81
  out = out[offset:]
80
82
  out = out / np.linalg.norm(out)
81
83
 
82
- return ImpulseResponseData(name=name, sample_rate=sample_rate, data=out)
84
+ return ImpulseResponseData(name=str(name), sample_rate=sample_rate, data=out)
83
85
 
84
86
 
85
- def read_audio(name: str) -> AudioT:
87
+ def read_audio(name: str | Path) -> AudioT:
86
88
  """Read audio data from a file using soundfile
87
89
 
88
90
  :param name: File name
@@ -101,7 +103,7 @@ def read_audio(name: str) -> AudioT:
101
103
  return out
102
104
 
103
105
 
104
- def get_num_samples(name: str) -> int:
106
+ def get_num_samples(name: str | Path) -> int:
105
107
  """Get the number of samples resampled to the SonusAI sample rate in the given file
106
108
 
107
109
  :param name: File name
@@ -1,16 +1,19 @@
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ import numpy as np
1
5
  from sox import Transformer as SoxTransformer
2
6
 
3
7
  from sonusai.mixture.datatypes import AudioT
4
8
  from sonusai.mixture.datatypes import ImpulseResponseData
5
9
 
6
10
 
7
- def read_impulse_response(name: str) -> ImpulseResponseData:
11
+ def read_impulse_response(name: str | Path) -> ImpulseResponseData:
8
12
  """Read impulse response data using SoX
9
13
 
10
14
  :param name: File name
11
15
  :return: ImpulseResponseData object
12
16
  """
13
- import numpy as np
14
17
  from scipy.io import wavfile
15
18
 
16
19
  from sonusai import SonusAIError
@@ -33,10 +36,10 @@ def read_impulse_response(name: str) -> ImpulseResponseData:
33
36
  data = data[offset:]
34
37
  data = data / np.linalg.norm(data)
35
38
 
36
- return ImpulseResponseData(name=name, sample_rate=sample_rate, data=data)
39
+ return ImpulseResponseData(name=str(name), sample_rate=sample_rate, data=data)
37
40
 
38
41
 
39
- def read_audio(name: str) -> AudioT:
42
+ def read_audio(name: str | Path) -> AudioT:
40
43
  """Read audio data from a file using SoX
41
44
 
42
45
  :param name: File name
@@ -44,7 +47,6 @@ def read_audio(name: str) -> AudioT:
44
47
  """
45
48
  from typing import Any
46
49
 
47
- import numpy as np
48
50
  from sox.core import sox
49
51
 
50
52
  from sonusai import SonusAIError
@@ -208,8 +210,11 @@ class Transformer(SoxTransformer):
208
210
 
209
211
  return self
210
212
 
211
- def build_array(self, input_filepath=None, input_array=None,
212
- sample_rate_in=None, extra_args=None):
213
+ def build_array(self,
214
+ input_filepath: Optional[str | Path] = None,
215
+ input_array: Optional[np.ndarray] = None,
216
+ sample_rate_in: Optional[int] = None,
217
+ extra_args: Optional[list[str]] = None) -> np.ndarray:
213
218
  """Given an input file or array, returns the output as a numpy array
214
219
  by executing the current set of commands. By default, the array will
215
220
  have the same sample rate as the input file unless otherwise specified
@@ -220,7 +225,7 @@ class Transformer(SoxTransformer):
220
225
 
221
226
  Parameters
222
227
  ----------
223
- input_filepath : str or None
228
+ input_filepath : str, Path or None
224
229
  Either path to input audio file or None.
225
230
  input_array : np.ndarray or None
226
231
  A np.ndarray of a waveform with shape (n_samples, n_channels).
@@ -270,8 +275,6 @@ class Transformer(SoxTransformer):
270
275
 
271
276
 
272
277
  """
273
- import numpy as np
274
-
275
278
  from sox.core import SoxError
276
279
  from sox.core import sox
277
280
  from sox.log import logger
@@ -324,13 +327,13 @@ class Transformer(SoxTransformer):
324
327
 
325
328
  match n_bits:
326
329
  case 8:
327
- encoding_out = np.int8
330
+ encoding_out = np.int8 # type: ignore
328
331
  case 16:
329
332
  encoding_out = np.int16
330
333
  case 32:
331
- encoding_out = np.float32
334
+ encoding_out = np.float32 # type: ignore
332
335
  case 64:
333
- encoding_out = np.float64
336
+ encoding_out = np.float64 # type: ignore
334
337
  case _:
335
338
  raise ValueError("invalid n_bits {}".format(n_bits))
336
339
 
@@ -1,8 +1,10 @@
1
+ from pathlib import Path
2
+
1
3
  from sonusai.mixture.datatypes import AudioT
2
4
  from sonusai.mixture.datatypes import ImpulseResponseData
3
5
 
4
6
 
5
- def read_impulse_response(name: str) -> ImpulseResponseData:
7
+ def read_impulse_response(name: str | Path) -> ImpulseResponseData:
6
8
  """Read impulse response data using torchaudio
7
9
 
8
10
  :param name: File name
@@ -36,10 +38,10 @@ def read_impulse_response(name: str) -> ImpulseResponseData:
36
38
  data = np.array(raw).astype(np.float32)
37
39
  data = data / np.linalg.norm(data)
38
40
 
39
- return ImpulseResponseData(name=name, sample_rate=sample_rate, data=data)
41
+ return ImpulseResponseData(name=str(name), sample_rate=sample_rate, data=data)
40
42
 
41
43
 
42
- def get_sample_rate(name: str) -> int:
44
+ def get_sample_rate(name: str | Path) -> int:
43
45
  """Get sample rate from audio file using torchaudio
44
46
 
45
47
  :param name: File name
@@ -61,7 +63,7 @@ def get_sample_rate(name: str) -> int:
61
63
  raise SonusAIError(f'Error reading {name}:\n{e}')
62
64
 
63
65
 
64
- def read_audio(name: str) -> AudioT:
66
+ def read_audio(name: str | Path) -> AudioT:
65
67
  """Read audio data from a file using torchaudio
66
68
 
67
69
  :param name: File name