sonusai 0.18.6__py3-none-any.whl → 0.18.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/mixture/mixdb.py CHANGED
@@ -17,6 +17,8 @@ from sonusai.mixture.datatypes import FeatureGeneratorConfig
17
17
  from sonusai.mixture.datatypes import FeatureGeneratorInfo
18
18
  from sonusai.mixture.datatypes import GeneralizedIDs
19
19
  from sonusai.mixture.datatypes import ImpulseResponseFiles
20
+ from sonusai.mixture.datatypes import MetricDoc
21
+ from sonusai.mixture.datatypes import MetricDocs
20
22
  from sonusai.mixture.datatypes import Mixture
21
23
  from sonusai.mixture.datatypes import Mixtures
22
24
  from sonusai.mixture.datatypes import NoiseFile
@@ -155,19 +157,69 @@ class MixtureDatabase:
155
157
  return json.loads(c.execute("SELECT top.asr_configs FROM top").fetchone()[0])
156
158
 
157
159
  @cached_property
158
- def supported_metrics(self) -> set[str]:
159
- metrics = {
160
- 'mxssnravg', 'mxssnrvar', 'mxssnrdavg', 'mxssnrdstd',
161
- 'mxpesq', 'mxcsig', 'mxcbak', 'mxcovl', 'mxwsdr',
162
- 'mxpd',
163
- 'mxstoi',
164
- 'tdco', 'tmin', 'tmax', 'tpkdb', 'tlrms', 'tpkr', 'ttr', 'tcr', 'tfl', 'tpkc',
165
- 'ndco', 'nmin', 'nmax', 'npkdb', 'nlrms', 'npkr', 'ntr', 'ncr', 'nfl', 'npkc',
166
- 'sedavg', 'sedcnt', 'sedtopn',
167
- 'ssnr',
168
- }
160
+ def supported_metrics(self) -> MetricDocs:
161
+ metrics = MetricDocs([
162
+ MetricDoc('Mixture Metrics', 'mxsnr', 'SNR specification in dB'),
163
+ MetricDoc('Mixture Metrics', 'mxssnr_avg', 'Segmental SNR average over all frames'),
164
+ MetricDoc('Mixture Metrics', 'mxssnr_std', 'Segmental SNR standard deviation over all frames'),
165
+ MetricDoc('Mixture Metrics', 'mxssnrdb_avg',
166
+ 'Segmental SNR average of the dB frame values over all frames'),
167
+ MetricDoc('Mixture Metrics', 'mxssnrdb_std',
168
+ 'Segmental SNR standard deviation of the dB frame values over all frames'),
169
+ MetricDoc('Mixture Metrics', 'mxssnrf_avg',
170
+ 'Per-bin segmental SNR average over all frames (using feature transform)'),
171
+ MetricDoc('Mixture Metrics', 'mxssnrf_std',
172
+ 'Per-bin segmental SNR standard deviation over all frames (using feature transform)'),
173
+ MetricDoc('Mixture Metrics', 'mxssnrdbf_avg',
174
+ 'Per-bin segmental average of the dB frame values over all frames (using feature transform)'),
175
+ MetricDoc('Mixture Metrics', 'mxssnrdbf_std',
176
+ 'Per-bin segmental standard deviation of the dB frame values over all frames (using feature transform)'),
177
+ MetricDoc('Mixture Metrics', 'mxpesq', 'PESQ of mixture versus true target[0]'),
178
+ MetricDoc('Mixture Metrics', 'mxwsdr', 'Weighted signal distorion ratio of mixture versus true target[0]'),
179
+ MetricDoc('Mixture Metrics', 'mxpd', 'Phase distance between mixture and true target[0]'),
180
+ MetricDoc('Mixture Metrics', 'mxstoi',
181
+ 'Short term objective intelligibility of mixture versus true target[0]'),
182
+ MetricDoc('Mixture Metrics', 'mxcsig',
183
+ 'Predicted rating of speech distortion of mixture versus true target[0]'),
184
+ MetricDoc('Mixture Metrics', 'mxcbak',
185
+ 'Predicted rating of background distortion of mixture versus true target[0]'),
186
+ MetricDoc('Mixture Metrics', 'mxcovl',
187
+ 'Predicted rating of overall quality of mixture versus true target[0]'),
188
+ MetricDoc('Mixture Metrics', 'ssnr', 'Segmental SNR'),
189
+ MetricDoc('Target Metrics', 'tdco', 'Target[0] DC offset'),
190
+ MetricDoc('Target Metrics', 'tmin', 'Target[0] min level'),
191
+ MetricDoc('Target Metrics', 'tmax', 'Target[0] max levl'),
192
+ MetricDoc('Target Metrics', 'tpkdb', 'Target[0] Pk lev dB'),
193
+ MetricDoc('Target Metrics', 'tlrms', 'Target[0] RMS lev dB'),
194
+ MetricDoc('Target Metrics', 'tpkr', 'Target[0] RMS Pk dB'),
195
+ MetricDoc('Target Metrics', 'ttr', 'Target[0] RMS Tr dB'),
196
+ MetricDoc('Target Metrics', 'tcr', 'Target[0] Crest factor'),
197
+ MetricDoc('Target Metrics', 'tfl', 'Target[0] Flat factor'),
198
+ MetricDoc('Target Metrics', 'tpkc', 'Target[0] Pk count'),
199
+ MetricDoc('Noise Metrics', 'ndco', 'Noise DC offset'),
200
+ MetricDoc('Noise Metrics', 'nmin', 'Noise min level'),
201
+ MetricDoc('Noise Metrics', 'nmax', 'Noise max levl'),
202
+ MetricDoc('Noise Metrics', 'npkdb', 'Noise Pk lev dB'),
203
+ MetricDoc('Noise Metrics', 'nlrms', 'Noise RMS lev dB'),
204
+ MetricDoc('Noise Metrics', 'npkr', 'Noise RMS Pk dB'),
205
+ MetricDoc('Noise Metrics', 'ntr', 'Noise RMS Tr dB'),
206
+ MetricDoc('Noise Metrics', 'ncr', 'Noise Crest factor'),
207
+ MetricDoc('Noise Metrics', 'nfl', 'Noise Flat factor'),
208
+ MetricDoc('Noise Metrics', 'npkc', 'Noise Pk count'),
209
+ MetricDoc('Truth Metrics', 'sedavg',
210
+ '(not implemented) Average SED activity over all frames [num_classes, 1]'),
211
+ MetricDoc('Truth Metrics', 'sedcnt',
212
+ '(not implemented) Count in number of frames that SED is active [num_classes, 1]'),
213
+ MetricDoc('Truth Metrics', 'sedtop3', '(not implemented) 3 most active by largest sedavg [3, 1]'),
214
+ MetricDoc('Truth Metrics', 'sedtopn', '(not implemented) N most active by largest sedavg [N, 1]'),
215
+ ])
169
216
  for name in self.asr_configs:
