sonusai 0.18.1__py3-none-any.whl → 0.18.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/mixture/mixdb.py CHANGED
@@ -4,9 +4,9 @@ from functools import partial
4
4
  from sqlite3 import Connection
5
5
  from sqlite3 import Cursor
6
6
  from typing import Any
7
- from typing import Callable
8
7
  from typing import Optional
9
8
 
9
+ from sonusai.mixture.datatypes import ASRConfigs
10
10
  from sonusai.mixture.datatypes import AudioF
11
11
  from sonusai.mixture.datatypes import AudioT
12
12
  from sonusai.mixture.datatypes import AudiosF
@@ -89,6 +89,7 @@ class MixtureDatabase:
89
89
  from .datatypes import MixtureDatabaseConfig
90
90
 
91
91
  config = MixtureDatabaseConfig(
92
+ asr_configs=self.asr_configs,
92
93
  class_balancing=self.class_balancing,
93
94
  class_labels=self.class_labels,
94
95
  class_weights_threshold=self.class_weights_thresholds,
@@ -146,6 +147,30 @@ class MixtureDatabase:
146
147
  with self.db() as c:
147
148
  return str(c.execute("SELECT top.noise_mix_mode FROM top").fetchone()[0])
148
149
 
150
+ @cached_property
151
+ def asr_configs(self) -> ASRConfigs:
152
+ import json
153
+
154
+ with self.db() as c:
155
+ return json.loads(c.execute("SELECT top.asr_configs FROM top").fetchone()[0])
156
+
157
+ @cached_property
158
+ def supported_metrics(self) -> set[str]:
159
+ metrics = {
160
+ 'mxssnravg', 'mxssnrvar', 'mxssnrdavg', 'mxssnrdstd',
161
+ 'mxpesq', 'mxcsig', 'mxcbak', 'mxcovl', 'mxwsdr',
162
+ 'mxpd',
163
+ 'mxstoi',
164
+ 'tdco', 'tmin', 'tmax', 'tpkdb', 'tlrms', 'tpkr', 'ttr', 'tcr', 'tfl', 'tpkc',
165
+ 'ndco', 'nmin', 'nmax', 'npkdb', 'nlrms', 'npkr', 'ntr', 'ncr', 'nfl', 'npkc',
166
+ 'sedavg', 'sedcnt', 'sedtopn',
167
+ 'ssnr',
168
+ }
169
+ for name in self.asr_configs:
170
+ metrics.add(f'mxwer.{name}')
171
+
172
+ return metrics
173
+
149
174
  @cached_property
150
175
  def class_balancing(self) -> bool:
151
176
  with self.db() as c:
@@ -1062,13 +1087,13 @@ class MixtureDatabase:
1062
1087
 
1063
1088
  def mixids_for_speech_metadata(self,
1064
1089
  tier: str,
1065
- value: str | None,
1066
- predicate: Callable[[str], bool] = None) -> list[int]:
1090
+ value: str = None,
1091
+ where: str = None) -> list[int]:
1067
1092
  """Get a list of mixture IDs for the given speech metadata tier.
1068
1093
 
1069
- If 'predicate' is None, then include mixture IDs whose tier values are equal to the given 'value'.
1070
- If 'predicate' is not None, then ignore 'value' and use the given callable to determine which entries
1071
- to include.
1094
+ If 'where' is None, then include mixture IDs whose tier values are equal to the given 'value'.
1095
+ If 'where' is not None, then ignore 'value' and use the given SQL where clause to determine
1096
+ which entries to include.
1072
1097
 
1073
1098
  Examples:
1074
1099
  >>> mixdb = MixtureDatabase('/mixdb_location')
@@ -1076,208 +1101,319 @@ class MixtureDatabase:
1076
1101
  >>> mixids = mixdb.mixids_for_speech_metadata('speaker_id', 'TIMIT_ARC0')
1077
1102
  Get mixutre IDs for mixtures with speakers whose speaker_ids are 'TIMIT_ARC0'.
1078
1103
 
1079
- >>> mixids = mixdb.mixids_for_speech_metadata('age', '', lambda x: int(x) < 25)
1104
+ >>> mixids = mixdb.mixids_for_speech_metadata('age', where='age < 25')
1080
1105
  Get mixture IDs for mixtures with speakers whose ages are less than 25.
1081
1106
 
1082
- >>> mixids = mixdb.mixids_for_speech_metadata('dialect', '', lambda x: x in ['New York City', 'Northern'])
1107
+ >>> mixids = mixdb.mixids_for_speech_metadata('dialect', where="dialect in ('New York City', 'Northern')")
1083
1108
  Get mixture IDs for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
1084
1109
  """
