sonusai 1.0.16__cp311-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,283 @@
1
+ import numpy as np
2
+
3
+ from ..datatypes import Predict
4
+ from ..datatypes import Truth
5
+
6
+
7
+ def one_hot(
8
+ truth: Truth,
9
+ predict: Predict,
10
+ predict_thr: float | np.ndarray = 0,
11
+ truth_thr: float = 0.5,
12
+ timesteps: int = -1,
13
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
14
+ """Calculates metrics from one-hot prediction and truth data (numpy float arrays) where
15
+ both are one-hot probabilities (or quantized decisions) for each class
16
+ with size [frames, num_classes] or [frames, timesteps, num_classes].
17
+ For metrics that require it, truth and pred decisions will be made using threshold >= predict_thr.
18
+ Some metrics like AP and AUC do not depend on predict_thr for predict, but still use truth >= predict_thr
19
+
20
+ predict_thr sets the decision threshold(s) applied to predict data for some metrics, thus allowing
21
+ the input to be continuous probabilities, for AUC-type metrics and root-mean-square error (rmse).
22
+ 1. Default = 0 (multiclass or binary) which infers:
23
+ binary (num_classes = 1) use >= 0.5 for truth and pred (same as argmax() for binary)
24
+ multi-class/single-label if truth_mutex= = true, use argmax() used on both truth and pred
25
+ note multilabel metrics are disabled for predict_thr = 0, must set predict_thr > 0
26
+
27
+ 2. predict_thr > 0 (multilabel or binary) scalar or a vector [num_classes, 1] then use
28
+ predict_thr as a binary decision threshold in each class:
29
+ binary (num_classes = 1) use >= predict_thr[0] for pred and predict_thr[num_classes+1] for truth
30
+ if it exists, else use >= 0.5 for truth
31
+ multilabel use >= predict_thr for pred if scalar, or predict_thr[class_idx] if vector
32
+ use >= predict_thr[num_classes+1] for truth if exists, else 0.5
33
+ note multi-class/single-label inputs are meaningless in this mode, use predict_thr = 0 argmax mode
34
+
35
+ num_classes is inferred from 1D, 2D, or 3D truth inputs by default (default timesteps = -1 which implies None).
36
+ Only set timesteps > 0 in case of ambiguous binary 2D case where input [frames, timesteps],
37
+ then it must set to the number of timesteps (which will be > 0).
38
+ It is safe to always set timesteps <= 0 for binary inputs, and if truth.shape[2] exists
39
+
40
+ returns metrics over all frames + timesteps:
41
+ mcm [num_classes, 2, 2] multiclass confusion matrix count ove
42
+ metrics [num_classes, 14] [ACC, TPR, PPV, TNR, FPR, HITFA, F1, MCC, NT, PT, TP, FP, AP, AUC]
43
+ cm [num_classes, num_classes] confusion matrix
44
+ cmn [num_classes, num_classes] normalized confusion matrix
45
+ rmse [num_classes, 1] RMS error over all frames + timesteps, before threshold decision
46
+ mavg [3, 8] averages macro, micro, weighted [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
47
+ """
48
+ import warnings
49
+
50
+ from sklearn.metrics import average_precision_score
51
+ from sklearn.metrics import confusion_matrix
52
+ from sklearn.metrics import multilabel_confusion_matrix
53
+ from sklearn.metrics import precision_recall_fscore_support
54
+ from sklearn.metrics import roc_auc_score
55
+
56
+ from ..utils.reshape import get_num_classes_from_predict
57
+ from ..utils.reshape import reshape_outputs
58
+
59
+ if truth.shape != predict.shape:
60
+ raise ValueError("truth and predict are not the same shape")
61
+
62
+ predict, truth = reshape_outputs(predict=predict, truth=truth, timesteps=timesteps) # type: ignore[assignment]
63
+ num_classes = get_num_classes_from_predict(predict=predict, timesteps=timesteps)
64
+
65
+ # Regression metric root-mean-square-error always works
66
+ rmse = np.sqrt(np.mean(np.square(truth - predict), axis=0))
67
+
68
+ # Calculate default predict decision thresholds based on mode
69
+ if not isinstance(predict_thr, np.ndarray):
70
+ if predict_thr == 0:
71
+ # if scalar and 0, set defaults
72
+ if num_classes == 1:
73
+ # binary case default >= 0.5 which is equiv to argmax()
74
+ predict_thr = np.atleast_1d(0.5)
75
+ else:
76
+ # multiclass, single-label (argmax mode)
77
+ predict_thr = np.atleast_1d(0)
78
+ else:
79
+ predict_thr = np.atleast_1d(predict_thr)
80
+ else:
81
+ if predict_thr.ndim > 1:
82
+ # multilabel with custom thr vector
83
+ if predict_thr.shape[0] != num_classes:
84
+ raise ValueError("predict_thr has wrong shape")
85
+ else:
86
+ if predict_thr == 0:
87
+ # binary or multilabel scalar default
88
+ predict_thr = np.atleast_1d(0.5)
89
+ else:
90
+ # user specified binary or multilabel scalar
91
+ predict_thr = np.atleast_1d(predict_thr)
92
+
93
+ if not isinstance(predict_thr, np.ndarray):
94
+ raise TypeError(f"predict_thr is invalid type: {type(predict_thr)}")
95
+
96
+ # Convert continuous probabilities to binary via argmax() or threshold comparison
97
+ # and create labels of int encoded (0:num_classes-1), and then equivalent one-hot
98
+ if num_classes == 1: # If binary
99
+ labels = list(range(0, 2)) # int encoded 0,1
100
+ plabel = np.int8(predict >= predict_thr) # [frames, 1], default 0.5 is equiv. to argmax()
101
+ tlabel = np.int8(truth >= truth_thr) # [frames, 1]
102
+ predb = np.array(plabel)
103
+ truthb = np.array(tlabel)
104
+ else:
105
+ labels = list(range(0, num_classes)) # int encoded 0,...,num_classes-1
106
+ if predict_thr[0] == 0: # multiclass single-label (mutex), use argmax
107
+ plabel = np.argmax(predict, axis=-1) # [frames, 1] labels
108
+ tlabel = np.argmax(truth, axis=-1) # [frames, 1] labels
109
+ # one-hot binary
110
+ predb = np.zeros(predict.shape, dtype=np.int8) # [frames, num_classes]
111
+ truthb = np.zeros(truth.shape, dtype=np.int8) # [frames, num_classes]
112
+ predb[np.arange(predb.shape[0]), plabel] = 1 # single-label [frames, num_classes]
113
+ if np.sum(truth): # special case all zero truth leave tlabel all zeros
114
+ truthb[np.arange(truthb.shape[0]), tlabel] = 1 # single-label [frames, num_classes]
115
+ else: # multilabel prob threshold comparison (multiple classes)
116
+ # multilabel one-hot decision
117
+ predb = np.array(predict >= predict_thr.transpose()).astype(np.int8) # [frames, num_classes]
118
+ truthb = np.array(truth >= truth_thr).astype(np.int8) # [frames, num_classes]
119
+ # Return argmax() for optional single-label confusion matrix metrics
120
+ plabel = np.argmax(predict, axis=-1) # [frames, 1] labels
121
+ tlabel = np.argmax(truth, axis=-1) # [frames, 1] labels
122
+
123
+ # debug checks to understand ap, auc:
124
+ # from sklearn.metrics import roc_curve
125
+ # fpr, tpr, thr = roc_curve(truthb[:,0],predict[:,0],drop_intermediate=False)
126
+ # from sklearn.metrics import precision_recall_curve
127
+ # precision, recall, thr = precision_recall_curve(truthb[:,0], predict[:,0])
128
+ # from sklearn.metrics import RocCurveDisplay
129
+ # RocCurveDisplay.from_predictions(truthb[:,0],predict[:,0]) # Plot ROC class0
130
+
131
+ # Create [num_classes, 2, 2] multilabel confusion matrix (mcm)
132
+ # Note - must include labels or sklearn func. will omit non-exiting classes
133
+ mcm = multilabel_confusion_matrix(truthb, predb, labels=labels)
134
+
135
+ if num_classes == 1:
136
+ mcm = mcm[1:] # remove dim 0 if binary
137
+
138
+ # Create [num_classes, num_classes] normalized confusion matrix
139
+ cmn = confusion_matrix(tlabel, plabel, labels=labels, normalize="true")
140
+
141
+ # Create [num_classes, num_classes] confusion matrix
142
+ cm = confusion_matrix(tlabel, plabel, labels=labels)
143
+
144
+ # Combine all per-class metrics into a single array
145
+ # [ACC, TPR, PPV, TNR, FPR, HITFA, F1, MCC, NT, PT, TP, FP, AP, AUC]
146
+ metrics = np.zeros((num_classes, 14))
147
+ # threshold_optpr = np.zeros((num_classes, 1))
148
+ eps = np.finfo(float).eps
149
+ for nci in range(num_classes):
150
+ # True negative
151
+ TN = mcm[nci, 0, 0]
152
+ # False positive
153
+ FP = mcm[nci, 0, 1]
154
+ # False negative
155
+ FN = mcm[nci, 1, 0]
156
+ # True positive
157
+ TP = mcm[nci, 1, 1]
158
+ # Accuracy
159
+ ACC = (TP + TN) / (TP + TN + FP + FN + eps)
160
+ # True positive rate, sensitivity, recall, hit rate (note eps in numerator)
161
+ # When ``true positive + false negative == 0``, recall is undefined, set to 0
162
+ TPR = TP / (TP + FN + eps)
163
+ # Precision, positive predictive value
164
+ # When ``true positive + false positive == 0``, precision is undefined, set to 0
165
+ PPV = TP / (TP + FP + eps)
166
+ # Specificity i.e., selectivity, or true negative rate
167
+ TNR = TN / (TN + FP + eps)
168
+ # False positive rate = 1-specificity, roc x-axis
169
+ FPR = FP / (TN + FP + eps)
170
+ # HitFA used by some separation research, close match to MCC
171
+ HITFA = TPR - FPR
172
+ # F1 harmonic mean of precision, recall = 2*PPV*TPR / (PPV + TPR)
173
+ F1 = 2 * TP / (2 * TP + FP + FN + eps)
174
+ # Matthew correlation coefficient
175
+ MCC = (TP * TN - FP * FN) / (np.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN)) + eps)
176
+ # Num. negatives total (truth), also = TN+FP denom of FPR
177
+ NT = sum(mcm[nci, 0, :])
178
+ # Num. positives total (truth), also = FN+TP denom of TPR, precision
179
+ PT = sum(mcm[nci, 1, :])
180
+ # Average Precision also called area under the PR curve AUCPR and
181
+ # AUC ROC curve using binary-ized truth and continuous prediction probabilities
182
+ # sklearn returns nan if no active truth in a class but w/un-suppressible div-by-zero warning
183
+ if np.sum(truthb[:, nci]) == 0: # if no active classes both sklearn will fail, set to NaN
184
+ AUC = np.NaN
185
+ AP = np.NaN
186
+ # threshold_optpr[nci] = np.NaN
187
+ else:
188
+ AP = average_precision_score(truthb[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
189
+ if len(np.unique(truthb[:, nci])) < 2: # if active classes not > 1 AUC must be NaN
190
+ AUC = np.NaN # i.e. all ones sklearn auc will fail
191
+ else:
192
+ AUC = roc_auc_score(truthb[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
193
+ # # Optimal threshold from PR curve, optimizes f-score
194
+ # precision, recall, thresholds = precision_recall_curve(truthb[:, nci], predict[:, nci])
195
+ # fscore = (2 * precision * recall) / (precision + recall)
196
+ # ix = np.argmax(fscore) # index of largest f1 score
197
+ # threshold_optpr[nci] = thresholds[ix]
198
+
199
+ metrics[nci, :] = [
200
+ ACC,
201
+ TPR,
202
+ PPV,
203
+ TNR,
204
+ FPR,
205
+ HITFA,
206
+ F1,
207
+ MCC,
208
+ NT,
209
+ PT,
210
+ TP,
211
+ FP,
212
+ AP,
213
+ AUC,
214
+ ]
215
+
216
+ # Calculate averages into single array, 3 types for now Macro, Micro, Weighted
217
+ mavg = np.zeros((3, 8), dtype=np.float32)
218
+ s = np.sum(metrics[:, 9].astype(int)) # support = sum (true pos total = FN+TP ) over classes
219
+
220
+ # macro average [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
221
+ with warnings.catch_warnings():
222
+ warnings.filterwarnings(action="ignore", message="Mean of empty slice")
223
+ mavg[0, :] = [
224
+ np.mean(metrics[:, 2]),
225
+ np.mean(metrics[:, 1]),
226
+ np.mean(metrics[:, 6]),
227
+ np.mean(metrics[:, 4]),
228
+ np.mean(metrics[:, 0]),
229
+ np.nanmean(metrics[:, 12]),
230
+ np.nanmean(metrics[:, 13]),
231
+ s,
232
+ ]
233
+
234
+ # micro average, micro-F1 = micro-precision = micro-recall = accuracy
235
+ if num_classes > 1:
236
+ tp_sum = np.sum(metrics[:, 10]) # TP all classes
237
+ rm = tp_sum / (np.sum(metrics[:, 9]) + eps) # micro mean PPV = TP / (PT=FN+TP)
238
+ fp_sum = np.sum(metrics[:, 11]) # FP false-positives all classes
239
+ fpm = fp_sum / (np.sum(metrics[:, 8]) + eps) # micro mean FPR = FP / (NT=TN+FP)
240
+ pm = tp_sum / (tp_sum + fp_sum + eps) # micro mean TPR = TP / (TP+FP) (note: same as rm for micro-avg)
241
+ fn_sum = sum(mcm[:, 1, 0])
242
+ f1m = 2 * tp_sum / (2 * tp_sum + fp_sum + fn_sum + eps)
243
+ tn_sum = sum(mcm[:, 0, 0])
244
+ accm = (tp_sum + tn_sum) / (tp_sum + tn_sum + fp_sum + fn_sum + eps)
245
+ with warnings.catch_warnings():
246
+ warnings.filterwarnings(action="ignore", message="invalid value encountered in true_divide")
247
+ miap = average_precision_score(truthb, predict, average="micro")
248
+ if np.sum(truthb): # no activity over all classes
249
+ miauc = roc_auc_score(truthb, predict, average="micro")
250
+ else:
251
+ miauc = np.NaN
252
+
253
+ # [miPPV, miTPR, miF1, miFPR, miACC, miAP, miAUC, TPSUM]
254
+ mavg[1, :] = [
255
+ pm,
256
+ rm,
257
+ f1m,
258
+ fpm,
259
+ accm,
260
+ miap,
261
+ miauc,
262
+ s,
263
+ ] # specific format, last 3 are unique
264
+
265
+ # weighted average TBD
266
+ wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average="weighted", zero_division=0) # pyright: ignore [reportArgumentType]
267
+ if np.sum(truthb):
268
+ taidx = np.sum(truthb, axis=0) > 0
269
+ wap = average_precision_score(truthb[:, taidx], predict[:, taidx], average="weighted")
270
+ if len(np.unique(truthb[:, taidx])) < 2:
271
+ wauc = np.NaN
272
+ else:
273
+ wauc = roc_auc_score(truthb[:, taidx], predict[:, taidx], average="weighted")
274
+ else:
275
+ wap = np.NaN
276
+ wauc = np.NaN
277
+
278
+ mavg[2, :] = [wp, wr, wf1, 0, 0, wap, wauc, s]
279
+ else: # binary case, all are same
280
+ mavg[1, :] = mavg[0, :]
281
+ mavg[2, :] = mavg[0, :]
282
+
283
+ return mcm, metrics, cm, cmn, rmse, mavg
@@ -0,0 +1,128 @@
1
+ # ruff: noqa: F821
2
+ import numpy as np
3
+ import pandas as pd
4
+
5
+ from ..datatypes import GeneralizedIDs
6
+ from ..datatypes import Predict
7
+ from ..datatypes import Segsnr
8
+ from ..datatypes import Truth
9
+ from ..mixture.mixdb import MixtureDatabase
10
+
11
+
12
+ def snr_summary(
13
+ mixdb: MixtureDatabase,
14
+ mixid: GeneralizedIDs,
15
+ truth_f: Truth,
16
+ predict: Predict,
17
+ segsnr: Segsnr | None = None,
18
+ predict_thr: float | np.ndarray = 0,
19
+ truth_thr: float = 0.5,
20
+ timesteps: int = 0,
21
+ ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, dict]:
22
+ """Calculate average-over-class metrics per SNR over specified mixture list.
23
+ Inputs:
24
+ mixdb Mixture database
25
+ mixid
26
+ truth_f Truth/labels [features, num_classes]
27
+ predict Prediction data / neural net model one-hot out [features, num_classes]
28
+ segsnr Segmental SNR from SonusAI genft [transform_frames, 1]
29
+ predict_thr Decision threshold(s) applied to predict data, allowing predict to be
30
+ continuous probabilities or decisions
31
+ truth_thr Decision threshold(s) applied to truth data, allowing truth to be
32
+ continuous probabilities or decisions
33
+ timesteps
34
+
35
+ Default predict_thr=0 will infer 0.5 for multi-label mode (truth_mutex = False), or
36
+ if single-label mode (truth_mutex == True) then ignore and use argmax mode, and
37
+ the confusion matrix is calculated for all classes.
38
+
39
+ Returns pandas dataframe (snrdf) of metrics per SNR.
40
+ """
41
+ import warnings
42
+
43
+ from ..metrics.one_hot import one_hot
44
+ from ..queries.queries import get_mixids_from_snr
45
+
46
+ num_classes = truth_f.shape[1]
47
+
48
+ snr_mixids = get_mixids_from_snr(mixdb=mixdb, mixids=mixid)
49
+
50
+ # Check predict_thr array or scalar and return final scalar predict_thr value
51
+ if num_classes > 1:
52
+ if not isinstance(predict_thr, np.ndarray):
53
+ if predict_thr == 0:
54
+ # multi-label predict_thr scalar 0 force to 0.5 default
55
+ predict_thr = np.atleast_1d(0.5)
56
+ else:
57
+ predict_thr = np.atleast_1d(predict_thr)
58
+ else:
59
+ if predict_thr.ndim == 1 and len(predict_thr) == 1:
60
+ if predict_thr[0] == 0:
61
+ # multi-label predict_thr array scalar 0 force to 0.5 default
62
+ predict_thr = np.atleast_1d(0.5)
63
+ else:
64
+ # multi-label predict_thr array set to scalar = array[0]
65
+ predict_thr = predict_thr[0]
66
+
67
+ macro_avg = np.zeros((len(snr_mixids), 7), dtype=np.float32)
68
+ micro_avg = np.zeros((len(snr_mixids), 7), dtype=np.float32)
69
+ wghtd_avg = np.zeros((len(snr_mixids), 7), dtype=np.float32)
70
+ ssnr_stats = None
71
+ segsnr_f = None
72
+
73
+ if segsnr is not None:
74
+ # prep segsnr if provided, transform frames to feature frames via mean()
75
+ # expected to always be an integer
76
+ feature_frames = int(segsnr.shape[0] / truth_f.shape[0])
77
+ segsnr_f = np.mean(
78
+ np.reshape(segsnr, (truth_f.shape[0], feature_frames)),
79
+ axis=1,
80
+ keepdims=True,
81
+ )
82
+ ssnr_stats = np.zeros((len(snr_mixids), 3), dtype=np.float32)
83
+
84
+ for ii, snr in enumerate(snr_mixids):
85
+ # TODO: re-work for modern mixdb API
86
+ y_truth, y_predict = get_mixids_data(mixdb, snr_mixids[snr], truth_f, predict) # type: ignore[name-defined]
87
+ _, _, _, _, _, mavg = one_hot(y_truth, y_predict, predict_thr, truth_thr, timesteps)
88
+
89
+ # mavg macro, micro, weighted: [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
90
+ macro_avg[ii, :] = mavg[0, 0:7]
91
+ micro_avg[ii, :] = mavg[1, 0:7]
92
+ wghtd_avg[ii, :] = mavg[2, 0:7]
93
+ if segsnr is not None:
94
+ # TODO: re-work for modern mixdb API
95
+ y_truth, y_segsnr = get_mixids_data(mixdb, snr_mixids[snr], truth_f, segsnr_f) # type: ignore[name-defined]
96
+ with warnings.catch_warnings():
97
+ warnings.filterwarnings(action="ignore", message="divide by zero encountered in log10")
98
+ # segmental SNR mean = mixture_snr and target_snr
99
+ ssnr_stats[ii, 0] = 10 * np.log10(np.mean(y_segsnr)) # type: ignore[index]
100
+ # segmental SNR 80% percentile
101
+ ssnr_stats[ii, 1] = 10 * np.log10(np.percentile(y_segsnr, 80, method="midpoint")) # type: ignore[index]
102
+ # segmental SNR max
103
+ ssnr_stats[ii, 2] = 10 * np.log10(max(y_segsnr)) # type: ignore[index]
104
+
105
+ # SNR format: PPV, TPR, F1, FPR, ACC, AP, AUC
106
+ col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC"]
107
+ snr_macrodf = pd.DataFrame(macro_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
108
+ snr_macrodf.sort_index(ascending=False, inplace=True)
109
+
110
+ snr_microdf = pd.DataFrame(micro_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
111
+ snr_microdf.sort_index(ascending=False, inplace=True)
112
+
113
+ snr_wghtdf = pd.DataFrame(wghtd_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
114
+ snr_wghtdf.sort_index(ascending=False, inplace=True)
115
+
116
+ # Add segmental SNR columns if provided
117
+ if segsnr is not None:
118
+ ssnrdf = pd.DataFrame(
119
+ ssnr_stats,
120
+ index=list(snr_mixids.keys()), # pyright: ignore [reportArgumentType]
121
+ columns=["SSNRavg", "SSNR80p", "SSNRmax"], # pyright: ignore [reportArgumentType]
122
+ )
123
+ ssnrdf.sort_index(ascending=False, inplace=True)
124
+ snr_macrodf = pd.concat([snr_macrodf, ssnrdf], axis=1)
125
+ snr_microdf = pd.concat([snr_microdf, ssnrdf], axis=1)
126
+ snr_wghtdf = pd.concat([snr_wghtdf, ssnrdf], axis=1)
127
+
128
+ return snr_macrodf, snr_microdf, snr_wghtdf, snr_mixids