sonusai 0.18.9__py3-none-any.whl → 0.19.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +93 -77
  13. sonusai/genmetrics.py +59 -46
  14. sonusai/genmix.py +116 -104
  15. sonusai/genmixdb.py +194 -153
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +15 -17
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +19 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +40 -38
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +669 -477
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +52 -85
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +40 -27
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
  113. sonusai-0.19.5.dist-info/RECORD +125 -0
  114. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.9.dist-info/RECORD +0 -125
  118. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
@@ -4,44 +4,46 @@ from sonusai.mixture.datatypes import Predict
4
4
  from sonusai.mixture.datatypes import Truth
5
5
 
6
6
 
7
- def one_hot(truth: Truth,
8
- predict: Predict,
9
- predict_thr: float | np.ndarray = 0,
10
- truth_thr: float = 0.5,
11
- timesteps: int = -1) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
12
- """ Calculates metrics from one-hot prediction and truth data (numpy float arrays) where
13
- both are one-hot probabilities (or quantized decisions) for each class
14
- with size [frames, num_classes] or [frames, timesteps, num_classes].
15
- For metrics that require it, truth and pred decisions will be made using threshold >= predict_thr.
16
- Some metrics like AP and AUC do not depend on predict_thr for predict, but still use truth >= predict_thr
17
-
18
- predict_thr sets the decision threshold(s) applied to predict data for some metrics, thus allowing
19
- the input to be continuous probabilities, for AUC-type metrics and root-mean-square error (rmse).
20
- 1. Default = 0 (multiclass or binary) which infers:
21
- binary (num_classes = 1) use >= 0.5 for truth and pred (same as argmax() for binary)
22
- multi-class/single-label if truth_mutex= = true, use argmax() used on both truth and pred
23
- note multilabel metrics are disabled for predict_thr = 0, must set predict_thr > 0
24
-
25
- 2. predict_thr > 0 (multilabel or binary) scalar or a vector [num_classes, 1] then use
26
- predict_thr as a binary decision threshold in each class:
27
- binary (num_classes = 1) use >= predict_thr[0] for pred and predict_thr[num_classes+1] for truth
28
- if it exists, else use >= 0.5 for truth
29
- multilabel use >= predict_thr for pred if scalar, or predict_thr[class_idx] if vector
30
- use >= predict_thr[num_classes+1] for truth if exists, else 0.5
31
- note multi-class/single-label inputs are meaningless in this mode, use predict_thr = 0 argmax mode
32
-
33
- num_classes is inferred from 1D, 2D, or 3D truth inputs by default (default timesteps = -1 which implies None).
34
- Only set timesteps > 0 in case of ambiguous binary 2D case where input [frames, timesteps],
35
- then it must set to the number of timesteps (which will be > 0).
36
- It is safe to always set timesteps <= 0 for binary inputs, and if truth.shape[2] exists
37
-
38
- returns metrics over all frames + timesteps:
39
- mcm [num_classes, 2, 2] multiclass confusion matrix count ove
40
- metrics [num_classes, 14] [ACC, TPR, PPV, TNR, FPR, HITFA, F1, MCC, NT, PT, TP, FP, AP, AUC]
41
- cm [num_classes, num_classes] confusion matrix
42
- cmn [num_classes, num_classes] normalized confusion matrix
43
- rmse [num_classes, 1] RMS error over all frames + timesteps, before threshold decision
44
- mavg [3, 8] averages macro, micro, weighted [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
7
+ def one_hot(
8
+ truth: Truth,
9
+ predict: Predict,
10
+ predict_thr: float | np.ndarray = 0,
11
+ truth_thr: float = 0.5,
12
+ timesteps: int = -1,
13
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
14
+ """Calculates metrics from one-hot prediction and truth data (numpy float arrays) where
15
+ both are one-hot probabilities (or quantized decisions) for each class
16
+ with size [frames, num_classes] or [frames, timesteps, num_classes].
17
+ For metrics that require it, truth and pred decisions will be made using threshold >= predict_thr.
18
+ Some metrics like AP and AUC do not depend on predict_thr for predict, but still use truth >= predict_thr
19
+
20
+ predict_thr sets the decision threshold(s) applied to predict data for some metrics, thus allowing
21
+ the input to be continuous probabilities, for AUC-type metrics and root-mean-square error (rmse).
22
+ 1. Default = 0 (multiclass or binary) which infers:
23
+ binary (num_classes = 1) use >= 0.5 for truth and pred (same as argmax() for binary)
24
+ multi-class/single-label if truth_mutex= = true, use argmax() used on both truth and pred
25
+ note multilabel metrics are disabled for predict_thr = 0, must set predict_thr > 0
26
+
27
+ 2. predict_thr > 0 (multilabel or binary) scalar or a vector [num_classes, 1] then use
28
+ predict_thr as a binary decision threshold in each class:
29
+ binary (num_classes = 1) use >= predict_thr[0] for pred and predict_thr[num_classes+1] for truth
30
+ if it exists, else use >= 0.5 for truth
31
+ multilabel use >= predict_thr for pred if scalar, or predict_thr[class_idx] if vector
32
+ use >= predict_thr[num_classes+1] for truth if exists, else 0.5
33
+ note multi-class/single-label inputs are meaningless in this mode, use predict_thr = 0 argmax mode
34
+
35
+ num_classes is inferred from 1D, 2D, or 3D truth inputs by default (default timesteps = -1 which implies None).
36
+ Only set timesteps > 0 in case of ambiguous binary 2D case where input [frames, timesteps],
37
+ then it must set to the number of timesteps (which will be > 0).
38
+ It is safe to always set timesteps <= 0 for binary inputs, and if truth.shape[2] exists
39
+
40
+ returns metrics over all frames + timesteps:
41
+ mcm [num_classes, 2, 2] multiclass confusion matrix count ove
42
+ metrics [num_classes, 14] [ACC, TPR, PPV, TNR, FPR, HITFA, F1, MCC, NT, PT, TP, FP, AP, AUC]
43
+ cm [num_classes, num_classes] confusion matrix
44
+ cmn [num_classes, num_classes] normalized confusion matrix
45
+ rmse [num_classes, 1] RMS error over all frames + timesteps, before threshold decision
46
+ mavg [3, 8] averages macro, micro, weighted [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
45
47
  """
46
48
  import warnings
47
49
 
@@ -51,14 +53,13 @@ def one_hot(truth: Truth,
51
53
  from sklearn.metrics import precision_recall_fscore_support
52
54
  from sklearn.metrics import roc_auc_score
53
55
 
54
- from sonusai import SonusAIError
55
56
  from sonusai.utils import get_num_classes_from_predict
56
57
  from sonusai.utils import reshape_outputs
57
58
 
58
59
  if truth.shape != predict.shape:
59
- raise SonusAIError('truth and predict are not the same shape')
60
+ raise ValueError("truth and predict are not the same shape")
60
61
 
61
- predict, truth = reshape_outputs(predict=predict, truth=truth, timesteps=timesteps)
62
+ predict, truth = reshape_outputs(predict=predict, truth=truth, timesteps=timesteps) # type: ignore[assignment]
62
63
  num_classes = get_num_classes_from_predict(predict=predict, timesteps=timesteps)
63
64
 
64
65
  # Regression metric root-mean-square-error always works
@@ -79,7 +80,8 @@ def one_hot(truth: Truth,
79
80
  else:
80
81
  if predict_thr.ndim > 1:
81
82
  # multilabel with custom thr vector
82
- assert predict_thr.shape[0] == num_classes
83
+ if predict_thr.shape[0] != num_classes:
84
+ raise ValueError("predict_thr has wrong shape")
83
85
  else:
84
86
  if predict_thr == 0:
85
87
  # binary or multilabel scalar default
@@ -89,18 +91,18 @@ def one_hot(truth: Truth,
89
91
  predict_thr = np.atleast_1d(predict_thr)
90
92
 
91
93
  if not isinstance(predict_thr, np.ndarray):
92
- raise SonusAIError(f'predict_thr is invalid type: {type(predict_thr)}')
94
+ raise TypeError(f"predict_thr is invalid type: {type(predict_thr)}")
93
95
 
94
96
  # Convert continuous probabilities to binary via argmax() or threshold comparison
95
97
  # and create labels of int encoded (0:num_classes-1), and then equivalent one-hot
96
98
  if num_classes == 1: # If binary
97
- labels = ([i for i in range(0, 2)]) # int encoded 0,1
99
+ labels = list(range(0, 2)) # int encoded 0,1
98
100
  plabel = np.int8(predict >= predict_thr) # [frames, 1], default 0.5 is equiv. to argmax()
99
101
  tlabel = np.int8(truth >= truth_thr) # [frames, 1]
100
102
  predb = np.array(plabel)
101
103
  truthb = np.array(tlabel)
102
104
  else:
103
- labels = ([i for i in range(0, num_classes)]) # int encoded 0,...,num_classes-1
105
+ labels = list(range(0, num_classes)) # int encoded 0,...,num_classes-1
104
106
  if predict_thr[0] == 0: # multiclass single-label (mutex), use argmax
105
107
  plabel = np.argmax(predict, axis=-1) # [frames, 1] labels
106
108
  tlabel = np.argmax(truth, axis=-1) # [frames, 1] labels
@@ -134,7 +136,7 @@ def one_hot(truth: Truth,
134
136
  mcm = mcm[1:] # remove dim 0 if binary
135
137
 
136
138
  # Create [num_classes, num_classes] normalized confusion matrix
137
- cmn = confusion_matrix(tlabel, plabel, labels=labels, normalize='true')
139
+ cmn = confusion_matrix(tlabel, plabel, labels=labels, normalize="true")
138
140
 
139
141
  # Create [num_classes, num_classes] confusion matrix
140
142
  cm = confusion_matrix(tlabel, plabel, labels=labels)
@@ -194,7 +196,22 @@ def one_hot(truth: Truth,
194
196
  # ix = np.argmax(fscore) # index of largest f1 score
195
197
  # threshold_optpr[nci] = thresholds[ix]
196
198
 
197
- metrics[nci, :] = [ACC, TPR, PPV, TNR, FPR, HITFA, F1, MCC, NT, PT, TP, FP, AP, AUC]
199
+ metrics[nci, :] = [
200
+ ACC,
201
+ TPR,
202
+ PPV,
203
+ TNR,
204
+ FPR,
205
+ HITFA,
206
+ F1,
207
+ MCC,
208
+ NT,
209
+ PT,
210
+ TP,
211
+ FP,
212
+ AP,
213
+ AUC,
214
+ ]
198
215
 
199
216
  # Calculate averages into single array, 3 types for now Macro, Micro, Weighted
200
217
  mavg = np.zeros((3, 8), dtype=np.float32)
@@ -202,9 +219,17 @@ def one_hot(truth: Truth,
202
219
 
203
220
  # macro average [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
204
221
  with warnings.catch_warnings():
205
- warnings.filterwarnings(action='ignore', message='Mean of empty slice')
206
- mavg[0, :] = [np.mean(metrics[:, 2]), np.mean(metrics[:, 1]), np.mean(metrics[:, 6]), np.mean(metrics[:, 4]),
207
- np.mean(metrics[:, 0]), np.nanmean(metrics[:, 12]), np.nanmean(metrics[:, 13]), s]
222
+ warnings.filterwarnings(action="ignore", message="Mean of empty slice")
223
+ mavg[0, :] = [
224
+ np.mean(metrics[:, 2]),
225
+ np.mean(metrics[:, 1]),
226
+ np.mean(metrics[:, 6]),
227
+ np.mean(metrics[:, 4]),
228
+ np.mean(metrics[:, 0]),
229
+ np.nanmean(metrics[:, 12]),
230
+ np.nanmean(metrics[:, 13]),
231
+ s,
232
+ ]
208
233
 
209
234
  # micro average, micro-F1 = micro-precision = micro-recall = accuracy
210
235
  if num_classes > 1:
@@ -218,25 +243,34 @@ def one_hot(truth: Truth,
218
243
  tn_sum = sum(mcm[:, 0, 0])
219
244
  accm = (tp_sum + tn_sum) / (tp_sum + tn_sum + fp_sum + fn_sum + eps)
220
245
  with warnings.catch_warnings():
221
- warnings.filterwarnings(action='ignore', message='invalid value encountered in true_divide')
222
- miap = average_precision_score(truthb, predict, average='micro')
246
+ warnings.filterwarnings(action="ignore", message="invalid value encountered in true_divide")
247
+ miap = average_precision_score(truthb, predict, average="micro")
223
248
  if np.sum(truthb): # no activity over all classes
224
- miauc = roc_auc_score(truthb, predict, average='micro')
249
+ miauc = roc_auc_score(truthb, predict, average="micro")
225
250
  else:
226
251
  miauc = np.NaN
227
252
 
228
253
  # [miPPV, miTPR, miF1, miFPR, miACC, miAP, miAUC, TPSUM]
229
- mavg[1, :] = [pm, rm, f1m, fpm, accm, miap, miauc, s] # specific format, last 3 are unique
254
+ mavg[1, :] = [
255
+ pm,
256
+ rm,
257
+ f1m,
258
+ fpm,
259
+ accm,
260
+ miap,
261
+ miauc,
262
+ s,
263
+ ] # specific format, last 3 are unique
230
264
 
231
265
  # weighted average TBD
232
- wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average='weighted', zero_division=0)
266
+ wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average="weighted", zero_division=0)
233
267
  if np.sum(truthb):
234
268
  taidx = np.sum(truthb, axis=0) > 0
235
- wap = average_precision_score(truthb[:, taidx], predict[:, taidx], average='weighted')
269
+ wap = average_precision_score(truthb[:, taidx], predict[:, taidx], average="weighted")
236
270
  if len(np.unique(truthb[:, taidx])) < 2:
237
271
  wauc = np.NaN
238
272
  else:
239
- wauc = roc_auc_score(truthb[:, taidx], predict[:, taidx], average='weighted')
273
+ wauc = roc_auc_score(truthb[:, taidx], predict[:, taidx], average="weighted")
240
274
  else:
241
275
  wap = np.NaN
242
276
  wauc = np.NaN
@@ -1,3 +1,4 @@
1
+ # ruff: noqa: F821
1
2
  import numpy as np
2
3
  import pandas as pd
3
4
 
@@ -8,32 +9,34 @@ from sonusai.mixture import Segsnr
8
9
  from sonusai.mixture import Truth
9
10
 
10
11
 
11
- def snr_summary(mixdb: MixtureDatabase,
12
- mixid: GeneralizedIDs,
13
- truth_f: Truth,
14
- predict: Predict,
15
- segsnr: Segsnr = None,
16
- predict_thr: float | np.ndarray = 0,
17
- truth_thr: float = 0.5,
18
- timesteps: int = 0) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, dict]:
12
+ def snr_summary(
13
+ mixdb: MixtureDatabase,
14
+ mixid: GeneralizedIDs,
15
+ truth_f: Truth,
16
+ predict: Predict,
17
+ segsnr: Segsnr | None = None,
18
+ predict_thr: float | np.ndarray = 0,
19
+ truth_thr: float = 0.5,
20
+ timesteps: int = 0,
21
+ ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, dict]:
19
22
  """Calculate average-over-class metrics per SNR over specified mixture list.
20
- Inputs:
21
- mixdb Mixture database
22
- mixid
23
- truth_f Truth/labels [features, num_classes]
24
- predict Prediction data / neural net model one-hot out [features, num_classes]
25
- segsnr Segmental SNR from SonusAI genft [transform_frames, 1]
26
- predict_thr Decision threshold(s) applied to predict data, allowing predict to be
27
- continuous probabilities or decisions
28
- truth_thr Decision threshold(s) applied to truth data, allowing truth to be
29
- continuous probabilities or decisions
30
- timesteps
31
-
32
- Default predict_thr=0 will infer 0.5 for multi-label mode (truth_mutex = False), or
33
- if single-label mode (truth_mutex == True) then ignore and use argmax mode, and
34
- the confusion matrix is calculated for all classes.
35
-
36
- Returns pandas dataframe (snrdf) of metrics per SNR.
23
+ Inputs:
24
+ mixdb Mixture database
25
+ mixid
26
+ truth_f Truth/labels [features, num_classes]
27
+ predict Prediction data / neural net model one-hot out [features, num_classes]
28
+ segsnr Segmental SNR from SonusAI genft [transform_frames, 1]
29
+ predict_thr Decision threshold(s) applied to predict data, allowing predict to be
30
+ continuous probabilities or decisions
31
+ truth_thr Decision threshold(s) applied to truth data, allowing truth to be
32
+ continuous probabilities or decisions
33
+ timesteps
34
+
35
+ Default predict_thr=0 will infer 0.5 for multi-label mode (truth_mutex = False), or
36
+ if single-label mode (truth_mutex == True) then ignore and use argmax mode, and
37
+ the confusion matrix is calculated for all classes.
38
+
39
+ Returns pandas dataframe (snrdf) of metrics per SNR.
37
40
  """
38
41
  import warnings
39
42
 
@@ -53,14 +56,13 @@ def snr_summary(mixdb: MixtureDatabase,
53
56
  else:
54
57
  predict_thr = np.atleast_1d(predict_thr)
55
58
  else:
56
- if predict_thr.ndim == 1:
57
- if len(predict_thr) == 1:
58
- if predict_thr[0] == 0:
59
- # multi-label predict_thr array scalar 0 force to 0.5 default
60
- predict_thr = np.atleast_1d(0.5)
61
- else:
62
- # multi-label predict_thr array set to scalar = array[0]
63
- predict_thr = predict_thr[0]
59
+ if predict_thr.ndim == 1 and len(predict_thr) == 1:
60
+ if predict_thr[0] == 0:
61
+ # multi-label predict_thr array scalar 0 force to 0.5 default
62
+ predict_thr = np.atleast_1d(0.5)
63
+ else:
64
+ # multi-label predict_thr array set to scalar = array[0]
65
+ predict_thr = predict_thr[0]
64
66
 
65
67
  macro_avg = np.zeros((len(snr_mixids), 7), dtype=np.float32)
66
68
  micro_avg = np.zeros((len(snr_mixids), 7), dtype=np.float32)
@@ -72,13 +74,16 @@ def snr_summary(mixdb: MixtureDatabase,
72
74
  # prep segsnr if provided, transform frames to feature frames via mean()
73
75
  # expected to always be an integer
74
76
  feature_frames = int(segsnr.shape[0] / truth_f.shape[0])
75
- segsnr_f = np.mean(np.reshape(segsnr, (truth_f.shape[0], feature_frames)), axis=1, keepdims=True)
77
+ segsnr_f = np.mean(
78
+ np.reshape(segsnr, (truth_f.shape[0], feature_frames)),
79
+ axis=1,
80
+ keepdims=True,
81
+ )
76
82
  ssnr_stats = np.zeros((len(snr_mixids), 3), dtype=np.float32)
77
83
 
78
- ii = 0
79
- for snr in snr_mixids:
84
+ for ii, snr in enumerate(snr_mixids):
80
85
  # TODO: re-work for modern mixdb API
81
- y_truth, y_predict = get_mixids_data(mixdb, snr_mixids[snr], truth_f, predict) # type: ignore
86
+ y_truth, y_predict = get_mixids_data(mixdb, snr_mixids[snr], truth_f, predict) # type: ignore[name-defined]
82
87
  _, metrics, _, _, _, mavg = one_hot(y_truth, y_predict, predict_thr, truth_thr, timesteps)
83
88
 
84
89
  # mavg macro, micro, weighted: [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
@@ -87,20 +92,18 @@ def snr_summary(mixdb: MixtureDatabase,
87
92
  wghtd_avg[ii, :] = mavg[2, 0:7]
88
93
  if segsnr is not None:
89
94
  # TODO: re-work for modern mixdb API
90
- y_truth, y_segsnr = get_mixids_data(mixdb, snr_mixids[snr], truth_f, segsnr_f) # type: ignore
95
+ y_truth, y_segsnr = get_mixids_data(mixdb, snr_mixids[snr], truth_f, segsnr_f) # type: ignore[name-defined]
91
96
  with warnings.catch_warnings():
92
- warnings.filterwarnings(action='ignore', message='divide by zero encountered in log10')
97
+ warnings.filterwarnings(action="ignore", message="divide by zero encountered in log10")
93
98
  # segmental SNR mean = mixture_snr and target_snr
94
- ssnr_stats[ii, 0] = 10 * np.log10(np.mean(y_segsnr))
99
+ ssnr_stats[ii, 0] = 10 * np.log10(np.mean(y_segsnr)) # type: ignore[index]
95
100
  # segmental SNR 80% percentile
96
- ssnr_stats[ii, 1] = 10 * np.log10(np.percentile(y_segsnr, 80, method='midpoint'))
101
+ ssnr_stats[ii, 1] = 10 * np.log10(np.percentile(y_segsnr, 80, method="midpoint")) # type: ignore[index]
97
102
  # segmental SNR max
98
- ssnr_stats[ii, 2] = 10 * np.log10(max(y_segsnr))
99
-
100
- ii += 1
103
+ ssnr_stats[ii, 2] = 10 * np.log10(max(y_segsnr)) # type: ignore[index]
101
104
 
102
105
  # SNR format: PPV, TPR, F1, FPR, ACC, AP, AUC
103
- col_n = ['PPV', 'TPR', 'F1', 'FPR', 'ACC', 'AP', 'AUC']
106
+ col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC"]
104
107
  snr_macrodf = pd.DataFrame(macro_avg, index=list(snr_mixids.keys()), columns=col_n)
105
108
  snr_macrodf.sort_index(ascending=False, inplace=True)
106
109
 
@@ -112,7 +115,11 @@ def snr_summary(mixdb: MixtureDatabase,
112
115
 
113
116
  # Add segmental SNR columns if provided
114
117
  if segsnr is not None:
115
- ssnrdf = pd.DataFrame(ssnr_stats, index=list(snr_mixids.keys()), columns=['SSNRavg', 'SSNR80p', 'SSNRmax'])
118
+ ssnrdf = pd.DataFrame(
119
+ ssnr_stats,
120
+ index=list(snr_mixids.keys()),
121
+ columns=["SSNRavg", "SSNR80p", "SSNRmax"],
122
+ )
116
123
  ssnrdf.sort_index(ascending=False, inplace=True)
117
124
  snr_macrodf = pd.concat([snr_macrodf, ssnrdf], axis=1)
118
125
  snr_microdf = pd.concat([snr_microdf, ssnrdf], axis=1)
@@ -1,4 +1,6 @@
1
1
  # SonusAI mixture utilities
2
+ # ruff: noqa: F401
3
+
2
4
  from .audio import get_duration
3
5
  from .audio import get_next_noise
4
6
  from .audio import get_num_samples
@@ -19,15 +21,15 @@ from .augmentation import pad_audio_to_length
19
21
  from .class_count import get_class_count_from_mixids
20
22
  from .config import get_default_config
21
23
  from .config import get_impulse_response_files
22
- from .config import get_max_class
23
24
  from .config import get_noise_files
24
25
  from .config import get_spectral_masks
25
26
  from .config import get_target_files
27
+ from .config import get_truth_parameters
26
28
  from .config import load_config
27
29
  from .config import raw_load_config
28
30
  from .config import update_config_from_file
29
31
  from .config import update_config_from_hierarchy
30
- from .config import update_truth_settings
32
+ from .config import validate_truth_configs
31
33
  from .constants import BIT_DEPTH
32
34
  from .constants import CHANNEL_COUNT
33
35
  from .constants import DEFAULT_CONFIG
@@ -35,19 +37,22 @@ from .constants import DEFAULT_NOISE
35
37
  from .constants import DEFAULT_SPEECH
36
38
  from .constants import ENCODING
37
39
  from .constants import FLOAT_BYTES
40
+ from .constants import MIXDB_VERSION
38
41
  from .constants import RAND_PATTERN
39
42
  from .constants import REQUIRED_CONFIGS
43
+ from .constants import REQUIRED_TRUTH_CONFIGS
40
44
  from .constants import SAMPLE_BYTES
41
45
  from .constants import SAMPLE_RATE
42
46
  from .constants import VALID_AUGMENTATIONS
43
47
  from .constants import VALID_CONFIGS
44
48
  from .constants import VALID_NOISE_MIX_MODES
45
- from .constants import VALID_TRUTH_SETTINGS
49
+ from .data_io import read_cached_data
50
+ from .data_io import write_cached_data
46
51
  from .datatypes import AudioF
47
- from .datatypes import AudioStatsMetrics
48
- from .datatypes import AudioT
49
52
  from .datatypes import AudiosF
50
53
  from .datatypes import AudiosT
54
+ from .datatypes import AudioStatsMetrics
55
+ from .datatypes import AudioT
51
56
  from .datatypes import Augmentation
52
57
  from .datatypes import AugmentationRule
53
58
  from .datatypes import AugmentationRules
@@ -60,10 +65,11 @@ from .datatypes import EnergyT
60
65
  from .datatypes import Feature
61
66
  from .datatypes import FeatureGeneratorConfig
62
67
  from .datatypes import FeatureGeneratorInfo
68
+ from .datatypes import GeneralizedIDs
63
69
  from .datatypes import GenFTData
64
70
  from .datatypes import GenMixData
65
- from .datatypes import GeneralizedIDs
66
71
  from .datatypes import ImpulseResponseData
72
+ from .datatypes import ImpulseResponseFile
67
73
  from .datatypes import ImpulseResponseFiles
68
74
  from .datatypes import ListAudiosT
69
75
  from .datatypes import MetricDoc
@@ -84,9 +90,10 @@ from .datatypes import TargetFile
84
90
  from .datatypes import TargetFiles
85
91
  from .datatypes import TransformConfig
86
92
  from .datatypes import Truth
87
- from .datatypes import TruthFunctionConfig
88
- from .datatypes import TruthSetting
89
- from .datatypes import TruthSettings
93
+ from .datatypes import TruthConfig
94
+ from .datatypes import TruthConfigs
95
+ from .datatypes import TruthParameter
96
+ from .datatypes import TruthParameters
90
97
  from .datatypes import UniversalSNR
91
98
  from .feature import get_audio_from_feature
92
99
  from .feature import get_feature_from_audio
@@ -101,6 +108,7 @@ from .generation import populate_noise_file_table
101
108
  from .generation import populate_spectral_mask_table
102
109
  from .generation import populate_target_file_table
103
110
  from .generation import populate_top_table
111
+ from .generation import populate_truth_parameters_table
104
112
  from .generation import update_mixid_width
105
113
  from .generation import update_mixture
106
114
  from .helpers import augmented_noise_samples
@@ -112,11 +120,9 @@ from .helpers import get_audio_from_transform
112
120
  from .helpers import get_ft
113
121
  from .helpers import get_segsnr
114
122
  from .helpers import get_transform_from_audio
115
- from .helpers import get_truth_t
123
+ from .helpers import get_truth
116
124
  from .helpers import inverse_transform
117
125
  from .helpers import mixture_metadata
118
- from .helpers import read_mixture_data
119
- from .helpers import write_mixture_data
120
126
  from .helpers import write_mixture_metadata
121
127
  from .log_duration_and_sizes import log_duration_and_sizes
122
128
  from .mixdb import MixtureDatabase
@@ -128,9 +134,8 @@ from .targets import get_augmented_target_ids_by_class
128
134
  from .targets import get_augmented_target_ids_for_mixup
129
135
  from .targets import get_augmented_targets
130
136
  from .targets import get_target_augmentations_for_mixup
131
- from .targets import get_truth_indices_for_target
132
137
  from .tokenized_shell_vars import tokenized_expand
133
138
  from .tokenized_shell_vars import tokenized_replace
134
139
  from .truth import get_truth_indices_for_mixid
135
140
  from .truth import truth_function
136
- from .truth import truth_reduction
141
+ from .truth import truth_stride_reduction
sonusai/mixture/audio.py CHANGED
@@ -15,7 +15,7 @@ def get_next_noise(audio: AudioT, offset: int, length: int) -> AudioT:
15
15
  """
16
16
  import numpy as np
17
17
 
18
- return np.take(audio, range(offset, offset + length), mode='wrap')
18
+ return np.take(audio, range(offset, offset + length), mode="wrap")
19
19
 
20
20
 
21
21
  def get_duration(audio: AudioT) -> float:
@@ -35,15 +35,13 @@ def validate_input_file(input_filepath: str | Path) -> None:
35
35
 
36
36
  from soundfile import available_formats
37
37
 
38
- from sonusai import SonusAIError
39
-
40
38
  if not exists(input_filepath):
41
- raise SonusAIError(f'input_filepath {input_filepath} does not exist.')
39
+ raise OSError(f"input_filepath {input_filepath} does not exist.")
42
40
 
43
41
  ext = splitext(input_filepath)[1][1:].lower()
44
- read_formats = [item.lower() for item in available_formats().keys()]
42
+ read_formats = [item.lower() for item in available_formats()]
45
43
  if ext not in read_formats:
46
- raise SonusAIError(f'This installation cannot process .{ext} files')
44
+ raise OSError(f"This installation cannot process .{ext} files")
47
45
 
48
46
 
49
47
  @lru_cache