1085
- from .helpers import get_textgrid_tier_from_target_file
1110
+ from sonusai import SonusAIError
1086
1111
 
1087
- if predicate is None:
1088
- def predicate(x: str | None) -> bool:
1089
- return x == value
1112
+ if value is None and where is None:
1113
+ raise SonusAIError('Must provide either value or where')
1090
1114
 
1091
- # First get list of matching target files
1092
- target_file_ids: list[int] = []
1093
- is_textgrid = tier in self.textgrid_metadata_tiers
1094
- for target_file_id, target_file in enumerate(self.target_files):
1095
- if is_textgrid:
1096
- metadata = get_textgrid_tier_from_target_file(target_file.name, tier)
1097
- else:
1098
- metadata = self.speaker(target_file.speaker_id, tier)
1115
+ if where is None:
1116
+ where = f"{tier} = '{value}'"
1099
1117
 
1100
- if not isinstance(metadata, list) and predicate(metadata):
1101
- target_file_ids.append(target_file_id + 1)
1118
+ if tier in self.textgrid_metadata_tiers:
1119
+ raise SonusAIError(f'TextGrid tier data, "{tier}", is not supported in mixids_for_speech_metadata().')
1102
1120
 
1103
- # Next get list of mixture IDs that contain those target files
1104
1121
  with self.db() as c:
1105
- m_ids = c.execute("SELECT mixture_id FROM mixture_target " +
1106
- f"WHERE mixture_target.target_id IN ({','.join(map(str, target_file_ids))})").fetchall()
1107
- return [x[0] - 1 for x in m_ids]
1122
+ speaker_ids = [speaker_id[0] for speaker_id in
1123
+ c.execute(f"SELECT id FROM speaker WHERE {where}").fetchall()]
1124
+ results = c.execute(f"SELECT id FROM target_file " +
1125
+ f"WHERE speaker_id IN ({','.join(map(str, speaker_ids))})").fetchall()
1126
+ target_file_ids = [target_file_id[0] for target_file_id in results]
1127
+ results = c.execute("SELECT mixture_id FROM mixture_target " +
1128
+ f"WHERE mixture_target.target_id IN ({','.join(map(str, target_file_ids))})").fetchall()
1129
+
1130
+ return [mixture_id[0] - 1 for mixture_id in results]
1108
1131
 
1109
1132
  def mixture_all_speech_metadata(self, m_id: int) -> list[dict[str, SpeechMetadata]]:
1110
1133
  from .helpers import mixture_all_speech_metadata
1111
1134
 
1112
1135
  return mixture_all_speech_metadata(self, self.mixture(m_id))
1113
1136
 
1114
- def mixture_metric(self, m_id: int, metric: str, force: bool = False) -> Any:
1115
- """Get metric data for the given mixture ID
1137
+ def mixture_metrics(self, m_id: int,
1138
+ metrics: list[str],
1139
+ force: bool = False) -> list[float | int | Segsnr]:
1140
+ """Get metrics data for the given mixture ID
1116
1141
 
1117
1142
  :param m_id: Zero-based mixture ID
1118
- :param metric: Metric data to retrieve
1143
+ :param metrics: List of metrics to get
1119
1144
  :param force: Force computing data from original sources regardless of whether cached data exists