170
- metrics.add(f'mxwer.{name}')
217
+ metrics.append(MetricDoc('Target Metrics', f'tasr.{name}',
218
+ f'Target[0] ASR text using {name} ASR as defined in mixdb asr_configs parameter'))
219
+ metrics.append(MetricDoc('Mixture Metrics', f'mxasr.{name}',
220
+ f'ASR text using {name} ASR as defined in mixdb asr_configs parameter'))
221
+ metrics.append(MetricDoc('Mixture Metrics', f'mxwer.{name}',
222
+ f'Word error rate using {name} ASR as defined in mixdb asr_configs parameter'))
171
223
 
172
224
  return metrics
173
225
 
@@ -240,11 +292,15 @@ class MixtureDatabase:
240
292
  def total_feature_frames(self, m_ids: GeneralizedIDs = '*') -> int:
241
293
  return self.total_samples(m_ids) // self.feature_step_samples
242
294
 
243
- def mixture_transform_frames(self, samples: int) -> int:
244
- return samples // self.ft_config.R
295
+ def mixture_transform_frames(self, m_id: int) -> int:
296
+ from .helpers import frames_from_samples
245
297
 
246
- def mixture_feature_frames(self, samples: int) -> int:
247
- return samples // self.feature_step_samples
298
+ return frames_from_samples(self.mixture(m_id).samples, self.ft_config.R)
299
+
300
+ def mixture_feature_frames(self, m_id: int) -> int:
301
+ from .helpers import frames_from_samples
302
+
303
+ return frames_from_samples(self.mixture(m_id).samples, self.feature_step_samples)
248
304
 
