sonusai 0.18.9__py3-none-any.whl → 0.19.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +81 -91
- sonusai/genmetrics.py +51 -61
- sonusai/genmix.py +105 -115
- sonusai/genmixdb.py +201 -174
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +16 -18
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +20 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +40 -38
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +669 -477
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +58 -101
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +41 -30
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/METADATA +20 -21
- sonusai-0.19.6.dist-info/RECORD +125 -0
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.9.dist-info/RECORD +0 -125
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/entry_points.txt +0 -0
sonusai/metrics/one_hot.py
CHANGED
@@ -4,44 +4,46 @@ from sonusai.mixture.datatypes import Predict
|
|
4
4
|
from sonusai.mixture.datatypes import Truth
|
5
5
|
|
6
6
|
|
7
|
-
def one_hot(
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
7
|
+
def one_hot(
|
8
|
+
truth: Truth,
|
9
|
+
predict: Predict,
|
10
|
+
predict_thr: float | np.ndarray = 0,
|
11
|
+
truth_thr: float = 0.5,
|
12
|
+
timesteps: int = -1,
|
13
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
14
|
+
"""Calculates metrics from one-hot prediction and truth data (numpy float arrays) where
|
15
|
+
both are one-hot probabilities (or quantized decisions) for each class
|
16
|
+
with size [frames, num_classes] or [frames, timesteps, num_classes].
|
17
|
+
For metrics that require it, truth and pred decisions will be made using threshold >= predict_thr.
|
18
|
+
Some metrics like AP and AUC do not depend on predict_thr for predict, but still use truth >= predict_thr
|
19
|
+
|
20
|
+
predict_thr sets the decision threshold(s) applied to predict data for some metrics, thus allowing
|
21
|
+
the input to be continuous probabilities, for AUC-type metrics and root-mean-square error (rmse).
|
22
|
+
1. Default = 0 (multiclass or binary) which infers:
|
23
|
+
binary (num_classes = 1) use >= 0.5 for truth and pred (same as argmax() for binary)
|
24
|
+
multi-class/single-label if truth_mutex= = true, use argmax() used on both truth and pred
|
25
|
+
note multilabel metrics are disabled for predict_thr = 0, must set predict_thr > 0
|
26
|
+
|
27
|
+
2. predict_thr > 0 (multilabel or binary) scalar or a vector [num_classes, 1] then use
|
28
|
+
predict_thr as a binary decision threshold in each class:
|
29
|
+
binary (num_classes = 1) use >= predict_thr[0] for pred and predict_thr[num_classes+1] for truth
|
30
|
+
if it exists, else use >= 0.5 for truth
|
31
|
+
multilabel use >= predict_thr for pred if scalar, or predict_thr[class_idx] if vector
|
32
|
+
use >= predict_thr[num_classes+1] for truth if exists, else 0.5
|
33
|
+
note multi-class/single-label inputs are meaningless in this mode, use predict_thr = 0 argmax mode
|
34
|
+
|
35
|
+
num_classes is inferred from 1D, 2D, or 3D truth inputs by default (default timesteps = -1 which implies None).
|
36
|
+
Only set timesteps > 0 in case of ambiguous binary 2D case where input [frames, timesteps],
|
37
|
+
then it must set to the number of timesteps (which will be > 0).
|
38
|
+
It is safe to always set timesteps <= 0 for binary inputs, and if truth.shape[2] exists
|
39
|
+
|
40
|
+
returns metrics over all frames + timesteps:
|
41
|
+
mcm [num_classes, 2, 2] multiclass confusion matrix count ove
|
42
|
+
metrics [num_classes, 14] [ACC, TPR, PPV, TNR, FPR, HITFA, F1, MCC, NT, PT, TP, FP, AP, AUC]
|
43
|
+
cm [num_classes, num_classes] confusion matrix
|
44
|
+
cmn [num_classes, num_classes] normalized confusion matrix
|
45
|
+
rmse [num_classes, 1] RMS error over all frames + timesteps, before threshold decision
|
46
|
+
mavg [3, 8] averages macro, micro, weighted [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
|
45
47
|
"""
|
46
48
|
import warnings
|
47
49
|
|
@@ -51,14 +53,13 @@ def one_hot(truth: Truth,
|
|
51
53
|
from sklearn.metrics import precision_recall_fscore_support
|
52
54
|
from sklearn.metrics import roc_auc_score
|
53
55
|
|
54
|
-
from sonusai import SonusAIError
|
55
56
|
from sonusai.utils import get_num_classes_from_predict
|
56
57
|
from sonusai.utils import reshape_outputs
|
57
58
|
|
58
59
|
if truth.shape != predict.shape:
|
59
|
-
raise
|
60
|
+
raise ValueError("truth and predict are not the same shape")
|
60
61
|
|
61
|
-
predict, truth = reshape_outputs(predict=predict, truth=truth, timesteps=timesteps)
|
62
|
+
predict, truth = reshape_outputs(predict=predict, truth=truth, timesteps=timesteps) # type: ignore[assignment]
|
62
63
|
num_classes = get_num_classes_from_predict(predict=predict, timesteps=timesteps)
|
63
64
|
|
64
65
|
# Regression metric root-mean-square-error always works
|
@@ -79,7 +80,8 @@ def one_hot(truth: Truth,
|
|
79
80
|
else:
|
80
81
|
if predict_thr.ndim > 1:
|
81
82
|
# multilabel with custom thr vector
|
82
|
-
|
83
|
+
if predict_thr.shape[0] != num_classes:
|
84
|
+
raise ValueError("predict_thr has wrong shape")
|
83
85
|
else:
|
84
86
|
if predict_thr == 0:
|
85
87
|
# binary or multilabel scalar default
|
@@ -89,18 +91,18 @@ def one_hot(truth: Truth,
|
|
89
91
|
predict_thr = np.atleast_1d(predict_thr)
|
90
92
|
|
91
93
|
if not isinstance(predict_thr, np.ndarray):
|
92
|
-
raise
|
94
|
+
raise TypeError(f"predict_thr is invalid type: {type(predict_thr)}")
|
93
95
|
|
94
96
|
# Convert continuous probabilities to binary via argmax() or threshold comparison
|
95
97
|
# and create labels of int encoded (0:num_classes-1), and then equivalent one-hot
|
96
98
|
if num_classes == 1: # If binary
|
97
|
-
labels = (
|
99
|
+
labels = list(range(0, 2)) # int encoded 0,1
|
98
100
|
plabel = np.int8(predict >= predict_thr) # [frames, 1], default 0.5 is equiv. to argmax()
|
99
101
|
tlabel = np.int8(truth >= truth_thr) # [frames, 1]
|
100
102
|
predb = np.array(plabel)
|
101
103
|
truthb = np.array(tlabel)
|
102
104
|
else:
|
103
|
-
labels = (
|
105
|
+
labels = list(range(0, num_classes)) # int encoded 0,...,num_classes-1
|
104
106
|
if predict_thr[0] == 0: # multiclass single-label (mutex), use argmax
|
105
107
|
plabel = np.argmax(predict, axis=-1) # [frames, 1] labels
|
106
108
|
tlabel = np.argmax(truth, axis=-1) # [frames, 1] labels
|
@@ -134,7 +136,7 @@ def one_hot(truth: Truth,
|
|
134
136
|
mcm = mcm[1:] # remove dim 0 if binary
|
135
137
|
|
136
138
|
# Create [num_classes, num_classes] normalized confusion matrix
|
137
|
-
cmn = confusion_matrix(tlabel, plabel, labels=labels, normalize=
|
139
|
+
cmn = confusion_matrix(tlabel, plabel, labels=labels, normalize="true")
|
138
140
|
|
139
141
|
# Create [num_classes, num_classes] confusion matrix
|
140
142
|
cm = confusion_matrix(tlabel, plabel, labels=labels)
|
@@ -194,7 +196,22 @@ def one_hot(truth: Truth,
|
|
194
196
|
# ix = np.argmax(fscore) # index of largest f1 score
|
195
197
|
# threshold_optpr[nci] = thresholds[ix]
|
196
198
|
|
197
|
-
metrics[nci, :] = [
|
199
|
+
metrics[nci, :] = [
|
200
|
+
ACC,
|
201
|
+
TPR,
|
202
|
+
PPV,
|
203
|
+
TNR,
|
204
|
+
FPR,
|
205
|
+
HITFA,
|
206
|
+
F1,
|
207
|
+
MCC,
|
208
|
+
NT,
|
209
|
+
PT,
|
210
|
+
TP,
|
211
|
+
FP,
|
212
|
+
AP,
|
213
|
+
AUC,
|
214
|
+
]
|
198
215
|
|
199
216
|
# Calculate averages into single array, 3 types for now Macro, Micro, Weighted
|
200
217
|
mavg = np.zeros((3, 8), dtype=np.float32)
|
@@ -202,9 +219,17 @@ def one_hot(truth: Truth,
|
|
202
219
|
|
203
220
|
# macro average [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
|
204
221
|
with warnings.catch_warnings():
|
205
|
-
warnings.filterwarnings(action=
|
206
|
-
mavg[0, :] = [
|
207
|
-
|
222
|
+
warnings.filterwarnings(action="ignore", message="Mean of empty slice")
|
223
|
+
mavg[0, :] = [
|
224
|
+
np.mean(metrics[:, 2]),
|
225
|
+
np.mean(metrics[:, 1]),
|
226
|
+
np.mean(metrics[:, 6]),
|
227
|
+
np.mean(metrics[:, 4]),
|
228
|
+
np.mean(metrics[:, 0]),
|
229
|
+
np.nanmean(metrics[:, 12]),
|
230
|
+
np.nanmean(metrics[:, 13]),
|
231
|
+
s,
|
232
|
+
]
|
208
233
|
|
209
234
|
# micro average, micro-F1 = micro-precision = micro-recall = accuracy
|
210
235
|
if num_classes > 1:
|
@@ -218,25 +243,34 @@ def one_hot(truth: Truth,
|
|
218
243
|
tn_sum = sum(mcm[:, 0, 0])
|
219
244
|
accm = (tp_sum + tn_sum) / (tp_sum + tn_sum + fp_sum + fn_sum + eps)
|
220
245
|
with warnings.catch_warnings():
|
221
|
-
warnings.filterwarnings(action=
|
222
|
-
miap = average_precision_score(truthb, predict, average=
|
246
|
+
warnings.filterwarnings(action="ignore", message="invalid value encountered in true_divide")
|
247
|
+
miap = average_precision_score(truthb, predict, average="micro")
|
223
248
|
if np.sum(truthb): # no activity over all classes
|
224
|
-
miauc = roc_auc_score(truthb, predict, average=
|
249
|
+
miauc = roc_auc_score(truthb, predict, average="micro")
|
225
250
|
else:
|
226
251
|
miauc = np.NaN
|
227
252
|
|
228
253
|
# [miPPV, miTPR, miF1, miFPR, miACC, miAP, miAUC, TPSUM]
|
229
|
-
mavg[1, :] = [
|
254
|
+
mavg[1, :] = [
|
255
|
+
pm,
|
256
|
+
rm,
|
257
|
+
f1m,
|
258
|
+
fpm,
|
259
|
+
accm,
|
260
|
+
miap,
|
261
|
+
miauc,
|
262
|
+
s,
|
263
|
+
] # specific format, last 3 are unique
|
230
264
|
|
231
265
|
# weighted average TBD
|
232
|
-
wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average=
|
266
|
+
wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average="weighted", zero_division=0)
|
233
267
|
if np.sum(truthb):
|
234
268
|
taidx = np.sum(truthb, axis=0) > 0
|
235
|
-
wap = average_precision_score(truthb[:, taidx], predict[:, taidx], average=
|
269
|
+
wap = average_precision_score(truthb[:, taidx], predict[:, taidx], average="weighted")
|
236
270
|
if len(np.unique(truthb[:, taidx])) < 2:
|
237
271
|
wauc = np.NaN
|
238
272
|
else:
|
239
|
-
wauc = roc_auc_score(truthb[:, taidx], predict[:, taidx], average=
|
273
|
+
wauc = roc_auc_score(truthb[:, taidx], predict[:, taidx], average="weighted")
|
240
274
|
else:
|
241
275
|
wap = np.NaN
|
242
276
|
wauc = np.NaN
|
sonusai/metrics/snr_summary.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
# ruff: noqa: F821
|
1
2
|
import numpy as np
|
2
3
|
import pandas as pd
|
3
4
|
|
@@ -8,32 +9,34 @@ from sonusai.mixture import Segsnr
|
|
8
9
|
from sonusai.mixture import Truth
|
9
10
|
|
10
11
|
|
11
|
-
def snr_summary(
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
12
|
+
def snr_summary(
|
13
|
+
mixdb: MixtureDatabase,
|
14
|
+
mixid: GeneralizedIDs,
|
15
|
+
truth_f: Truth,
|
16
|
+
predict: Predict,
|
17
|
+
segsnr: Segsnr | None = None,
|
18
|
+
predict_thr: float | np.ndarray = 0,
|
19
|
+
truth_thr: float = 0.5,
|
20
|
+
timesteps: int = 0,
|
21
|
+
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, dict]:
|
19
22
|
"""Calculate average-over-class metrics per SNR over specified mixture list.
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
23
|
+
Inputs:
|
24
|
+
mixdb Mixture database
|
25
|
+
mixid
|
26
|
+
truth_f Truth/labels [features, num_classes]
|
27
|
+
predict Prediction data / neural net model one-hot out [features, num_classes]
|
28
|
+
segsnr Segmental SNR from SonusAI genft [transform_frames, 1]
|
29
|
+
predict_thr Decision threshold(s) applied to predict data, allowing predict to be
|
30
|
+
continuous probabilities or decisions
|
31
|
+
truth_thr Decision threshold(s) applied to truth data, allowing truth to be
|
32
|
+
continuous probabilities or decisions
|
33
|
+
timesteps
|
34
|
+
|
35
|
+
Default predict_thr=0 will infer 0.5 for multi-label mode (truth_mutex = False), or
|
36
|
+
if single-label mode (truth_mutex == True) then ignore and use argmax mode, and
|
37
|
+
the confusion matrix is calculated for all classes.
|
38
|
+
|
39
|
+
Returns pandas dataframe (snrdf) of metrics per SNR.
|
37
40
|
"""
|
38
41
|
import warnings
|
39
42
|
|
@@ -53,14 +56,13 @@ def snr_summary(mixdb: MixtureDatabase,
|
|
53
56
|
else:
|
54
57
|
predict_thr = np.atleast_1d(predict_thr)
|
55
58
|
else:
|
56
|
-
if predict_thr.ndim == 1:
|
57
|
-
if
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
predict_thr = predict_thr[0]
|
59
|
+
if predict_thr.ndim == 1 and len(predict_thr) == 1:
|
60
|
+
if predict_thr[0] == 0:
|
61
|
+
# multi-label predict_thr array scalar 0 force to 0.5 default
|
62
|
+
predict_thr = np.atleast_1d(0.5)
|
63
|
+
else:
|
64
|
+
# multi-label predict_thr array set to scalar = array[0]
|
65
|
+
predict_thr = predict_thr[0]
|
64
66
|
|
65
67
|
macro_avg = np.zeros((len(snr_mixids), 7), dtype=np.float32)
|
66
68
|
micro_avg = np.zeros((len(snr_mixids), 7), dtype=np.float32)
|
@@ -72,13 +74,16 @@ def snr_summary(mixdb: MixtureDatabase,
|
|
72
74
|
# prep segsnr if provided, transform frames to feature frames via mean()
|
73
75
|
# expected to always be an integer
|
74
76
|
feature_frames = int(segsnr.shape[0] / truth_f.shape[0])
|
75
|
-
segsnr_f = np.mean(
|
77
|
+
segsnr_f = np.mean(
|
78
|
+
np.reshape(segsnr, (truth_f.shape[0], feature_frames)),
|
79
|
+
axis=1,
|
80
|
+
keepdims=True,
|
81
|
+
)
|
76
82
|
ssnr_stats = np.zeros((len(snr_mixids), 3), dtype=np.float32)
|
77
83
|
|
78
|
-
ii
|
79
|
-
for snr in snr_mixids:
|
84
|
+
for ii, snr in enumerate(snr_mixids):
|
80
85
|
# TODO: re-work for modern mixdb API
|
81
|
-
y_truth, y_predict = get_mixids_data(mixdb, snr_mixids[snr], truth_f, predict) # type: ignore
|
86
|
+
y_truth, y_predict = get_mixids_data(mixdb, snr_mixids[snr], truth_f, predict) # type: ignore[name-defined]
|
82
87
|
_, metrics, _, _, _, mavg = one_hot(y_truth, y_predict, predict_thr, truth_thr, timesteps)
|
83
88
|
|
84
89
|
# mavg macro, micro, weighted: [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
|
@@ -87,20 +92,18 @@ def snr_summary(mixdb: MixtureDatabase,
|
|
87
92
|
wghtd_avg[ii, :] = mavg[2, 0:7]
|
88
93
|
if segsnr is not None:
|
89
94
|
# TODO: re-work for modern mixdb API
|
90
|
-
y_truth, y_segsnr = get_mixids_data(mixdb, snr_mixids[snr], truth_f, segsnr_f) # type: ignore
|
95
|
+
y_truth, y_segsnr = get_mixids_data(mixdb, snr_mixids[snr], truth_f, segsnr_f) # type: ignore[name-defined]
|
91
96
|
with warnings.catch_warnings():
|
92
|
-
warnings.filterwarnings(action=
|
97
|
+
warnings.filterwarnings(action="ignore", message="divide by zero encountered in log10")
|
93
98
|
# segmental SNR mean = mixture_snr and target_snr
|
94
|
-
ssnr_stats[ii, 0] = 10 * np.log10(np.mean(y_segsnr))
|
99
|
+
ssnr_stats[ii, 0] = 10 * np.log10(np.mean(y_segsnr)) # type: ignore[index]
|
95
100
|
# segmental SNR 80% percentile
|
96
|
-
ssnr_stats[ii, 1] = 10 * np.log10(np.percentile(y_segsnr, 80, method=
|
101
|
+
ssnr_stats[ii, 1] = 10 * np.log10(np.percentile(y_segsnr, 80, method="midpoint")) # type: ignore[index]
|
97
102
|
# segmental SNR max
|
98
|
-
ssnr_stats[ii, 2] = 10 * np.log10(max(y_segsnr))
|
99
|
-
|
100
|
-
ii += 1
|
103
|
+
ssnr_stats[ii, 2] = 10 * np.log10(max(y_segsnr)) # type: ignore[index]
|
101
104
|
|
102
105
|
# SNR format: PPV, TPR, F1, FPR, ACC, AP, AUC
|
103
|
-
col_n = [
|
106
|
+
col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC"]
|
104
107
|
snr_macrodf = pd.DataFrame(macro_avg, index=list(snr_mixids.keys()), columns=col_n)
|
105
108
|
snr_macrodf.sort_index(ascending=False, inplace=True)
|
106
109
|
|
@@ -112,7 +115,11 @@ def snr_summary(mixdb: MixtureDatabase,
|
|
112
115
|
|
113
116
|
# Add segmental SNR columns if provided
|
114
117
|
if segsnr is not None:
|
115
|
-
ssnrdf = pd.DataFrame(
|
118
|
+
ssnrdf = pd.DataFrame(
|
119
|
+
ssnr_stats,
|
120
|
+
index=list(snr_mixids.keys()),
|
121
|
+
columns=["SSNRavg", "SSNR80p", "SSNRmax"],
|
122
|
+
)
|
116
123
|
ssnrdf.sort_index(ascending=False, inplace=True)
|
117
124
|
snr_macrodf = pd.concat([snr_macrodf, ssnrdf], axis=1)
|
118
125
|
snr_microdf = pd.concat([snr_microdf, ssnrdf], axis=1)
|
sonusai/mixture/__init__.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1
1
|
# SonusAI mixture utilities
|
2
|
+
# ruff: noqa: F401
|
3
|
+
|
2
4
|
from .audio import get_duration
|
3
5
|
from .audio import get_next_noise
|
4
6
|
from .audio import get_num_samples
|
@@ -19,15 +21,15 @@ from .augmentation import pad_audio_to_length
|
|
19
21
|
from .class_count import get_class_count_from_mixids
|
20
22
|
from .config import get_default_config
|
21
23
|
from .config import get_impulse_response_files
|
22
|
-
from .config import get_max_class
|
23
24
|
from .config import get_noise_files
|
24
25
|
from .config import get_spectral_masks
|
25
26
|
from .config import get_target_files
|
27
|
+
from .config import get_truth_parameters
|
26
28
|
from .config import load_config
|
27
29
|
from .config import raw_load_config
|
28
30
|
from .config import update_config_from_file
|
29
31
|
from .config import update_config_from_hierarchy
|
30
|
-
from .config import
|
32
|
+
from .config import validate_truth_configs
|
31
33
|
from .constants import BIT_DEPTH
|
32
34
|
from .constants import CHANNEL_COUNT
|
33
35
|
from .constants import DEFAULT_CONFIG
|
@@ -35,19 +37,22 @@ from .constants import DEFAULT_NOISE
|
|
35
37
|
from .constants import DEFAULT_SPEECH
|
36
38
|
from .constants import ENCODING
|
37
39
|
from .constants import FLOAT_BYTES
|
40
|
+
from .constants import MIXDB_VERSION
|
38
41
|
from .constants import RAND_PATTERN
|
39
42
|
from .constants import REQUIRED_CONFIGS
|
43
|
+
from .constants import REQUIRED_TRUTH_CONFIGS
|
40
44
|
from .constants import SAMPLE_BYTES
|
41
45
|
from .constants import SAMPLE_RATE
|
42
46
|
from .constants import VALID_AUGMENTATIONS
|
43
47
|
from .constants import VALID_CONFIGS
|
44
48
|
from .constants import VALID_NOISE_MIX_MODES
|
45
|
-
from .
|
49
|
+
from .data_io import read_cached_data
|
50
|
+
from .data_io import write_cached_data
|
46
51
|
from .datatypes import AudioF
|
47
|
-
from .datatypes import AudioStatsMetrics
|
48
|
-
from .datatypes import AudioT
|
49
52
|
from .datatypes import AudiosF
|
50
53
|
from .datatypes import AudiosT
|
54
|
+
from .datatypes import AudioStatsMetrics
|
55
|
+
from .datatypes import AudioT
|
51
56
|
from .datatypes import Augmentation
|
52
57
|
from .datatypes import AugmentationRule
|
53
58
|
from .datatypes import AugmentationRules
|
@@ -60,10 +65,11 @@ from .datatypes import EnergyT
|
|
60
65
|
from .datatypes import Feature
|
61
66
|
from .datatypes import FeatureGeneratorConfig
|
62
67
|
from .datatypes import FeatureGeneratorInfo
|
68
|
+
from .datatypes import GeneralizedIDs
|
63
69
|
from .datatypes import GenFTData
|
64
70
|
from .datatypes import GenMixData
|
65
|
-
from .datatypes import GeneralizedIDs
|
66
71
|
from .datatypes import ImpulseResponseData
|
72
|
+
from .datatypes import ImpulseResponseFile
|
67
73
|
from .datatypes import ImpulseResponseFiles
|
68
74
|
from .datatypes import ListAudiosT
|
69
75
|
from .datatypes import MetricDoc
|
@@ -84,9 +90,11 @@ from .datatypes import TargetFile
|
|
84
90
|
from .datatypes import TargetFiles
|
85
91
|
from .datatypes import TransformConfig
|
86
92
|
from .datatypes import Truth
|
87
|
-
from .datatypes import
|
88
|
-
from .datatypes import
|
89
|
-
from .datatypes import
|
93
|
+
from .datatypes import TruthConfig
|
94
|
+
from .datatypes import TruthConfigs
|
95
|
+
from .datatypes import TruthDict
|
96
|
+
from .datatypes import TruthParameter
|
97
|
+
from .datatypes import TruthParameters
|
90
98
|
from .datatypes import UniversalSNR
|
91
99
|
from .feature import get_audio_from_feature
|
92
100
|
from .feature import get_feature_from_audio
|
@@ -101,6 +109,7 @@ from .generation import populate_noise_file_table
|
|
101
109
|
from .generation import populate_spectral_mask_table
|
102
110
|
from .generation import populate_target_file_table
|
103
111
|
from .generation import populate_top_table
|
112
|
+
from .generation import populate_truth_parameters_table
|
104
113
|
from .generation import update_mixid_width
|
105
114
|
from .generation import update_mixture
|
106
115
|
from .helpers import augmented_noise_samples
|
@@ -112,11 +121,9 @@ from .helpers import get_audio_from_transform
|
|
112
121
|
from .helpers import get_ft
|
113
122
|
from .helpers import get_segsnr
|
114
123
|
from .helpers import get_transform_from_audio
|
115
|
-
from .helpers import
|
124
|
+
from .helpers import get_truth
|
116
125
|
from .helpers import inverse_transform
|
117
126
|
from .helpers import mixture_metadata
|
118
|
-
from .helpers import read_mixture_data
|
119
|
-
from .helpers import write_mixture_data
|
120
127
|
from .helpers import write_mixture_metadata
|
121
128
|
from .log_duration_and_sizes import log_duration_and_sizes
|
122
129
|
from .mixdb import MixtureDatabase
|
@@ -128,9 +135,8 @@ from .targets import get_augmented_target_ids_by_class
|
|
128
135
|
from .targets import get_augmented_target_ids_for_mixup
|
129
136
|
from .targets import get_augmented_targets
|
130
137
|
from .targets import get_target_augmentations_for_mixup
|
131
|
-
from .targets import get_truth_indices_for_target
|
132
138
|
from .tokenized_shell_vars import tokenized_expand
|
133
139
|
from .tokenized_shell_vars import tokenized_replace
|
134
140
|
from .truth import get_truth_indices_for_mixid
|
135
141
|
from .truth import truth_function
|
136
|
-
from .truth import
|
142
|
+
from .truth import truth_stride_reduction
|
sonusai/mixture/audio.py
CHANGED
@@ -15,7 +15,7 @@ def get_next_noise(audio: AudioT, offset: int, length: int) -> AudioT:
|
|
15
15
|
"""
|
16
16
|
import numpy as np
|
17
17
|
|
18
|
-
return np.take(audio, range(offset, offset + length), mode=
|
18
|
+
return np.take(audio, range(offset, offset + length), mode="wrap")
|
19
19
|
|
20
20
|
|
21
21
|
def get_duration(audio: AudioT) -> float:
|
@@ -35,15 +35,13 @@ def validate_input_file(input_filepath: str | Path) -> None:
|
|
35
35
|
|
36
36
|
from soundfile import available_formats
|
37
37
|
|
38
|
-
from sonusai import SonusAIError
|
39
|
-
|
40
38
|
if not exists(input_filepath):
|
41
|
-
raise
|
39
|
+
raise OSError(f"input_filepath {input_filepath} does not exist.")
|
42
40
|
|
43
41
|
ext = splitext(input_filepath)[1][1:].lower()
|
44
|
-
read_formats = [item.lower() for item in available_formats()
|
42
|
+
read_formats = [item.lower() for item in available_formats()]
|
45
43
|
if ext not in read_formats:
|
46
|
-
raise
|
44
|
+
raise OSError(f"This installation cannot process .{ext} files")
|
47
45
|
|
48
46
|
|
49
47
|
@lru_cache
|