1120
- :return: Metric data
1145
+ :return: List of metric data
1121
1146
  """
1147
+ from typing import Callable
1148
+
1149
+ import numpy as np
1150
+ from pystoi import stoi
1151
+
1122
1152
  from sonusai import SonusAIError
1153
+ from sonusai.metrics import calc_audio_stats
1154
+ from sonusai.metrics import calc_phase_distance
1155
+ from sonusai.metrics import calc_snr_f
1156
+ from sonusai.metrics import calc_speech
1157
+ from sonusai.metrics import calc_wer
1158
+ from sonusai.metrics import calc_wsdr
1159
+ from sonusai.mixture import SAMPLE_RATE
1160
+ from sonusai.mixture import AudioStatsMetrics
1161
+ from sonusai.mixture import SpeechMetrics
1162
+ from sonusai.utils import calc_asr
1123
1163
 
1124
- supported_metrics = (
1125
- 'MXSNR',
1126
- 'MXSSNRAVG',
1127
- 'MXSSNRSTD',
1128
- 'MXSSNRDAVG',
1129
- 'MXSSNRDSTD',
1130
- 'MXPESQ',
1131
- 'MXWSDR',
1132
- 'MXPD',
1133
- 'MXSTOI',
1134
- 'MXCSIG',
1135
- 'MXCBAK',
1136
- 'MXCOVL',
1137
- 'TDCO',
1138
- 'TMIN',
1139
- 'TMAX',
1140
- 'TPKDB',
1141
- 'TLRMS',
1142
- 'TPKR',
1143
- 'TTR',
1144
- 'TCR',
1145
- 'TFL',
1146
- 'TPKC',
1147
- 'NDCO',
1148
- 'NMIN',
1149
- 'NMAX',
1150
- 'NPKDB',
1151
- 'NLRMS',
1152
- 'NPKR',
1153
- 'NTR',
1154
- 'NCR',
1155
- 'NFL',
1156
- 'NPKC',
1157
- 'SEDAVG',
1158
- 'SEDCNT',
1159
- 'SEDTOPN',
1160
- )
1164
+ def create_target_audio() -> Callable:
1165
+ state = None
1161
1166
 
1162
- if not (metric in supported_metrics or metric.startswith('MXWER')):
1163
- raise ValueError(f'Unsupported metric: {metric}')
1167
+ def get() -> np.ndarray:
1168
+ nonlocal state
1169
+ if state is None:
1170
+ state = self.mixture_target(m_id)
1171
+ return state
1164
1172
 
1165
- if not force:
1166
- result = self.read_mixture_data(m_id, metric)
1167
- if result is not None:
1168
- return result
1173
+ return get
1169
1174
 
1170
- mixture = self.mixture(m_id)
1171
- if mixture is None:
1172
- raise SonusAIError(f'Could not find mixture for m_id: {m_id}')
1175
+ target_audio = create_target_audio()
1173
1176
 
1174
- if metric.startswith('MXWER'):
1175
- return None
1177
+ def create_noise_audio() -> Callable:
1178
+ state = None
1176
1179
 
1177
- if metric == 'MXSNR':
1178
- return self.snrs
1180
+ def get() -> np.ndarray:
1181
+ nonlocal state
1182
+ if state is None:
1183
+ state = self.mixture_noise(m_id)
1184
+ return state
1179
1185
 
1180
- if metric == 'MXSSNRAVG':
1181
- return None
1186
+ return get
1182
1187
 
1183
- if metric == 'MXSSNRSTD':
1184
- return None
1188
+ noise_audio = create_noise_audio()
1185
1189
 
1186
- if metric == 'MXSSNRDAVG':
1187
- return None
1190
+ def create_mixture_audio() -> Callable:
1191
+ state = None
1188
1192
 
1189
- if metric == 'MXSSNRDSTD':
1190
- return None
1193
+ def get() -> np.ndarray:
1194
+ nonlocal state
1195
+ if state is None:
1196
+ state = self.mixture_mixture(m_id)
1197
+ return state
1191
1198
 
1192
- if metric == 'MXPESQ':
1193
- return None
1199
+ return get
1194
1200
 
1195
- if metric == 'MXWSDR':
1196
- return None
1201
+ mixture_audio = create_mixture_audio()
1197
1202
 
1198
- if metric == 'MXPD':
1199
- return None
1203
+ def create_segsnr_f() -> Callable:
1204
+ state = None
1200
1205
 
1201
- if metric == 'MXSTOI':
1202
- return None
1206
+ def get() -> np.ndarray:
1207
+ nonlocal state
1208
+ if state is None:
1209
+ state = self.mixture_segsnr(m_id)
1210
+ return state
1203
1211
 
1204
- if metric == 'MXCSIG':
1205
- return None
1212
+ return get
1206
1213
 
1207
- if metric == 'MXCBAK':
1208
- return None
1214
+ segsnr_f = create_segsnr_f()
1209
1215
 
1210
- if metric == 'MXCOVL':
1211
- return None
1216
+ def create_speech() -> Callable:
1217
+ state = None
1212
1218
 
1213
- if metric == 'TDCO':
1214
- return None
1219
+ def get() -> SpeechMetrics:
1220
+ nonlocal state
1221
+ if state is None:
1222
+ state = calc_speech(hypothesis=mixture_audio(), reference=target_audio())
1223
+ return state
1215
1224
 
1216
- if metric == 'TMIN':
1217
- return None
1225
+ return get
1218
1226
 
1219
- if metric == 'TMAX':
1220
- return None
1227
+ speech = create_speech()
1221
1228
 
1222
- if metric == 'TPKDB':
1223
- return None
1229
+ def create_target_stats() -> Callable:
1230
+ state = None
1224
1231
 
1225
- if metric == 'TLRMS':
1226
- return None
1232
+ def get() -> AudioStatsMetrics:
1233
+ nonlocal state
1234
+ if state is None:
1235
+ state = calc_audio_stats(target_audio(), self.fg_info.ft_config.N / SAMPLE_RATE)
1236
+ return state
1227
1237
 
1228
- if metric == 'TPKR':
1229
- return None
1238
+ return get
1230
1239
 
1231
- if metric == 'TTR':
1232
- return None
1240
+ target_stats = create_target_stats()
1233
1241
 
1234
- if metric == 'TCR':
1235
- return None
1242
+ def create_noise_stats() -> Callable:
1243
+ state = None
1236
1244
 
1237
- if metric == 'TFL':
1238
- return None
1245
+ def get() -> AudioStatsMetrics:
1246
+ nonlocal state
1247
+ if state is None:
1248
+ state = calc_audio_stats(noise_audio(), self.fg_info.ft_config.N / SAMPLE_RATE)
1249
+ return state
1239
1250
 
1240
- if metric == 'TPKC':
1241
- return None
1251
+ return get
1242
1252
 
1243
- if metric == 'NDCO':
1244
- return None
1253
+ noise_stats = create_noise_stats()
1245
1254
 
1246
- if metric == 'NMIN':
1247
- return None
1255
+ def calc(m: str) -> float | int | Segsnr:
1256
+ if m == 'mxsnr':
1257
+ return self.mixture(m_id).snr
1248
1258
 
1249
- if metric == 'NMAX':
1250
- return None
1259
+ # Get cached data first, if exists
1260
+ if not force:
1261
+ value = self.read_mixture_data(m_id, m)
1262
+ if value is not None:
1263
+ return value
1251
1264
 
1252
- if metric == 'NPKDB':
1253
- return None
1265
+ # Otherwise, generate data as needed
1266
+ if m.startswith('mxwer'):
1267
+ parts = m.split('.')
1268
+ if len(parts) != 3:
1269
+ raise SonusAIError(
1270
+ f"Unrecognized 'mwwer' format: '{m}'; must be of the form: 'mxwer.<engine>.<model>'")
1271
+ asr_engine = parts[1]
1272
+ asr_model = parts[2]
1254
1273
 
1255
- if metric == 'NLRMS':
1256
- return None
1274
+ if asr_engine == 'none' or self.mixture(m_id).snr < -96:
1275
+ # noise only, ignore/reset target asr
1276
+ return float('nan')
1257
1277
 
1258
- if metric == 'NPKR':
1259
- return None
1278
+ # ignore mixup
1279
+ target_asr = self.mixture_speech_metadata(m_id, 'text')[0]
1280
+ if target_asr is None:
1281
+ target_asr = calc_asr(target_audio(), engine=asr_engine, whisper_model_name=asr_model).text
1260
1282
 
1261
- if metric == 'NTR':
1262
- return None
1283
+ if target_asr:
1284
+ mixture_asr = calc_asr(mixture_audio(), engine=asr_engine, whisper_model_name=asr_model).text
1285
+ return calc_wer(mixture_asr, target_asr).wer * 100
1263
1286
 
1264
- if metric == 'NCR':
1265
- return None
1287
+ # TODO: should this be NaN like above?
1288
+ return float(0)
1266
1289
 
1267
- if metric == 'NFL':
1268
- return None
1290
+ if m == 'mxssnravg':
1291
+ return calc_snr_f(segsnr_f()).mean
1269
1292
 
1270
- if metric == 'NPKC':
1271
- return None
1293
+ if m == 'mxssnrvar':
1294
+ return calc_snr_f(segsnr_f()).var
1272
1295
 
1273
- if metric == 'SEDAVG':
1274
- return None
1296
+ if m == 'mxssnrdavg':
1297
+ return calc_snr_f(segsnr_f()).db_mean
1275
1298
 
1276
- if metric == 'SEDCNT':
1277
- return None
1299
+ if m == 'mxssnrdstd':
1300
+ return calc_snr_f(segsnr_f()).db_std
1278
1301
 
1279
- if metric == 'SEDTOPN':
1280
- return None
1302
+ if m == 'mxpesq':
1303
+ if self.mixture(m_id).snr < -96:
1304
+ return 0
1305
+ return speech().pesq
1306
+
1307
+ if m == 'mxcsig':
1308
+ if self.mixture(m_id).snr < -96:
1309
+ return 0
1310
+ return speech().c_sig
1311
+
1312
+ if m == 'mxcbak':
1313
+ if self.mixture(m_id).snr < -96:
1314
+ return 0
1315
+ return speech().c_bak
1316
+
1317
+ if m == 'mxcovl':
1318
+ if self.mixture(m_id).snr < -96:
1319
+ return 0
1320
+ return speech().c_ovl
1321
+
1322
+ if m == 'mxwsdr':
1323
+ mixture = mixture_audio()[:, np.newaxis]
1324
+ target = target_audio()[:, np.newaxis]
1325
+ noise = noise_audio()[:, np.newaxis]
1326
+ return calc_wsdr(hypothesis=np.concatenate((mixture, noise), axis=1),
1327
+ reference=np.concatenate((target, noise), axis=1),
1328
+ with_log=True)[0]
1329
+
1330
+ if m == 'mxpd':
1331
+ mixture_f = self.mixture_mixture_f(m_id)
1332
+ target_f = self.mixture_target_f(m_id)
1333
+ return calc_phase_distance(hypothesis=mixture_f, reference=target_f)[0]
1334
+
1335
+ if m == 'mxstoi':
1336
+ return stoi(x=target_audio(), y=mixture_audio(), fs_sig=SAMPLE_RATE, extended=False)
1337
+
1338
+ if m == 'tdco':
1339
+ return target_stats().dco
1340
+
1341
+ if m == 'tmin':
1342
+ return target_stats().min
1343
+
1344
+ if m == 'tmax':
1345
+ return target_stats().max
1346
+
1347
+ if m == 'tpkdb':
1348
+ return target_stats().pkdb
1349
+
1350
+ if m == 'tlrms':
1351
+ return target_stats().lrms
1352
+
1353
+ if m == 'tpkr':
1354
+ return target_stats().pkr
1355
+
1356
+ if m == 'ttr':
1357
+ return target_stats().tr
1358
+
1359
+ if m == 'tcr':
1360
+ return target_stats().cr
1361
+
1362
+ if m == 'tfl':
1363
+ return target_stats().fl
1364
+
1365
+ if m == 'tpkc':
1366
+ return target_stats().pkc
1367
+
1368
+ if m == 'ndco':
1369
+ return noise_stats().dco
1370
+
1371
+ if m == 'nmin':
1372
+ return noise_stats().min
1373
+
1374
+ if m == 'nmax':
1375
+ return noise_stats().max
1376
+
1377
+ if m == 'npkdb':
1378
+ return noise_stats().pkdb
1379
+
1380
+ if m == 'nlrms':
1381
+ return noise_stats().lrms
1382
+
1383
+ if m == 'npkr':
1384
+ return noise_stats().pkr
1385
+
1386
+ if m == 'ntr':
1387
+ return noise_stats().tr
1388
+
1389
+ if m == 'ncr':
1390
+ return noise_stats().cr
1391
+
1392
+ if m == 'nfl':
1393
+ return noise_stats().fl
1394
+
1395
+ if m == 'npkc':
1396
+ return noise_stats().pkc
1397
+
1398
+ if m == 'sedavg':
1399
+ return 0
1400
+
1401
+ if m == 'sedcnt':
1402
+ return 0
1403
+
1404
+ if m == 'sedtopn':
1405
+ return 0
1406
+
1407
+ if m == 'ssnr':
1408
+ return self.mixture_segsnr(m_id)
1409
+
1410
+ raise SonusAIError(f"Unrecognized metric: '{m}'")
1411
+
1412
+ result: list[float | int | Segsnr] = []
1413
+ for metric in metrics:
1414
+ result.append(calc(metric))
1415
+
1416
+ return result
1281
1417
 
1282
1418
 
1283
1419
  @lru_cache
@@ -1,8 +1,10 @@
1
+ from pathlib import Path
2
+
1
3
  from sonusai.mixture.datatypes import AudioT
2
4
  from sonusai.mixture.datatypes import ImpulseResponseData
3
5
 
4
6
 
5
- def _raw_read(name: str) -> tuple[AudioT, int]:
7
+ def _raw_read(name: str | Path) -> tuple[AudioT, int]:
6
8
  import numpy as np
7
9
  import soundfile
8
10
  from pydub import AudioSegment
@@ -34,7 +36,7 @@ def _raw_read(name: str) -> tuple[AudioT, int]:
34
36
  return np.squeeze(raw[:, 0]), sample_rate
35
37
 
36
38
 
37
- def get_sample_rate(name: str) -> int:
39
+ def get_sample_rate(name: str | Path) -> int:
38
40
  """Get sample rate from audio file using soundfile