249
305
  def mixids_to_list(self, m_ids: Optional[GeneralizedIDs] = None) -> list[int]:
250
306
  """Resolve generalized mixture IDs to a list of integers
@@ -907,8 +963,8 @@ class MixtureDatabase:
907
963
  truth_t = self.mixture_truth_t(m_id=m_id, targets=targets, noise=noise, force=force)
908
964
 
909
965
  m = self.mixture(m_id)
910
- transform_frames = self.mixture_transform_frames(m.samples)
911
- feature_frames = self.mixture_feature_frames(m.samples)
966
+ transform_frames = self.mixture_transform_frames(m_id)
967
+ feature_frames = self.mixture_feature_frames(m_id)
912
968
 
913
969
  if truth_t is None:
914
970
  truth_t = np.zeros((m.samples, self.num_classes), dtype=np.float32)
@@ -1133,7 +1189,7 @@ class MixtureDatabase:
1133
1189
 
1134
1190
  def mixture_metrics(self, m_id: int,
1135
1191
  metrics: list[str],
1136
- force: bool = False) -> list[float | int | Segsnr]:
1192
+ force: bool = False) -> list[float | int | str | Segsnr]:
1137
1193
  """Get metrics data for the given mixture ID
1138
1194
 
1139
1195
  :param m_id: Zero-based mixture ID
@@ -1149,7 +1205,8 @@ class MixtureDatabase:
1149
1205
  from sonusai import SonusAIError
1150
1206
  from sonusai.metrics import calc_audio_stats
1151
1207
  from sonusai.metrics import calc_phase_distance
1152
- from sonusai.metrics import calc_snr_f
1208
+ from sonusai.metrics import calc_segsnr_f
1209
+ from sonusai.metrics import calc_segsnr_f_bin
1153
1210
  from sonusai.metrics import calc_speech
1154
1211
  from sonusai.metrics import calc_wer
1155
1212
  from sonusai.metrics import calc_wsdr
@@ -1158,7 +1215,7 @@ class MixtureDatabase:
1158
1215
  from sonusai.mixture import SpeechMetrics
1159
1216
  from sonusai.utils import calc_asr
1160
1217
 
1161
- def create_target_audio() -> Callable:
1218
+ def create_target_audio() -> Callable[[], np.ndarray]:
1162
1219
  state = None
1163
1220
 
1164
1221
  def get() -> np.ndarray:
@@ -1171,7 +1228,20 @@ class MixtureDatabase:
1171
1228
 
1172
1229
  target_audio = create_target_audio()
1173
1230
 
1174
- def create_noise_audio() -> Callable:
1231
+ def create_target_f() -> Callable[[], np.ndarray]:
1232
+ state = None
1233
+
1234
+ def get() -> np.ndarray:
1235
+ nonlocal state
1236
+ if state is None:
1237
+ state = self.mixture_targets_f(m_id)[0]
1238
+ return state
1239
+
1240
+ return get
1241
+
1242
+ target_f = create_target_f()
1243
+
1244
+ def create_noise_audio() -> Callable[[], np.ndarray]:
1175
1245
  state = None
1176
1246
 
1177
1247
  def get() -> np.ndarray:
@@ -1184,7 +1254,20 @@ class MixtureDatabase:
1184
1254
 
1185
1255
  noise_audio = create_noise_audio()
1186
1256
 
1187
- def create_mixture_audio() -> Callable:
1257
+ def create_noise_f() -> Callable[[], np.ndarray]:
1258
+ state = None
1259
+
1260
+ def get() -> np.ndarray:
1261
+ nonlocal state
1262
+ if state is None:
1263
+ state = self.mixture_noise_f(m_id)
1264
+ return state
1265
+
1266
+ return get
1267
+
1268
+ noise_f = create_noise_f()
1269
+
1270
+ def create_mixture_audio() -> Callable[[], np.ndarray]:
1188
1271
  state = None
1189
1272
 
1190
1273
  def get() -> np.ndarray:
@@ -1197,7 +1280,7 @@ class MixtureDatabase:
1197
1280
 
