sonusai 1.0.16__cp311-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,283 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
from ..datatypes import Predict
|
4
|
+
from ..datatypes import Truth
|
5
|
+
|
6
|
+
|
7
|
+
def one_hot(
|
8
|
+
truth: Truth,
|
9
|
+
predict: Predict,
|
10
|
+
predict_thr: float | np.ndarray = 0,
|
11
|
+
truth_thr: float = 0.5,
|
12
|
+
timesteps: int = -1,
|
13
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
14
|
+
"""Calculates metrics from one-hot prediction and truth data (numpy float arrays) where
|
15
|
+
both are one-hot probabilities (or quantized decisions) for each class
|
16
|
+
with size [frames, num_classes] or [frames, timesteps, num_classes].
|
17
|
+
For metrics that require it, truth and pred decisions will be made using threshold >= predict_thr.
|
18
|
+
Some metrics like AP and AUC do not depend on predict_thr for predict, but still use truth >= predict_thr
|
19
|
+
|
20
|
+
predict_thr sets the decision threshold(s) applied to predict data for some metrics, thus allowing
|
21
|
+
the input to be continuous probabilities, for AUC-type metrics and root-mean-square error (rmse).
|
22
|
+
1. Default = 0 (multiclass or binary) which infers:
|
23
|
+
binary (num_classes = 1) use >= 0.5 for truth and pred (same as argmax() for binary)
|
24
|
+
multi-class/single-label if truth_mutex= = true, use argmax() used on both truth and pred
|
25
|
+
note multilabel metrics are disabled for predict_thr = 0, must set predict_thr > 0
|
26
|
+
|
27
|
+
2. predict_thr > 0 (multilabel or binary) scalar or a vector [num_classes, 1] then use
|
28
|
+
predict_thr as a binary decision threshold in each class:
|
29
|
+
binary (num_classes = 1) use >= predict_thr[0] for pred and predict_thr[num_classes+1] for truth
|
30
|
+
if it exists, else use >= 0.5 for truth
|
31
|
+
multilabel use >= predict_thr for pred if scalar, or predict_thr[class_idx] if vector
|
32
|
+
use >= predict_thr[num_classes+1] for truth if exists, else 0.5
|
33
|
+
note multi-class/single-label inputs are meaningless in this mode, use predict_thr = 0 argmax mode
|
34
|
+
|
35
|
+
num_classes is inferred from 1D, 2D, or 3D truth inputs by default (default timesteps = -1 which implies None).
|
36
|
+
Only set timesteps > 0 in case of ambiguous binary 2D case where input [frames, timesteps],
|
37
|
+
then it must set to the number of timesteps (which will be > 0).
|
38
|
+
It is safe to always set timesteps <= 0 for binary inputs, and if truth.shape[2] exists
|
39
|
+
|
40
|
+
returns metrics over all frames + timesteps:
|
41
|
+
mcm [num_classes, 2, 2] multiclass confusion matrix count ove
|
42
|
+
metrics [num_classes, 14] [ACC, TPR, PPV, TNR, FPR, HITFA, F1, MCC, NT, PT, TP, FP, AP, AUC]
|
43
|
+
cm [num_classes, num_classes] confusion matrix
|
44
|
+
cmn [num_classes, num_classes] normalized confusion matrix
|
45
|
+
rmse [num_classes, 1] RMS error over all frames + timesteps, before threshold decision
|
46
|
+
mavg [3, 8] averages macro, micro, weighted [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
|
47
|
+
"""
|
48
|
+
import warnings
|
49
|
+
|
50
|
+
from sklearn.metrics import average_precision_score
|
51
|
+
from sklearn.metrics import confusion_matrix
|
52
|
+
from sklearn.metrics import multilabel_confusion_matrix
|
53
|
+
from sklearn.metrics import precision_recall_fscore_support
|
54
|
+
from sklearn.metrics import roc_auc_score
|
55
|
+
|
56
|
+
from ..utils.reshape import get_num_classes_from_predict
|
57
|
+
from ..utils.reshape import reshape_outputs
|
58
|
+
|
59
|
+
if truth.shape != predict.shape:
|
60
|
+
raise ValueError("truth and predict are not the same shape")
|
61
|
+
|
62
|
+
predict, truth = reshape_outputs(predict=predict, truth=truth, timesteps=timesteps) # type: ignore[assignment]
|
63
|
+
num_classes = get_num_classes_from_predict(predict=predict, timesteps=timesteps)
|
64
|
+
|
65
|
+
# Regression metric root-mean-square-error always works
|
66
|
+
rmse = np.sqrt(np.mean(np.square(truth - predict), axis=0))
|
67
|
+
|
68
|
+
# Calculate default predict decision thresholds based on mode
|
69
|
+
if not isinstance(predict_thr, np.ndarray):
|
70
|
+
if predict_thr == 0:
|
71
|
+
# if scalar and 0, set defaults
|
72
|
+
if num_classes == 1:
|
73
|
+
# binary case default >= 0.5 which is equiv to argmax()
|
74
|
+
predict_thr = np.atleast_1d(0.5)
|
75
|
+
else:
|
76
|
+
# multiclass, single-label (argmax mode)
|
77
|
+
predict_thr = np.atleast_1d(0)
|
78
|
+
else:
|
79
|
+
predict_thr = np.atleast_1d(predict_thr)
|
80
|
+
else:
|
81
|
+
if predict_thr.ndim > 1:
|
82
|
+
# multilabel with custom thr vector
|
83
|
+
if predict_thr.shape[0] != num_classes:
|
84
|
+
raise ValueError("predict_thr has wrong shape")
|
85
|
+
else:
|
86
|
+
if predict_thr == 0:
|
87
|
+
# binary or multilabel scalar default
|
88
|
+
predict_thr = np.atleast_1d(0.5)
|
89
|
+
else:
|
90
|
+
# user specified binary or multilabel scalar
|
91
|
+
predict_thr = np.atleast_1d(predict_thr)
|
92
|
+
|
93
|
+
if not isinstance(predict_thr, np.ndarray):
|
94
|
+
raise TypeError(f"predict_thr is invalid type: {type(predict_thr)}")
|
95
|
+
|
96
|
+
# Convert continuous probabilities to binary via argmax() or threshold comparison
|
97
|
+
# and create labels of int encoded (0:num_classes-1), and then equivalent one-hot
|
98
|
+
if num_classes == 1: # If binary
|
99
|
+
labels = list(range(0, 2)) # int encoded 0,1
|
100
|
+
plabel = np.int8(predict >= predict_thr) # [frames, 1], default 0.5 is equiv. to argmax()
|
101
|
+
tlabel = np.int8(truth >= truth_thr) # [frames, 1]
|
102
|
+
predb = np.array(plabel)
|
103
|
+
truthb = np.array(tlabel)
|
104
|
+
else:
|
105
|
+
labels = list(range(0, num_classes)) # int encoded 0,...,num_classes-1
|
106
|
+
if predict_thr[0] == 0: # multiclass single-label (mutex), use argmax
|
107
|
+
plabel = np.argmax(predict, axis=-1) # [frames, 1] labels
|
108
|
+
tlabel = np.argmax(truth, axis=-1) # [frames, 1] labels
|
109
|
+
# one-hot binary
|
110
|
+
predb = np.zeros(predict.shape, dtype=np.int8) # [frames, num_classes]
|
111
|
+
truthb = np.zeros(truth.shape, dtype=np.int8) # [frames, num_classes]
|
112
|
+
predb[np.arange(predb.shape[0]), plabel] = 1 # single-label [frames, num_classes]
|
113
|
+
if np.sum(truth): # special case all zero truth leave tlabel all zeros
|
114
|
+
truthb[np.arange(truthb.shape[0]), tlabel] = 1 # single-label [frames, num_classes]
|
115
|
+
else: # multilabel prob threshold comparison (multiple classes)
|
116
|
+
# multilabel one-hot decision
|
117
|
+
predb = np.array(predict >= predict_thr.transpose()).astype(np.int8) # [frames, num_classes]
|
118
|
+
truthb = np.array(truth >= truth_thr).astype(np.int8) # [frames, num_classes]
|
119
|
+
# Return argmax() for optional single-label confusion matrix metrics
|
120
|
+
plabel = np.argmax(predict, axis=-1) # [frames, 1] labels
|
121
|
+
tlabel = np.argmax(truth, axis=-1) # [frames, 1] labels
|
122
|
+
|
123
|
+
# debug checks to understand ap, auc:
|
124
|
+
# from sklearn.metrics import roc_curve
|
125
|
+
# fpr, tpr, thr = roc_curve(truthb[:,0],predict[:,0],drop_intermediate=False)
|
126
|
+
# from sklearn.metrics import precision_recall_curve
|
127
|
+
# precision, recall, thr = precision_recall_curve(truthb[:,0], predict[:,0])
|
128
|
+
# from sklearn.metrics import RocCurveDisplay
|
129
|
+
# RocCurveDisplay.from_predictions(truthb[:,0],predict[:,0]) # Plot ROC class0
|
130
|
+
|
131
|
+
# Create [num_classes, 2, 2] multilabel confusion matrix (mcm)
|
132
|
+
# Note - must include labels or sklearn func. will omit non-exiting classes
|
133
|
+
mcm = multilabel_confusion_matrix(truthb, predb, labels=labels)
|
134
|
+
|
135
|
+
if num_classes == 1:
|
136
|
+
mcm = mcm[1:] # remove dim 0 if binary
|
137
|
+
|
138
|
+
# Create [num_classes, num_classes] normalized confusion matrix
|
139
|
+
cmn = confusion_matrix(tlabel, plabel, labels=labels, normalize="true")
|
140
|
+
|
141
|
+
# Create [num_classes, num_classes] confusion matrix
|
142
|
+
cm = confusion_matrix(tlabel, plabel, labels=labels)
|
143
|
+
|
144
|
+
# Combine all per-class metrics into a single array
|
145
|
+
# [ACC, TPR, PPV, TNR, FPR, HITFA, F1, MCC, NT, PT, TP, FP, AP, AUC]
|
146
|
+
metrics = np.zeros((num_classes, 14))
|
147
|
+
# threshold_optpr = np.zeros((num_classes, 1))
|
148
|
+
eps = np.finfo(float).eps
|
149
|
+
for nci in range(num_classes):
|
150
|
+
# True negative
|
151
|
+
TN = mcm[nci, 0, 0]
|
152
|
+
# False positive
|
153
|
+
FP = mcm[nci, 0, 1]
|
154
|
+
# False negative
|
155
|
+
FN = mcm[nci, 1, 0]
|
156
|
+
# True positive
|
157
|
+
TP = mcm[nci, 1, 1]
|
158
|
+
# Accuracy
|
159
|
+
ACC = (TP + TN) / (TP + TN + FP + FN + eps)
|
160
|
+
# True positive rate, sensitivity, recall, hit rate (note eps in numerator)
|
161
|
+
# When ``true positive + false negative == 0``, recall is undefined, set to 0
|
162
|
+
TPR = TP / (TP + FN + eps)
|
163
|
+
# Precision, positive predictive value
|
164
|
+
# When ``true positive + false positive == 0``, precision is undefined, set to 0
|
165
|
+
PPV = TP / (TP + FP + eps)
|
166
|
+
# Specificity i.e., selectivity, or true negative rate
|
167
|
+
TNR = TN / (TN + FP + eps)
|
168
|
+
# False positive rate = 1-specificity, roc x-axis
|
169
|
+
FPR = FP / (TN + FP + eps)
|
170
|
+
# HitFA used by some separation research, close match to MCC
|
171
|
+
HITFA = TPR - FPR
|
172
|
+
# F1 harmonic mean of precision, recall = 2*PPV*TPR / (PPV + TPR)
|
173
|
+
F1 = 2 * TP / (2 * TP + FP + FN + eps)
|
174
|
+
# Matthew correlation coefficient
|
175
|
+
MCC = (TP * TN - FP * FN) / (np.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN)) + eps)
|
176
|
+
# Num. negatives total (truth), also = TN+FP denom of FPR
|
177
|
+
NT = sum(mcm[nci, 0, :])
|
178
|
+
# Num. positives total (truth), also = FN+TP denom of TPR, precision
|
179
|
+
PT = sum(mcm[nci, 1, :])
|
180
|
+
# Average Precision also called area under the PR curve AUCPR and
|
181
|
+
# AUC ROC curve using binary-ized truth and continuous prediction probabilities
|
182
|
+
# sklearn returns nan if no active truth in a class but w/un-suppressible div-by-zero warning
|
183
|
+
if np.sum(truthb[:, nci]) == 0: # if no active classes both sklearn will fail, set to NaN
|
184
|
+
AUC = np.NaN
|
185
|
+
AP = np.NaN
|
186
|
+
# threshold_optpr[nci] = np.NaN
|
187
|
+
else:
|
188
|
+
AP = average_precision_score(truthb[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
|
189
|
+
if len(np.unique(truthb[:, nci])) < 2: # if active classes not > 1 AUC must be NaN
|
190
|
+
AUC = np.NaN # i.e. all ones sklearn auc will fail
|
191
|
+
else:
|
192
|
+
AUC = roc_auc_score(truthb[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
|
193
|
+
# # Optimal threshold from PR curve, optimizes f-score
|
194
|
+
# precision, recall, thresholds = precision_recall_curve(truthb[:, nci], predict[:, nci])
|
195
|
+
# fscore = (2 * precision * recall) / (precision + recall)
|
196
|
+
# ix = np.argmax(fscore) # index of largest f1 score
|
197
|
+
# threshold_optpr[nci] = thresholds[ix]
|
198
|
+
|
199
|
+
metrics[nci, :] = [
|
200
|
+
ACC,
|
201
|
+
TPR,
|
202
|
+
PPV,
|
203
|
+
TNR,
|
204
|
+
FPR,
|
205
|
+
HITFA,
|
206
|
+
F1,
|
207
|
+
MCC,
|
208
|
+
NT,
|
209
|
+
PT,
|
210
|
+
TP,
|
211
|
+
FP,
|
212
|
+
AP,
|
213
|
+
AUC,
|
214
|
+
]
|
215
|
+
|
216
|
+
# Calculate averages into single array, 3 types for now Macro, Micro, Weighted
|
217
|
+
mavg = np.zeros((3, 8), dtype=np.float32)
|
218
|
+
s = np.sum(metrics[:, 9].astype(int)) # support = sum (true pos total = FN+TP ) over classes
|
219
|
+
|
220
|
+
# macro average [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
|
221
|
+
with warnings.catch_warnings():
|
222
|
+
warnings.filterwarnings(action="ignore", message="Mean of empty slice")
|
223
|
+
mavg[0, :] = [
|
224
|
+
np.mean(metrics[:, 2]),
|
225
|
+
np.mean(metrics[:, 1]),
|
226
|
+
np.mean(metrics[:, 6]),
|
227
|
+
np.mean(metrics[:, 4]),
|
228
|
+
np.mean(metrics[:, 0]),
|
229
|
+
np.nanmean(metrics[:, 12]),
|
230
|
+
np.nanmean(metrics[:, 13]),
|
231
|
+
s,
|
232
|
+
]
|
233
|
+
|
234
|
+
# micro average, micro-F1 = micro-precision = micro-recall = accuracy
|
235
|
+
if num_classes > 1:
|
236
|
+
tp_sum = np.sum(metrics[:, 10]) # TP all classes
|
237
|
+
rm = tp_sum / (np.sum(metrics[:, 9]) + eps) # micro mean PPV = TP / (PT=FN+TP)
|
238
|
+
fp_sum = np.sum(metrics[:, 11]) # FP false-positives all classes
|
239
|
+
fpm = fp_sum / (np.sum(metrics[:, 8]) + eps) # micro mean FPR = FP / (NT=TN+FP)
|
240
|
+
pm = tp_sum / (tp_sum + fp_sum + eps) # micro mean TPR = TP / (TP+FP) (note: same as rm for micro-avg)
|
241
|
+
fn_sum = sum(mcm[:, 1, 0])
|
242
|
+
f1m = 2 * tp_sum / (2 * tp_sum + fp_sum + fn_sum + eps)
|
243
|
+
tn_sum = sum(mcm[:, 0, 0])
|
244
|
+
accm = (tp_sum + tn_sum) / (tp_sum + tn_sum + fp_sum + fn_sum + eps)
|
245
|
+
with warnings.catch_warnings():
|
246
|
+
warnings.filterwarnings(action="ignore", message="invalid value encountered in true_divide")
|
247
|
+
miap = average_precision_score(truthb, predict, average="micro")
|
248
|
+
if np.sum(truthb): # no activity over all classes
|
249
|
+
miauc = roc_auc_score(truthb, predict, average="micro")
|
250
|
+
else:
|
251
|
+
miauc = np.NaN
|
252
|
+
|
253
|
+
# [miPPV, miTPR, miF1, miFPR, miACC, miAP, miAUC, TPSUM]
|
254
|
+
mavg[1, :] = [
|
255
|
+
pm,
|
256
|
+
rm,
|
257
|
+
f1m,
|
258
|
+
fpm,
|
259
|
+
accm,
|
260
|
+
miap,
|
261
|
+
miauc,
|
262
|
+
s,
|
263
|
+
] # specific format, last 3 are unique
|
264
|
+
|
265
|
+
# weighted average TBD
|
266
|
+
wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average="weighted", zero_division=0) # pyright: ignore [reportArgumentType]
|
267
|
+
if np.sum(truthb):
|
268
|
+
taidx = np.sum(truthb, axis=0) > 0
|
269
|
+
wap = average_precision_score(truthb[:, taidx], predict[:, taidx], average="weighted")
|
270
|
+
if len(np.unique(truthb[:, taidx])) < 2:
|
271
|
+
wauc = np.NaN
|
272
|
+
else:
|
273
|
+
wauc = roc_auc_score(truthb[:, taidx], predict[:, taidx], average="weighted")
|
274
|
+
else:
|
275
|
+
wap = np.NaN
|
276
|
+
wauc = np.NaN
|
277
|
+
|
278
|
+
mavg[2, :] = [wp, wr, wf1, 0, 0, wap, wauc, s]
|
279
|
+
else: # binary case, all are same
|
280
|
+
mavg[1, :] = mavg[0, :]
|
281
|
+
mavg[2, :] = mavg[0, :]
|
282
|
+
|
283
|
+
return mcm, metrics, cm, cmn, rmse, mavg
|
@@ -0,0 +1,128 @@
|
|
1
|
+
# ruff: noqa: F821
|
2
|
+
import numpy as np
|
3
|
+
import pandas as pd
|
4
|
+
|
5
|
+
from ..datatypes import GeneralizedIDs
|
6
|
+
from ..datatypes import Predict
|
7
|
+
from ..datatypes import Segsnr
|
8
|
+
from ..datatypes import Truth
|
9
|
+
from ..mixture.mixdb import MixtureDatabase
|
10
|
+
|
11
|
+
|
12
|
+
def snr_summary(
|
13
|
+
mixdb: MixtureDatabase,
|
14
|
+
mixid: GeneralizedIDs,
|
15
|
+
truth_f: Truth,
|
16
|
+
predict: Predict,
|
17
|
+
segsnr: Segsnr | None = None,
|
18
|
+
predict_thr: float | np.ndarray = 0,
|
19
|
+
truth_thr: float = 0.5,
|
20
|
+
timesteps: int = 0,
|
21
|
+
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, dict]:
|
22
|
+
"""Calculate average-over-class metrics per SNR over specified mixture list.
|
23
|
+
Inputs:
|
24
|
+
mixdb Mixture database
|
25
|
+
mixid
|
26
|
+
truth_f Truth/labels [features, num_classes]
|
27
|
+
predict Prediction data / neural net model one-hot out [features, num_classes]
|
28
|
+
segsnr Segmental SNR from SonusAI genft [transform_frames, 1]
|
29
|
+
predict_thr Decision threshold(s) applied to predict data, allowing predict to be
|
30
|
+
continuous probabilities or decisions
|
31
|
+
truth_thr Decision threshold(s) applied to truth data, allowing truth to be
|
32
|
+
continuous probabilities or decisions
|
33
|
+
timesteps
|
34
|
+
|
35
|
+
Default predict_thr=0 will infer 0.5 for multi-label mode (truth_mutex = False), or
|
36
|
+
if single-label mode (truth_mutex == True) then ignore and use argmax mode, and
|
37
|
+
the confusion matrix is calculated for all classes.
|
38
|
+
|
39
|
+
Returns pandas dataframe (snrdf) of metrics per SNR.
|
40
|
+
"""
|
41
|
+
import warnings
|
42
|
+
|
43
|
+
from ..metrics.one_hot import one_hot
|
44
|
+
from ..queries.queries import get_mixids_from_snr
|
45
|
+
|
46
|
+
num_classes = truth_f.shape[1]
|
47
|
+
|
48
|
+
snr_mixids = get_mixids_from_snr(mixdb=mixdb, mixids=mixid)
|
49
|
+
|
50
|
+
# Check predict_thr array or scalar and return final scalar predict_thr value
|
51
|
+
if num_classes > 1:
|
52
|
+
if not isinstance(predict_thr, np.ndarray):
|
53
|
+
if predict_thr == 0:
|
54
|
+
# multi-label predict_thr scalar 0 force to 0.5 default
|
55
|
+
predict_thr = np.atleast_1d(0.5)
|
56
|
+
else:
|
57
|
+
predict_thr = np.atleast_1d(predict_thr)
|
58
|
+
else:
|
59
|
+
if predict_thr.ndim == 1 and len(predict_thr) == 1:
|
60
|
+
if predict_thr[0] == 0:
|
61
|
+
# multi-label predict_thr array scalar 0 force to 0.5 default
|
62
|
+
predict_thr = np.atleast_1d(0.5)
|
63
|
+
else:
|
64
|
+
# multi-label predict_thr array set to scalar = array[0]
|
65
|
+
predict_thr = predict_thr[0]
|
66
|
+
|
67
|
+
macro_avg = np.zeros((len(snr_mixids), 7), dtype=np.float32)
|
68
|
+
micro_avg = np.zeros((len(snr_mixids), 7), dtype=np.float32)
|
69
|
+
wghtd_avg = np.zeros((len(snr_mixids), 7), dtype=np.float32)
|
70
|
+
ssnr_stats = None
|
71
|
+
segsnr_f = None
|
72
|
+
|
73
|
+
if segsnr is not None:
|
74
|
+
# prep segsnr if provided, transform frames to feature frames via mean()
|
75
|
+
# expected to always be an integer
|
76
|
+
feature_frames = int(segsnr.shape[0] / truth_f.shape[0])
|
77
|
+
segsnr_f = np.mean(
|
78
|
+
np.reshape(segsnr, (truth_f.shape[0], feature_frames)),
|
79
|
+
axis=1,
|
80
|
+
keepdims=True,
|
81
|
+
)
|
82
|
+
ssnr_stats = np.zeros((len(snr_mixids), 3), dtype=np.float32)
|
83
|
+
|
84
|
+
for ii, snr in enumerate(snr_mixids):
|
85
|
+
# TODO: re-work for modern mixdb API
|
86
|
+
y_truth, y_predict = get_mixids_data(mixdb, snr_mixids[snr], truth_f, predict) # type: ignore[name-defined]
|
87
|
+
_, _, _, _, _, mavg = one_hot(y_truth, y_predict, predict_thr, truth_thr, timesteps)
|
88
|
+
|
89
|
+
# mavg macro, micro, weighted: [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
|
90
|
+
macro_avg[ii, :] = mavg[0, 0:7]
|
91
|
+
micro_avg[ii, :] = mavg[1, 0:7]
|
92
|
+
wghtd_avg[ii, :] = mavg[2, 0:7]
|
93
|
+
if segsnr is not None:
|
94
|
+
# TODO: re-work for modern mixdb API
|
95
|
+
y_truth, y_segsnr = get_mixids_data(mixdb, snr_mixids[snr], truth_f, segsnr_f) # type: ignore[name-defined]
|
96
|
+
with warnings.catch_warnings():
|
97
|
+
warnings.filterwarnings(action="ignore", message="divide by zero encountered in log10")
|
98
|
+
# segmental SNR mean = mixture_snr and target_snr
|
99
|
+
ssnr_stats[ii, 0] = 10 * np.log10(np.mean(y_segsnr)) # type: ignore[index]
|
100
|
+
# segmental SNR 80% percentile
|
101
|
+
ssnr_stats[ii, 1] = 10 * np.log10(np.percentile(y_segsnr, 80, method="midpoint")) # type: ignore[index]
|
102
|
+
# segmental SNR max
|
103
|
+
ssnr_stats[ii, 2] = 10 * np.log10(max(y_segsnr)) # type: ignore[index]
|
104
|
+
|
105
|
+
# SNR format: PPV, TPR, F1, FPR, ACC, AP, AUC
|
106
|
+
col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC"]
|
107
|
+
snr_macrodf = pd.DataFrame(macro_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
|
108
|
+
snr_macrodf.sort_index(ascending=False, inplace=True)
|
109
|
+
|
110
|
+
snr_microdf = pd.DataFrame(micro_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
|
111
|
+
snr_microdf.sort_index(ascending=False, inplace=True)
|
112
|
+
|
113
|
+
snr_wghtdf = pd.DataFrame(wghtd_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
|
114
|
+
snr_wghtdf.sort_index(ascending=False, inplace=True)
|
115
|
+
|
116
|
+
# Add segmental SNR columns if provided
|
117
|
+
if segsnr is not None:
|
118
|
+
ssnrdf = pd.DataFrame(
|
119
|
+
ssnr_stats,
|
120
|
+
index=list(snr_mixids.keys()), # pyright: ignore [reportArgumentType]
|
121
|
+
columns=["SSNRavg", "SSNR80p", "SSNRmax"], # pyright: ignore [reportArgumentType]
|
122
|
+
)
|
123
|
+
ssnrdf.sort_index(ascending=False, inplace=True)
|
124
|
+
snr_macrodf = pd.concat([snr_macrodf, ssnrdf], axis=1)
|
125
|
+
snr_microdf = pd.concat([snr_microdf, ssnrdf], axis=1)
|
126
|
+
snr_wghtdf = pd.concat([snr_wghtdf, ssnrdf], axis=1)
|
127
|
+
|
128
|
+
return snr_macrodf, snr_microdf, snr_wghtdf, snr_mixids
|