39
41
 
40
42
  :param name: File name
@@ -63,7 +65,7 @@ def get_sample_rate(name: str) -> int:
63
65
  raise SonusAIError(f'Error reading {name}: {e}')
64
66
 
65
67
 
66
- def read_ir(name: str) -> ImpulseResponseData:
68
+ def read_ir(name: str | Path) -> ImpulseResponseData:
67
69
  """Read impulse response data using soundfile
68
70
 
69
71
  :param name: File name
@@ -79,10 +81,10 @@ def read_ir(name: str) -> ImpulseResponseData:
79
81
  out = out[offset:]
80
82
  out = out / np.linalg.norm(out)
81
83
 
82
- return ImpulseResponseData(name=name, sample_rate=sample_rate, data=out)
84
+ return ImpulseResponseData(name=str(name), sample_rate=sample_rate, data=out)
83
85
 
84
86
 
85
- def read_audio(name: str) -> AudioT:
87
+ def read_audio(name: str | Path) -> AudioT:
86
88
  """Read audio data from a file using soundfile
87
89
 
88
90
  :param name: File name
@@ -101,7 +103,7 @@ def read_audio(name: str) -> AudioT:
101
103
  return out
102
104
 
103
105
 
104
- def get_num_samples(name: str) -> int:
106
+ def get_num_samples(name: str | Path) -> int:
105
107
  """Get the number of samples resampled to the SonusAI sample rate in the given file
106
108
 
107
109
  :param name: File name
@@ -1,16 +1,19 @@
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ import numpy as np
1
5
  from sox import Transformer as SoxTransformer
2
6
 
3
7
  from sonusai.mixture.datatypes import AudioT
4
8
  from sonusai.mixture.datatypes import ImpulseResponseData
5
9
 
6
10
 
7
- def read_impulse_response(name: str) -> ImpulseResponseData:
11
+ def read_impulse_response(name: str | Path) -> ImpulseResponseData:
8
12
  """Read impulse response data using SoX