1198
1281
  mixture_audio = create_mixture_audio()
1199
1282
 
1200
- def create_segsnr_f() -> Callable:
1283
+ def create_segsnr_f() -> Callable[[], np.ndarray]:
1201
1284
  state = None
1202
1285
 
1203
1286
  def get() -> np.ndarray:
@@ -1210,7 +1293,7 @@ class MixtureDatabase:
1210
1293
 
1211
1294
  segsnr_f = create_segsnr_f()
1212
1295
 
1213
- def create_speech() -> Callable:
1296
+ def create_speech() -> Callable[[], SpeechMetrics]:
1214
1297
  state = None
1215
1298
 
1216
1299
  def get() -> SpeechMetrics:
@@ -1223,7 +1306,7 @@ class MixtureDatabase:
1223
1306
 
1224
1307
  speech = create_speech()
1225
1308
 
1226
- def create_target_stats() -> Callable:
1309
+ def create_target_stats() -> Callable[[], AudioStatsMetrics]:
1227
1310
  state = None
1228
1311
 
1229
1312
  def get() -> AudioStatsMetrics:
@@ -1236,7 +1319,7 @@ class MixtureDatabase:
1236
1319
 
1237
1320
  target_stats = create_target_stats()
1238
1321
 
1239
- def create_noise_stats() -> Callable:
1322
+ def create_noise_stats() -> Callable[[], AudioStatsMetrics]:
1240
1323
  state = None
1241
1324
 
1242
1325
  def get() -> AudioStatsMetrics:
@@ -1249,7 +1332,56 @@ class MixtureDatabase:
1249
1332
 
1250
1333
  noise_stats = create_noise_stats()
1251
1334
 