9
13
 
10
14
  :param name: File name
11
15
  :return: ImpulseResponseData object
12
16
  """
13
- import numpy as np
14
17
  from scipy.io import wavfile
15
18
 
16
19
  from sonusai import SonusAIError
@@ -33,10 +36,10 @@ def read_impulse_response(name: str) -> ImpulseResponseData:
33
36
  data = data[offset:]
34
37
  data = data / np.linalg.norm(data)
35
38
 
36
- return ImpulseResponseData(name=name, sample_rate=sample_rate, data=data)
39
+ return ImpulseResponseData(name=str(name), sample_rate=sample_rate, data=data)
37
40
 
38
41
 
39
- def read_audio(name: str) -> AudioT:
42
+ def read_audio(name: str | Path) -> AudioT:
40
43
  """Read audio data from a file using SoX
41
44
 
42
45
  :param name: File name
@@ -44,7 +47,6 @@ def read_audio(name: str) -> AudioT:
44
47
  """
45
48
  from typing import Any
46
49
 
47
- import numpy as np
48
50
  from sox.core import sox
49
51
 
50
52
  from sonusai import SonusAIError
@@ -208,8 +210,11 @@ class Transformer(SoxTransformer):
208
210
 
209
211
  return self
210
212
 
211
- def build_array(self, input_filepath=None, input_array=None,
212
- sample_rate_in=None, extra_args=None):
213
+ def build_array(self,
214
+ input_filepath: Optional[str | Path] = None,
215
+ input_array: Optional[np.ndarray] = None,
216
+ sample_rate_in: Optional[int] = None,
217
+ extra_args: Optional[list[str]] = None) -> np.ndarray:
213
218
  """Given an input file or array, returns the output as a numpy array