1252
- def calc(m: str) -> float | int | Segsnr:
1335
+ def create_asr_config() -> Callable[[str], dict]:
1336
+ state: dict[str, dict] = {}
1337
+
1338
+ def get(asr_name) -> dict:
1339
+ nonlocal state
1340
+ if asr_name not in state:
1341
+ state[asr_name] = self.asr_configs.get(asr_name, None)
1342
+ if state[asr_name] is None:
1343
+ raise SonusAIError(f"Unrecognized ASR name: '{asr_name}'")
1344
+ return state[asr_name]
1345
+
1346
+ return get
1347
+
1348
+ asr_config = create_asr_config()
1349
+
1350
+ def create_target_asr() -> Callable[[str], str]:
1351
+ state: dict[str, str] = {}
1352
+
1353
+ def get(asr_name) -> str:
1354
+ nonlocal state
1355
+ if asr_name not in state:
1356
+ state[asr_name] = calc_asr(target_audio(), **asr_config(asr_name)).text
1357
+ return state[asr_name]
1358
+
1359
+ return get
1360
+
1361
+ target_asr = create_target_asr()
1362
+
1363
+ def create_mixture_asr() -> Callable[[str], str]:
1364
+ state: dict[str, str] = {}
1365
+
1366
+ def get(asr_name) -> str:
1367
+ nonlocal state
1368
+ if asr_name not in state:
1369
+ state[asr_name] = calc_asr(mixture_audio(), **asr_config(asr_name)).text
1370
+ return state[asr_name]
1371
+
1372
+ return get
1373
+
1374
+ mixture_asr = create_mixture_asr()
1375
+
1376
+ def get_asr_name(m: str) -> str:
1377
+ parts = m.split('.')
1378
+ if len(parts) != 2:
1379
+ raise SonusAIError(
1380
+ f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
1381
+ asr_name = parts[1]
1382
+ return asr_name
1383
+
1384
+ def calc(m: str) -> float | int | str | Segsnr:
1253
1385
  if m == 'mxsnr':
1254
1386
  return self.mixture(m_id).snr
1255
1387
 
@@ -1261,42 +1393,44 @@ class MixtureDatabase:
1261
1393
 
1262
1394
  # Otherwise, generate data as needed
1263
1395
  if m.startswith('mxwer'):
1264
- parts = m.split('.')
1265
- if len(parts) != 2:
1266
- raise SonusAIError(
1267
- f"Unrecognized 'mxwer' format: '{m}'; must be of the form: 'mxwer.<name>'")
1268
- asr_name = parts[1]
1269
- asr_config = self.asr_configs.get(asr_name, None)
1270
- if asr_config is None:
1271
- raise SonusAIError(f"Unrecognized metric: '{m}'")
1396
+ asr_name = get_asr_name(m)
1272
1397
 
1273
1398
  if self.mixture(m_id).snr < -96:
1274
1399
  # noise only, ignore/reset target asr
1275
1400
  return float('nan')
1276
1401
 
1277
- # ignore mixup
1278
- target_asr = self.mixture_speech_metadata(m_id, 'text')[0]
1279
- if target_asr is None:
1280
- target_asr = calc_asr(target_audio(), **asr_config).text
1281
-
1282
- if target_asr:
1283
- mixture_asr = calc_asr(mixture_audio(), **asr_config).text
1284
- return calc_wer(mixture_asr, target_asr).wer * 100
1402
+ if target_asr(asr_name):
1403
+ return calc_wer(mixture_asr(asr_name), target_asr(asr_name)).wer * 100
1285
1404
 
1286
1405
  # TODO: should this be NaN like above?
1287
1406
  return float(0)
1288
1407
 
1289
- if m == 'mxssnravg':
1290
- return calc_snr_f(segsnr_f()).mean
1408
+ if m.startswith('mxasr'):
1409
+ return mixture_asr(get_asr_name(m))
1410
+
1411
+ if m == 'mxssnr_avg':
1412
+ return calc_segsnr_f(segsnr_f()).avg
1291
1413
 
1292
- if m == 'mxssnrvar':
1293
- return calc_snr_f(segsnr_f()).var
1414
+ if m == 'mxssnr_std':
1415
+ return calc_segsnr_f(segsnr_f()).std
1294
1416
 
1295
- if m == 'mxssnrdavg':
1296
- return calc_snr_f(segsnr_f()).db_mean
1417
+ if m == 'mxssnrdb_avg':
1418
+ return calc_segsnr_f(segsnr_f()).db_avg
1297
1419
 
1298
- if m == 'mxssnrdstd':
1299
- return calc_snr_f(segsnr_f()).db_std
1420
+ if m == 'mxssnrdb_std':
1421
+ return calc_segsnr_f(segsnr_f()).db_std
1422
+
1423
+ if m == 'mxssnrf_avg':
1424
+ return calc_segsnr_f_bin(target_f(), noise_f()).avg
1425
+
1426
+ if m == 'mxssnrf_std':
1427
+ return calc_segsnr_f_bin(target_f(), noise_f()).std
1428
+
1429
+ if m == 'mxssnrdbf_avg':
1430
+ return calc_segsnr_f_bin(target_f(), noise_f()).db_avg
1431
+
1432
+ if m == 'mxssnrdbf_std':
1433
+ return calc_segsnr_f_bin(target_f(), noise_f()).db_std
1300
1434
 
1301
1435
  if m == 'mxpesq':
1302
1436
  if self.mixture(m_id).snr < -96:
@@ -1306,17 +1440,17 @@ class MixtureDatabase:
1306
1440
  if m == 'mxcsig':
1307
1441
  if self.mixture(m_id).snr < -96:
1308
1442
  return 0
1309
- return speech().c_sig
1443
+ return speech().csig
1310
1444
 
1311
1445
  if m == 'mxcbak':
1312
1446
  if self.mixture(m_id).snr < -96:
1313
1447
  return 0
1314
- return speech().c_bak
1448
+ return speech().cbak
1315
1449
 
1316
1450
  if m == 'mxcovl':
1317
1451
  if self.mixture(m_id).snr < -96:
1318
1452
  return 0
1319
- return speech().c_ovl
1453
+ return speech().covl
1320
1454
 
1321
1455
  if m == 'mxwsdr':
1322
1456
  mixture = mixture_audio()[:, np.newaxis]
@@ -1328,8 +1462,7 @@ class MixtureDatabase:
1328
1462
 
1329
1463
  if m == 'mxpd':
1330
1464
  mixture_f = self.mixture_mixture_f(m_id)
1331
- target_f = self.mixture_target_f(m_id)
1332
- return calc_phase_distance(hypothesis=mixture_f, reference=target_f)[0]
1465
+ return calc_phase_distance(hypothesis=mixture_f, reference=target_f())[0]
1333
1466
 
1334
1467
  if m == 'mxstoi':
1335
1468
  return stoi(x=target_audio(), y=mixture_audio(), fs_sig=SAMPLE_RATE, extended=False)
@@ -1364,6 +1497,9 @@ class MixtureDatabase:
1364
1497
  if m == 'tpkc':
1365
1498
  return target_stats().pkc
1366
1499
 
1500
+ if m.startswith('tasr'):
1501
+ return target_asr(get_asr_name(m))
1502
+
1367
1503
  if m == 'ndco':
1368
1504
  return noise_stats().dco
1369
1505
 
@@ -1400,15 +1536,18 @@ class MixtureDatabase:
1400
1536
  if m == 'sedcnt':
1401
1537
  return 0
1402
1538
 
1539
+ if m == 'sedtop3':
1540
+ return np.zeros(3, dtype=np.float32)
1541
+
1403
1542
  if m == 'sedtopn':
1404
1543
  return 0
1405
1544
 
1406
1545
  if m == 'ssnr':
1407
- return self.mixture_segsnr(m_id)
1546
+ return segsnr_f()
1408
1547
 
1409
1548
  raise SonusAIError(f"Unrecognized metric: '{m}'")
1410
1549
 
1411
- result: list[float | int | Segsnr] = []
1550
+ result: list[float | int | str | Segsnr] = []
1412
1551
  for metric in metrics:
1413
1552
  result.append(calc(metric))
1414
1553
 
@@ -210,6 +210,131 @@ class Transformer(SoxTransformer):
210
210
 
211
211
  return self
212
212
 
213
+ def build(self,
214
+ input_filepath: Optional[str | Path] = None,
215
+ output_filepath: Optional[str | Path] = None,
216
+ input_array: Optional[np.ndarray] = None,
217
+ sample_rate_in: Optional[float] = None,
218
+ extra_args: Optional[list[str]] = None,
219
+ return_output: bool = False) -> tuple[bool, Optional[str], Optional[str]]:
220
+ """Given an input file or array, creates an output_file on disk by
221
+ executing the current set of commands. This function returns True on
222
+ success. If return_output is True, this function returns a triple of
223
+ (status, out, err), giving the success state, along with stdout and
224
+ stderr returned by sox.
225
+
226
+ Parameters
227
+ ----------
228
+ input_filepath : str or None
229
+ Either path to input audio file or None for array input.
230
+ output_filepath : str
231
+ Path to desired output file. If a file already exists at
232
+ the given path, the file will be overwritten.
233
+ If '-n', no file is created.
234
+ input_array : np.ndarray or None
235
+ An np.ndarray of an waveform with shape (n_samples, n_channels).
236
+ sample_rate_in must also be provided.
237
+ If None, input_filepath must be specified.
238
+ sample_rate_in : int
239
+ Sample rate of input_array.
240
+ This argument is ignored if input_array is None.
241
+ extra_args : list or None, default=None
242
+ If a list is given, these additional arguments are passed to SoX
243
+ at the end of the list of effects.
244
+ Don't use this argument unless you know exactly what you're doing!
245
+ return_output : bool, default=False
246
+ If True, returns the status and information sent to stderr and
247
+ stdout as a tuple (status, stdout, stderr).
248
+ If output_filepath is None, return_output=True by default.
249
+ If False, returns True on success.
250
+
251
+ Returns
252
+ -------
253
+ status : bool
254
+ True on success.
255
+ out : str (optional)
256
+ This is not returned unless return_output is True.
257
+ When returned, captures the stdout produced by sox.
258
+ err : str (optional)
259
+ This is not returned unless return_output is True.
260
+ When returned, captures the stderr produced by sox.
261
+
262
+ Examples
263
+ --------
264
+ > import numpy as np
265
+ > import sox
266
+ > tfm = sox.Transformer()
267
+ > sample_rate = 44100
268
+ > y = np.sin(2 * np.pi * 440.0 * np.arange(sample_rate * 1.0) / sample_rate)
269
+
270
+ file in, file out - basic usage
271
+
272
+ > status = tfm.build('path/to/input.wav', 'path/to/output.mp3')
273
+
274
+ file in, file out - equivalent usage
275
+
276
+ > status = tfm.build(
277
+ input_filepath='path/to/input.wav',
278
+ output_filepath='path/to/output.mp3'
279
+ )
280
+
281
+ array in, file out
282
+
283
+ > status = tfm.build(
284
+ input_array=y, sample_rate_in=sample_rate,
285
+ output_filepath='path/to/output.mp3'
286
+ )
287
+
288
+ """
289
+ from sox import file_info
290
+ from sox.core import SoxError
291
+ from sox.core import sox
292
+ from sox.log import logger
293
+
294
+ input_format, input_filepath = self._parse_inputs(
295
+ input_filepath, input_array, sample_rate_in
296
+ )
297
+
298
+ if output_filepath is None:
299
+ raise ValueError("output_filepath is not specified!")
300
+
301
+ # set output parameters
302
+ if input_filepath == output_filepath:
303
+ raise ValueError(
304
+ "input_filepath must be different from output_filepath."
305
+ )
306
+ file_info.validate_output_file(output_filepath)
307
+
308
+ args = []
309
+ args.extend(self.globals)
310
+ args.extend(self._input_format_args(input_format))
311
+ args.append(input_filepath)
312
+ args.extend(self._output_format_args(self.output_format))
313
+ args.append(output_filepath)
314
+ args.extend(self.effects)
315
+
316
+ if extra_args is not None:
317
+ if not isinstance(extra_args, list):
318
+ raise ValueError("extra_args must be a list.")
319
+ args.extend(extra_args)
320
+
321
+ status, out, err = sox(args, input_array, True)
322
+ if status != 0:
323
+ raise SoxError(
324
+ f"Stdout: {out}\nStderr: {err}"
325
+ )
326
+
327
+ logger.info(
328
+ "Created %s with effects: %s",
329
+ output_filepath,
330
+ " ".join(self.effects_log)
331
+ )
332
+
333
+ if return_output:
334
+ return status, out, err
335
+
336
+ return True, None, None
337
+
213
338
  def build_array(self,
214
339
  input_filepath: Optional[str | Path] = None,
215
340
  input_array: Optional[np.ndarray] = None,
@@ -3,13 +3,14 @@ from sonusai.mixture.datatypes import TruthFunctionConfig
3
3
 
4
4
 
5
5
  class Data:
6
- def __init__(self, target_audio: AudioT,
6
+ def __init__(self,
7
+ target_audio: AudioT,
7
8
  noise_audio: AudioT,
8
9
  mixture_audio: AudioT,
9
10
  config: TruthFunctionConfig) -> None:
10
11
  import numpy as np
11
- from pyaaware import AawareForwardTransform
12
- from pyaaware import AawareInverseTransform
12
+ from sonusai import ForwardTransform
13
+ from sonusai import InverseTransform
13
14
  from pyaaware import FeatureGenerator
14
15
 
15
16
  from sonusai import SonusAIError
@@ -33,25 +34,25 @@ class Data:
33
34
 
34
35
  self.offsets = range(0, len(target_audio), self.frame_size)
35
36
  self.zero_based_indices = [x - 1 for x in config.index]
36
- self.target_fft = AawareForwardTransform(N=fg.ftransform_N,
37
- R=fg.ftransform_R,
38
- bin_start=fg.bin_start,
39
- bin_end=fg.bin_end,
40
- ttype=fg.ftransform_ttype)
41
- self.noise_fft = AawareForwardTransform(N=fg.ftransform_N,
42
- R=fg.ftransform_R,
43
- bin_start=fg.bin_start,
44
- bin_end=fg.bin_end,
45
- ttype=fg.ftransform_ttype)
46
- self.mixture_fft = AawareForwardTransform(N=fg.ftransform_N,
47
- R=fg.ftransform_R,
48
- bin_start=fg.bin_start,
49
- bin_end=fg.bin_end,
50
- ttype=fg.ftransform_ttype)
51
- self.swin = AawareInverseTransform(N=fg.itransform_N,
52
- R=fg.itransform_R,
37
+ self.target_fft = ForwardTransform(N=fg.ftransform_N,
38
+ R=fg.ftransform_R,
53
39
  bin_start=fg.bin_start,
54
40
  bin_end=fg.bin_end,
55
- ttype=fg.itransform_ttype,
56
- gain=np.float32(1)).W
41
+ ttype=fg.ftransform_ttype)
42
+ self.noise_fft = ForwardTransform(N=fg.ftransform_N,
43
+ R=fg.ftransform_R,
44
+ bin_start=fg.bin_start,
45
+ bin_end=fg.bin_end,
46
+ ttype=fg.ftransform_ttype)
47
+ self.mixture_fft = ForwardTransform(N=fg.ftransform_N,
48
+ R=fg.ftransform_R,
49
+ bin_start=fg.bin_start,
50
+ bin_end=fg.bin_end,
51
+ ttype=fg.ftransform_ttype)
52
+ self.swin = InverseTransform(N=fg.itransform_N,
53
+ R=fg.itransform_R,
54
+ bin_start=fg.bin_start,
55
+ bin_end=fg.bin_end,
56
+ ttype=fg.itransform_ttype,
57
+ gain=np.float32(1)).W
57
58
  self.truth = np.zeros((len(target_audio), config.num_classes), dtype=np.float32)
@@ -132,9 +132,11 @@ def energy_t(data: Data) -> Truth:
132
132
  will reflect the total energy over all bins regardless of the feature
133
133
  transform config.
134
134
  """
135
+ import torch
136
+
135
137
  from sonusai import SonusAIError
136
138
 
137
- _, target_energy = data.target_fft.execute_all(data.target_audio)
139
+ target_energy = data.target_fft.execute_all(torch.from_numpy(data.target_audio))[1].numpy()
138
140
  if len(target_energy) != len(data.offsets):
139
141
  raise SonusAIError(f'Number of frames in target_energy, {len(target_energy)},'
140
142
  f' is not number of frames in truth, {len(data.offsets)}')
@@ -21,6 +21,7 @@ should be set to the number of sounds/classes to be detected + 1 for
21
21
  the other class.
22
22
  """
23
23
  import numpy as np
24
+ import torch
24
25
  from pyaaware import SED
25
26
 
26
27
  from sonusai import SonusAIError
@@ -48,7 +49,7 @@ the other class.
48
49
  mutex=data.config.mutex)
49
50
 
50
51
  target_audio = data.target_audio / data.config.target_gain
51
- _, energy_t = data.target_fft.execute_all(target_audio)
52
+ energy_t = data.target_fft.execute_all(torch.from_numpy(target_audio))[1].numpy()
52
53
  if len(energy_t) != len(data.offsets):
53
54
  raise SonusAIError(f'Number of frames in energy_t, {len(energy_t)},'
54
55
  f' is not number of frames in truth, {len(data.offsets)}')
@@ -1,4 +1,4 @@
1
- from pyaaware import ForwardTransform
1
+ from sonusai import ForwardTransform
2
2
 
3
3
  from sonusai.mixture.datatypes import AudioF
4
4
  from sonusai.mixture.datatypes import AudioT
@@ -98,7 +98,6 @@ Output shape: [:, 2 * bins] (stacked real, imag)
98
98
  for idx, offset in enumerate(data.offsets):
99
99
  target_freq, _ = data.target_fft.execute(
100
100
  np.multiply(data.target_audio[offset:offset + data.frame_size], data.swin))
101
- target_freq = target_freq.transpose()
102
101
 
103
102
  indices = slice(offset, offset + data.frame_size)
104
103
  for index in data.zero_based_indices:
@@ -112,10 +111,10 @@ Output shape: [:, 2 * bins] (stacked real, imag)
112
111
 
113
112
 
114
113
  def _execute_fft(audio: AudioT, transform: ForwardTransform, expected_frames: int) -> AudioF:
114
+ import torch
115
115
  from sonusai import SonusAIError
116
116
 
117
- freq, _ = transform.execute_all(audio)
118
- freq = freq.transpose()
117
+ freq = transform.execute_all(torch.from_numpy(audio))[0].numpy()
119
118
  if len(freq) != expected_frames:
120
119
  raise SonusAIError(f'Number of frames, {len(freq)}, is not number of frames expected, {expected_frames}')
121
120
  return freq