214
219
  by executing the current set of commands. By default, the array will
215
220
  have the same sample rate as the input file unless otherwise specified
@@ -220,7 +225,7 @@ class Transformer(SoxTransformer):
220
225
 
221
226
  Parameters
222
227
  ----------
223
- input_filepath : str or None
228
+ input_filepath : str, Path or None
224
229
  Either path to input audio file or None.
225
230
  input_array : np.ndarray or None
226
231
  A np.ndarray of a waveform with shape (n_samples, n_channels).
@@ -270,8 +275,6 @@ class Transformer(SoxTransformer):
270
275
 
271
276
 
272
277
  """
273
- import numpy as np
274
-
275
278
  from sox.core import SoxError
276
279
  from sox.core import sox
277
280
  from sox.log import logger
@@ -324,13 +327,13 @@ class Transformer(SoxTransformer):
324
327
 
325
328
  match n_bits:
326
329
  case 8:
327
- encoding_out = np.int8
330
+ encoding_out = np.int8 # type: ignore
328
331
  case 16:
329
332
  encoding_out = np.int16
330
333
  case 32:
331
- encoding_out = np.float32
334
+ encoding_out = np.float32 # type: ignore
332
335
  case 64:
333
- encoding_out = np.float64
336
+ encoding_out = np.float64 # type: ignore
334
337
  case _:
335
338
  raise ValueError("invalid n_bits {}".format(n_bits))
336
339