sonusai 0.19.6__py3-none-any.whl → 0.19.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +1 -1
- sonusai/aawscd_probwrite.py +1 -1
- sonusai/calc_metric_spenh.py +1 -1
- sonusai/genft.py +29 -14
- sonusai/genmetrics.py +60 -42
- sonusai/genmix.py +41 -29
- sonusai/genmixdb.py +54 -62
- sonusai/metrics/calc_class_weights.py +1 -3
- sonusai/metrics/calc_optimal_thresholds.py +2 -2
- sonusai/metrics/calc_phase_distance.py +1 -1
- sonusai/metrics/calc_speech.py +6 -6
- sonusai/metrics/class_summary.py +6 -15
- sonusai/metrics/confusion_matrix_summary.py +11 -27
- sonusai/metrics/one_hot.py +3 -3
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/mixture/__init__.py +2 -17
- sonusai/mixture/augmentation.py +5 -6
- sonusai/mixture/class_count.py +1 -1
- sonusai/mixture/config.py +36 -46
- sonusai/mixture/data_io.py +30 -1
- sonusai/mixture/datatypes.py +29 -40
- sonusai/mixture/db_datatypes.py +1 -1
- sonusai/mixture/feature.py +3 -23
- sonusai/mixture/generation.py +202 -235
- sonusai/mixture/helpers.py +29 -187
- sonusai/mixture/mixdb.py +386 -159
- sonusai/mixture/soundfile_audio.py +1 -1
- sonusai/mixture/sox_audio.py +4 -4
- sonusai/mixture/sox_augmentation.py +1 -1
- sonusai/mixture/target_class_balancing.py +9 -11
- sonusai/mixture/targets.py +23 -20
- sonusai/mixture/truth.py +21 -34
- sonusai/mixture/truth_functions/__init__.py +6 -0
- sonusai/mixture/truth_functions/crm.py +51 -37
- sonusai/mixture/truth_functions/energy.py +95 -50
- sonusai/mixture/truth_functions/file.py +12 -8
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +4 -5
- sonusai/mixture/truth_functions/sed.py +32 -23
- sonusai/mixture/truth_functions/target.py +62 -29
- sonusai/mkwav.py +20 -19
- sonusai/queries/queries.py +9 -15
- sonusai/speech/l2arctic.py +6 -2
- sonusai/summarize_metric_spenh.py +1 -1
- sonusai/utils/__init__.py +1 -0
- sonusai/utils/asr_functions/aaware_whisper.py +1 -1
- sonusai/utils/audio_devices.py +27 -18
- sonusai/utils/docstring.py +6 -3
- sonusai/utils/energy_f.py +5 -3
- sonusai/utils/human_readable_size.py +6 -6
- sonusai/utils/load_object.py +15 -0
- sonusai/utils/onnx_utils.py +2 -2
- sonusai/utils/print_mixture_details.py +3 -3
- {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/METADATA +2 -2
- {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/RECORD +58 -56
- sonusai/mixture/truth_functions/datatypes.py +0 -37
- {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/WHEEL +0 -0
- {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/entry_points.txt +0 -0
@@ -54,7 +54,7 @@ def calc_class_weights_from_truth(truth: Truth, other_weight: float | None = Non
|
|
54
54
|
|
55
55
|
def calc_class_weights_from_mixdb(
|
56
56
|
mixdb: MixtureDatabase,
|
57
|
-
mixids: GeneralizedIDs
|
57
|
+
mixids: GeneralizedIDs = "*",
|
58
58
|
other_weight: float = 1,
|
59
59
|
other_index: int = -1,
|
60
60
|
) -> tuple[np.ndarray, np.ndarray]:
|
@@ -77,8 +77,6 @@ def calc_class_weights_from_mixdb(
|
|
77
77
|
from sonusai.mixture import get_class_count_from_mixids
|
78
78
|
|
79
79
|
count = np.ceil(np.array(get_class_count_from_mixids(mixdb=mixdb, mixids=mixids)) / mixdb.feature_step_samples)
|
80
|
-
if mixdb.truth_mutex and other_weight is not None and other_weight > 0:
|
81
|
-
count[other_index] = count[other_index] / np.float32(other_weight)
|
82
80
|
total_features = sum(count)
|
83
81
|
|
84
82
|
weights = np.empty(mixdb.num_classes, dtype=np.float32)
|
@@ -51,8 +51,8 @@ def calc_optimal_thresholds(
|
|
51
51
|
AUC[nci] = np.NaN
|
52
52
|
AP[nci] = np.NaN
|
53
53
|
else:
|
54
|
-
AP[nci] = average_precision_score(truth_binary[:, nci], predict[:, nci], average=None)
|
55
|
-
AUC[nci] = roc_auc_score(truth_binary[:, nci], predict[:, nci], average=None)
|
54
|
+
AP[nci] = average_precision_score(truth_binary[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
|
55
|
+
AUC[nci] = roc_auc_score(truth_binary[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
|
56
56
|
|
57
57
|
# Optimal threshold from PR curve, optimizes f-score
|
58
58
|
precision, recall, thrpr = precision_recall_curve(truth_binary[:, nci], predict[:, nci])
|
@@ -26,7 +26,7 @@ def calc_phase_distance(
|
|
26
26
|
# weighted mean over all (scalar)
|
27
27
|
reference_mag = np.abs(reference)
|
28
28
|
ref_weight = reference_mag / (np.sum(reference_mag) + eps) # frames x bins
|
29
|
-
err = np.around(np.sum(ref_weight * rh_angle_diff), 3)
|
29
|
+
err = float(np.around(np.sum(ref_weight * rh_angle_diff), 3))
|
30
30
|
|
31
31
|
# weighted mean over frames (value per bin)
|
32
32
|
err_b = np.zeros(reference.shape[1])
|
sonusai/metrics/calc_speech.py
CHANGED
@@ -32,16 +32,16 @@ def calc_speech(hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int
|
|
32
32
|
llr_mean = np.mean(ll_rs[:llr_len])
|
33
33
|
|
34
34
|
# Segmental SNR
|
35
|
-
|
35
|
+
_, segsnr_dist = _calc_snr(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
|
36
36
|
seg_snr = np.mean(segsnr_dist)
|
37
37
|
|
38
38
|
# PESQ
|
39
39
|
_pesq = calc_pesq(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
|
40
40
|
|
41
41
|
# Now compute the composite measures
|
42
|
-
csig = np.clip(3.093 - 1.029 * llr_mean + 0.603 * _pesq - 0.009 * wss_dist, 1, 5)
|
43
|
-
cbak = np.clip(1.634 + 0.478 * _pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5)
|
44
|
-
covl = np.clip(1.594 + 0.805 * _pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5)
|
42
|
+
csig = float(np.clip(3.093 - 1.029 * llr_mean + 0.603 * _pesq - 0.009 * wss_dist, 1, 5))
|
43
|
+
cbak = float(np.clip(1.634 + 0.478 * _pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5))
|
44
|
+
covl = float(np.clip(1.594 + 0.805 * _pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5))
|
45
45
|
|
46
46
|
return SpeechMetrics(_pesq, csig, cbak, covl)
|
47
47
|
|
@@ -284,8 +284,8 @@ def _calc_log_likelihood_ratio_measure(
|
|
284
284
|
hypothesis_frame = np.multiply(hypothesis_frame, window)
|
285
285
|
|
286
286
|
# (2) Get the autocorrelation lags and LPC parameters used to compute the log likelihood ratio measure.
|
287
|
-
r_reference,
|
288
|
-
|
287
|
+
r_reference, _, a_reference = _lp_coefficients(reference_frame, p)
|
288
|
+
_, _, a_hypothesis = _lp_coefficients(hypothesis_frame, p)
|
289
289
|
|
290
290
|
# (3) Compute the log likelihood ratio measure
|
291
291
|
numerator = np.dot(np.matmul(a_hypothesis, toeplitz(r_reference)), a_hypothesis)
|
sonusai/metrics/class_summary.py
CHANGED
@@ -38,7 +38,7 @@ def class_summary(
|
|
38
38
|
# TODO: re-work for modern mixdb API
|
39
39
|
y_truth_f, y_predict = get_mixids_data(mixdb, mixids, truth_f, predict) # type: ignore[name-defined]
|
40
40
|
|
41
|
-
if
|
41
|
+
if num_classes > 1:
|
42
42
|
if not isinstance(predict_thr, np.ndarray):
|
43
43
|
if predict_thr == 0:
|
44
44
|
predict_thr = np.atleast_1d(0.5)
|
@@ -53,25 +53,16 @@ def class_summary(
|
|
53
53
|
# [ACC, TPR, PPV, TNR, FPR, HITFA, F1, MCC, NT, PT, TP, FP, AP, AUC]
|
54
54
|
table_idx = np.array([2, 1, 6, 4, 0, 12, 13, 9])
|
55
55
|
col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC", "Support"]
|
56
|
-
if mixdb.
|
57
|
-
|
58
|
-
row_n = mixdb.class_labels
|
59
|
-
if len(mixdb.class_labels) == num_classes - 1: # Other label does not exist, so add it
|
60
|
-
row_n.append("Other")
|
61
|
-
else:
|
62
|
-
row_n = [f"Class {i}" for i in range(1, num_classes)]
|
63
|
-
row_n.append("Other")
|
56
|
+
if len(mixdb.class_labels) == num_classes:
|
57
|
+
row_n = mixdb.class_labels
|
64
58
|
else:
|
65
|
-
|
66
|
-
row_n = mixdb.class_labels
|
67
|
-
else:
|
68
|
-
row_n = [f"Class {i}" for i in range(1, num_classes + 1)]
|
59
|
+
row_n = [f"Class {i}" for i in range(1, num_classes + 1)]
|
69
60
|
|
70
|
-
df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n)
|
61
|
+
df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n) # pyright: ignore [reportArgumentType]
|
71
62
|
|
72
63
|
# [miPPV, miTPR, miF1, miFPR, miACC, miAP, miAUC, TPSUM]
|
73
64
|
avg_row_n = ["Macro-avg", "Micro-avg", "Weighted-avg"]
|
74
|
-
dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n)
|
65
|
+
dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n) # pyright: ignore [reportArgumentType]
|
75
66
|
|
76
67
|
# dfblank = pd.DataFrame([''])
|
77
68
|
# pd.concat([df, dfblank, dfblank, dfavg])
|
@@ -37,7 +37,7 @@ def confusion_matrix_summary(
|
|
37
37
|
ytrue, ypred = get_mixids_data(mixdb=mixdb, mixids=mixids, truth_f=truth_f, predict=predict) # type: ignore[name-defined]
|
38
38
|
|
39
39
|
# Check predict_thr array or scalar and return final scalar predict_thr value
|
40
|
-
if
|
40
|
+
if num_classes > 1:
|
41
41
|
if not isinstance(predict_thr, np.ndarray):
|
42
42
|
if predict_thr == 0:
|
43
43
|
# multi-label predict_thr scalar 0 force to 0.5 default
|
@@ -61,31 +61,15 @@ def confusion_matrix_summary(
|
|
61
61
|
else:
|
62
62
|
class_names = [f"Class {i}" for i in range(1, num_classes + 1)]
|
63
63
|
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
# mux = pd.MultiIndex.from_product([['truth thr = {}'.format(truth_thr)], class_nums])
|
75
|
-
|
76
|
-
cmdf = pd.DataFrame(cm, index=row_n, columns=class_nums, dtype=np.int32)
|
77
|
-
cmndf = pd.DataFrame(cmn, index=row_n, columns=class_nums, dtype=np.float32)
|
78
|
-
|
79
|
-
else:
|
80
|
-
_, _, cm, cmn, _, _ = one_hot(ytrue[:, class_idx], ypred[:, class_idx], predict_thr, truth_thr, timesteps)
|
81
|
-
cname = class_names[class_idx]
|
82
|
-
row_n = ["TrueN", "TrueP"]
|
83
|
-
col_n = ["N-" + cname, "P-" + cname]
|
84
|
-
cmdf = pd.DataFrame(cm, index=row_n, columns=col_n, dtype=np.int32)
|
85
|
-
cmndf = pd.DataFrame(cmn, index=row_n, columns=col_n, dtype=np.float32)
|
86
|
-
# add thresholds in 3rd row
|
87
|
-
pdnote = pd.DataFrame(np.atleast_2d([predict_thr, truth_thr]), index=["p/t thr:"], columns=col_n)
|
88
|
-
cmdf = pd.concat([cmdf, pdnote])
|
89
|
-
cmndf = pd.concat([cmndf, pdnote])
|
64
|
+
_, _, cm, cmn, _, _ = one_hot(ytrue[:, class_idx], ypred[:, class_idx], predict_thr, truth_thr, timesteps)
|
65
|
+
cname = class_names[class_idx]
|
66
|
+
row_n = ["TrueN", "TrueP"]
|
67
|
+
col_n = ["N-" + cname, "P-" + cname]
|
68
|
+
cmdf = pd.DataFrame(cm, index=row_n, columns=col_n, dtype=np.int32) # pyright: ignore [reportArgumentType]
|
69
|
+
cmndf = pd.DataFrame(cmn, index=row_n, columns=col_n, dtype=np.float32) # pyright: ignore [reportArgumentType]
|
70
|
+
# add thresholds in 3rd row
|
71
|
+
pdnote = pd.DataFrame(np.atleast_2d([predict_thr, truth_thr]), index=["p/t thr:"], columns=col_n) # pyright: ignore [reportArgumentType, reportCallIssue]
|
72
|
+
cmdf = pd.concat([cmdf, pdnote])
|
73
|
+
cmndf = pd.concat([cmndf, pdnote])
|
90
74
|
|
91
75
|
return cmdf, cmndf
|
sonusai/metrics/one_hot.py
CHANGED
@@ -185,11 +185,11 @@ def one_hot(
|
|
185
185
|
AP = np.NaN
|
186
186
|
# threshold_optpr[nci] = np.NaN
|
187
187
|
else:
|
188
|
-
AP = average_precision_score(truthb[:, nci], predict[:, nci], average=None)
|
188
|
+
AP = average_precision_score(truthb[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
|
189
189
|
if len(np.unique(truthb[:, nci])) < 2: # if active classes not > 1 AUC must be NaN
|
190
190
|
AUC = np.NaN # i.e. all ones sklearn auc will fail
|
191
191
|
else:
|
192
|
-
AUC = roc_auc_score(truthb[:, nci], predict[:, nci], average=None)
|
192
|
+
AUC = roc_auc_score(truthb[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
|
193
193
|
# # Optimal threshold from PR curve, optimizes f-score
|
194
194
|
# precision, recall, thresholds = precision_recall_curve(truthb[:, nci], predict[:, nci])
|
195
195
|
# fscore = (2 * precision * recall) / (precision + recall)
|
@@ -263,7 +263,7 @@ def one_hot(
|
|
263
263
|
] # specific format, last 3 are unique
|
264
264
|
|
265
265
|
# weighted average TBD
|
266
|
-
wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average="weighted", zero_division=0)
|
266
|
+
wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average="weighted", zero_division=0) # pyright: ignore [reportArgumentType]
|
267
267
|
if np.sum(truthb):
|
268
268
|
taidx = np.sum(truthb, axis=0) > 0
|
269
269
|
wap = average_precision_score(truthb[:, taidx], predict[:, taidx], average="weighted")
|
sonusai/metrics/snr_summary.py
CHANGED
@@ -48,7 +48,7 @@ def snr_summary(
|
|
48
48
|
snr_mixids = get_mixids_from_snr(mixdb=mixdb, mixids=mixid)
|
49
49
|
|
50
50
|
# Check predict_thr array or scalar and return final scalar predict_thr value
|
51
|
-
if
|
51
|
+
if num_classes > 1:
|
52
52
|
if not isinstance(predict_thr, np.ndarray):
|
53
53
|
if predict_thr == 0:
|
54
54
|
# multi-label predict_thr scalar 0 force to 0.5 default
|
@@ -84,7 +84,7 @@ def snr_summary(
|
|
84
84
|
for ii, snr in enumerate(snr_mixids):
|
85
85
|
# TODO: re-work for modern mixdb API
|
86
86
|
y_truth, y_predict = get_mixids_data(mixdb, snr_mixids[snr], truth_f, predict) # type: ignore[name-defined]
|
87
|
-
_,
|
87
|
+
_, _, _, _, _, mavg = one_hot(y_truth, y_predict, predict_thr, truth_thr, timesteps)
|
88
88
|
|
89
89
|
# mavg macro, micro, weighted: [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
|
90
90
|
macro_avg[ii, :] = mavg[0, 0:7]
|
@@ -104,21 +104,21 @@ def snr_summary(
|
|
104
104
|
|
105
105
|
# SNR format: PPV, TPR, F1, FPR, ACC, AP, AUC
|
106
106
|
col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC"]
|
107
|
-
snr_macrodf = pd.DataFrame(macro_avg, index=list(snr_mixids.keys()), columns=col_n)
|
107
|
+
snr_macrodf = pd.DataFrame(macro_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
|
108
108
|
snr_macrodf.sort_index(ascending=False, inplace=True)
|
109
109
|
|
110
|
-
snr_microdf = pd.DataFrame(micro_avg, index=list(snr_mixids.keys()), columns=col_n)
|
110
|
+
snr_microdf = pd.DataFrame(micro_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
|
111
111
|
snr_microdf.sort_index(ascending=False, inplace=True)
|
112
112
|
|
113
|
-
snr_wghtdf = pd.DataFrame(wghtd_avg, index=list(snr_mixids.keys()), columns=col_n)
|
113
|
+
snr_wghtdf = pd.DataFrame(wghtd_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
|
114
114
|
snr_wghtdf.sort_index(ascending=False, inplace=True)
|
115
115
|
|
116
116
|
# Add segmental SNR columns if provided
|
117
117
|
if segsnr is not None:
|
118
118
|
ssnrdf = pd.DataFrame(
|
119
119
|
ssnr_stats,
|
120
|
-
index=list(snr_mixids.keys()),
|
121
|
-
columns=["SSNRavg", "SSNR80p", "SSNRmax"],
|
120
|
+
index=list(snr_mixids.keys()), # pyright: ignore [reportArgumentType]
|
121
|
+
columns=["SSNRavg", "SSNR80p", "SSNRmax"], # pyright: ignore [reportArgumentType]
|
122
122
|
)
|
123
123
|
ssnrdf.sort_index(ascending=False, inplace=True)
|
124
124
|
snr_macrodf = pd.concat([snr_macrodf, ssnrdf], axis=1)
|
sonusai/mixture/__init__.py
CHANGED
@@ -46,19 +46,15 @@ from .constants import SAMPLE_RATE
|
|
46
46
|
from .constants import VALID_AUGMENTATIONS
|
47
47
|
from .constants import VALID_CONFIGS
|
48
48
|
from .constants import VALID_NOISE_MIX_MODES
|
49
|
+
from .data_io import clear_cached_data
|
49
50
|
from .data_io import read_cached_data
|
50
51
|
from .data_io import write_cached_data
|
51
52
|
from .datatypes import AudioF
|
52
|
-
from .datatypes import AudiosF
|
53
|
-
from .datatypes import AudiosT
|
54
53
|
from .datatypes import AudioStatsMetrics
|
55
54
|
from .datatypes import AudioT
|
56
55
|
from .datatypes import Augmentation
|
57
56
|
from .datatypes import AugmentationRule
|
58
|
-
from .datatypes import AugmentationRules
|
59
|
-
from .datatypes import Augmentations
|
60
57
|
from .datatypes import AugmentedTarget
|
61
|
-
from .datatypes import AugmentedTargets
|
62
58
|
from .datatypes import ClassCount
|
63
59
|
from .datatypes import EnergyF
|
64
60
|
from .datatypes import EnergyT
|
@@ -70,35 +66,27 @@ from .datatypes import GenFTData
|
|
70
66
|
from .datatypes import GenMixData
|
71
67
|
from .datatypes import ImpulseResponseData
|
72
68
|
from .datatypes import ImpulseResponseFile
|
73
|
-
from .datatypes import ImpulseResponseFiles
|
74
|
-
from .datatypes import ListAudiosT
|
75
69
|
from .datatypes import MetricDoc
|
76
70
|
from .datatypes import MetricDocs
|
77
71
|
from .datatypes import Mixture
|
78
72
|
from .datatypes import MixtureDatabaseConfig
|
79
|
-
from .datatypes import Mixtures
|
80
73
|
from .datatypes import NoiseFile
|
81
|
-
from .datatypes import NoiseFiles
|
82
74
|
from .datatypes import Predict
|
83
75
|
from .datatypes import Segsnr
|
84
76
|
from .datatypes import SnrFMetrics
|
85
77
|
from .datatypes import SpectralMask
|
86
|
-
from .datatypes import SpectralMasks
|
87
78
|
from .datatypes import SpeechMetadata
|
88
79
|
from .datatypes import SpeechMetrics
|
89
80
|
from .datatypes import TargetFile
|
90
|
-
from .datatypes import TargetFiles
|
91
81
|
from .datatypes import TransformConfig
|
92
82
|
from .datatypes import Truth
|
93
83
|
from .datatypes import TruthConfig
|
94
84
|
from .datatypes import TruthConfigs
|
95
85
|
from .datatypes import TruthDict
|
96
86
|
from .datatypes import TruthParameter
|
97
|
-
from .datatypes import TruthParameters
|
98
87
|
from .datatypes import UniversalSNR
|
99
88
|
from .feature import get_audio_from_feature
|
100
89
|
from .feature import get_feature_from_audio
|
101
|
-
from .generation import generate_mixtures
|
102
90
|
from .generation import get_all_snrs_from_config
|
103
91
|
from .generation import initialize_db
|
104
92
|
from .generation import populate_class_label_table
|
@@ -111,17 +99,14 @@ from .generation import populate_target_file_table
|
|
111
99
|
from .generation import populate_top_table
|
112
100
|
from .generation import populate_truth_parameters_table
|
113
101
|
from .generation import update_mixid_width
|
114
|
-
from .generation import
|
102
|
+
from .generation import update_mixture_table
|
115
103
|
from .helpers import augmented_noise_samples
|
116
104
|
from .helpers import augmented_target_samples
|
117
105
|
from .helpers import check_audio_files_exist
|
118
106
|
from .helpers import forward_transform
|
119
107
|
from .helpers import frames_from_samples
|
120
108
|
from .helpers import get_audio_from_transform
|
121
|
-
from .helpers import get_ft
|
122
|
-
from .helpers import get_segsnr
|
123
109
|
from .helpers import get_transform_from_audio
|
124
|
-
from .helpers import get_truth
|
125
110
|
from .helpers import inverse_transform
|
126
111
|
from .helpers import mixture_metadata
|
127
112
|
from .helpers import write_mixture_metadata
|
sonusai/mixture/augmentation.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1
1
|
from sonusai.mixture.datatypes import AudioT
|
2
2
|
from sonusai.mixture.datatypes import Augmentation
|
3
3
|
from sonusai.mixture.datatypes import AugmentationRule
|
4
|
-
from sonusai.mixture.datatypes import AugmentationRules
|
5
4
|
from sonusai.mixture.datatypes import ImpulseResponseData
|
6
5
|
from sonusai.mixture.datatypes import OptionalNumberStr
|
7
6
|
|
8
7
|
|
9
|
-
def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) ->
|
8
|
+
def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> list[AugmentationRule]:
|
10
9
|
"""Generate augmentation rules from list of input rules
|
11
10
|
|
12
11
|
:param rules: Dictionary of augmentation config rule[s]
|
@@ -25,7 +24,7 @@ def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> Augment
|
|
25
24
|
rule = _parse_ir(rule, num_ir)
|
26
25
|
processed_rules = _expand_rules(expanded_rules=processed_rules, rule=rule)
|
27
26
|
|
28
|
-
return [dataclass_from_dict(AugmentationRule, processed_rule) for processed_rule in processed_rules]
|
27
|
+
return [dataclass_from_dict(AugmentationRule, processed_rule) for processed_rule in processed_rules] # pyright: ignore [reportReturnType]
|
29
28
|
|
30
29
|
|
31
30
|
def _expand_rules(expanded_rules: list[dict], rule: dict) -> list[dict]:
|
@@ -163,7 +162,7 @@ def estimate_augmented_length_from_length(length: int, tempo: OptionalNumberStr
|
|
163
162
|
return length
|
164
163
|
|
165
164
|
|
166
|
-
def get_mixups(augmentations:
|
165
|
+
def get_mixups(augmentations: list[AugmentationRule]) -> list[int]:
|
167
166
|
"""Get a list of mixup values used
|
168
167
|
|
169
168
|
:param augmentations: List of augmentations
|
@@ -172,7 +171,7 @@ def get_mixups(augmentations: AugmentationRules) -> list[int]:
|
|
172
171
|
return sorted({augmentation.mixup for augmentation in augmentations})
|
173
172
|
|
174
173
|
|
175
|
-
def get_augmentation_indices_for_mixup(augmentations:
|
174
|
+
def get_augmentation_indices_for_mixup(augmentations: list[AugmentationRule], mixup: int) -> list[int]:
|
176
175
|
"""Get a list of augmentation indices for a given mixup value
|
177
176
|
|
178
177
|
:param augmentations: List of augmentations
|
@@ -327,4 +326,4 @@ def augmentation_from_rule(rule: AugmentationRule, num_ir: int) -> Augmentation:
|
|
327
326
|
if _rule_has_rand(processed_rule):
|
328
327
|
processed_rule = _generate_random_rule(processed_rule, num_ir)
|
329
328
|
|
330
|
-
return dataclass_from_dict(Augmentation, processed_rule)
|
329
|
+
return dataclass_from_dict(Augmentation, processed_rule) # pyright: ignore [reportReturnType]
|
sonusai/mixture/class_count.py
CHANGED
@@ -3,7 +3,7 @@ from sonusai.mixture.datatypes import GeneralizedIDs
|
|
3
3
|
from sonusai.mixture.mixdb import MixtureDatabase
|
4
4
|
|
5
5
|
|
6
|
-
def get_class_count_from_mixids(mixdb: MixtureDatabase, mixids: GeneralizedIDs
|
6
|
+
def get_class_count_from_mixids(mixdb: MixtureDatabase, mixids: GeneralizedIDs = "*") -> ClassCount:
|
7
7
|
"""Sums the class counts for given mixids"""
|
8
8
|
total_class_count = [0] * mixdb.num_classes
|
9
9
|
m_ids = mixdb.mixids_to_list(mixids)
|
sonusai/mixture/config.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1
1
|
from sonusai.mixture.datatypes import ImpulseResponseFile
|
2
|
-
from sonusai.mixture.datatypes import
|
3
|
-
from sonusai.mixture.datatypes import
|
4
|
-
from sonusai.mixture.datatypes import
|
5
|
-
from sonusai.mixture.datatypes import
|
6
|
-
from sonusai.mixture.datatypes import TruthParameters
|
2
|
+
from sonusai.mixture.datatypes import NoiseFile
|
3
|
+
from sonusai.mixture.datatypes import SpectralMask
|
4
|
+
from sonusai.mixture.datatypes import TargetFile
|
5
|
+
from sonusai.mixture.datatypes import TruthParameter
|
7
6
|
|
8
7
|
|
9
8
|
def raw_load_config(name: str) -> dict:
|
@@ -210,7 +209,7 @@ def update_config_from_hierarchy(root: str, leaf: str, config: dict) -> dict:
|
|
210
209
|
return new_config
|
211
210
|
|
212
211
|
|
213
|
-
def get_target_files(config: dict, show_progress: bool = False) ->
|
212
|
+
def get_target_files(config: dict, show_progress: bool = False) -> list[TargetFile]:
|
214
213
|
"""Get the list of target files from a config
|
215
214
|
|
216
215
|
:param config: Config dictionary
|
@@ -223,7 +222,7 @@ def get_target_files(config: dict, show_progress: bool = False) -> TargetFiles:
|
|
223
222
|
from sonusai.utils import par_track
|
224
223
|
from sonusai.utils import track
|
225
224
|
|
226
|
-
from .datatypes import
|
225
|
+
from .datatypes import TargetFile
|
227
226
|
|
228
227
|
class_indices = config["class_indices"]
|
229
228
|
if not isinstance(class_indices, list):
|
@@ -255,7 +254,7 @@ def get_target_files(config: dict, show_progress: bool = False) -> TargetFiles:
|
|
255
254
|
if any(class_index > num_classes for class_index in target_file["class_indices"]):
|
256
255
|
raise ValueError(f"class index elements must not be greater than {num_classes}")
|
257
256
|
|
258
|
-
return dataclass_from_dict(
|
257
|
+
return dataclass_from_dict(list[TargetFile], target_files)
|
259
258
|
|
260
259
|
|
261
260
|
def append_target_files(
|
@@ -294,6 +293,7 @@ def append_target_files(
|
|
294
293
|
if tokens is None:
|
295
294
|
tokens = {}
|
296
295
|
|
296
|
+
truth_configs_merged = deepcopy(truth_configs)
|
297
297
|
if isinstance(entry, dict):
|
298
298
|
if "name" in entry:
|
299
299
|
in_name = entry["name"]
|
@@ -312,15 +312,11 @@ def append_target_files(
|
|
312
312
|
raise AttributeError(
|
313
313
|
f"Truth config '{key}' override specified for {entry['name']} is not defined at top level"
|
314
314
|
)
|
315
|
-
|
316
|
-
for key in truth_configs_override:
|
317
|
-
truth_configs_merged[key] = deepcopy(truth_configs[key])
|
318
|
-
if truth_configs_override[key] is not None:
|
315
|
+
if key in truth_configs_override:
|
319
316
|
truth_configs_merged[key] |= truth_configs_override[key]
|
320
317
|
level_type = entry.get("level_type", level_type)
|
321
318
|
else:
|
322
319
|
in_name = entry
|
323
|
-
truth_configs_merged = deepcopy(truth_configs)
|
324
320
|
|
325
321
|
in_name, new_tokens = tokenized_expand(in_name)
|
326
322
|
tokens.update(new_tokens)
|
@@ -416,7 +412,7 @@ def append_target_files(
|
|
416
412
|
return target_files
|
417
413
|
|
418
414
|
|
419
|
-
def get_noise_files(config: dict, show_progress: bool = False) ->
|
415
|
+
def get_noise_files(config: dict, show_progress: bool = False) -> list[NoiseFile]:
|
420
416
|
"""Get the list of noise files from a config
|
421
417
|
|
422
418
|
:param config: Config dictionary
|
@@ -429,7 +425,7 @@ def get_noise_files(config: dict, show_progress: bool = False) -> NoiseFiles:
|
|
429
425
|
from sonusai.utils import par_track
|
430
426
|
from sonusai.utils import track
|
431
427
|
|
432
|
-
from .datatypes import
|
428
|
+
from .datatypes import NoiseFile
|
433
429
|
|
434
430
|
noise_files = list(chain.from_iterable([append_noise_files(entry=entry) for entry in config["noises"]]))
|
435
431
|
|
@@ -437,7 +433,7 @@ def get_noise_files(config: dict, show_progress: bool = False) -> NoiseFiles:
|
|
437
433
|
noise_files = par_track(_get_num_samples, noise_files, progress=progress)
|
438
434
|
progress.close()
|
439
435
|
|
440
|
-
return dataclass_from_dict(
|
436
|
+
return dataclass_from_dict(list[NoiseFile], noise_files)
|
441
437
|
|
442
438
|
|
443
439
|
def append_noise_files(entry: dict | str, tokens: dict | None = None) -> list[dict]:
|
@@ -522,26 +518,25 @@ def append_noise_files(entry: dict | str, tokens: dict | None = None) -> list[di
|
|
522
518
|
return noise_files
|
523
519
|
|
524
520
|
|
525
|
-
def get_impulse_response_files(config: dict) ->
|
521
|
+
def get_impulse_response_files(config: dict) -> list[ImpulseResponseFile]:
|
526
522
|
"""Get the list of impulse response files from a config
|
527
523
|
|
528
524
|
:param config: Config dictionary
|
529
525
|
:return: List of impulse response files
|
530
526
|
"""
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | None = None) -> list[str]:
|
527
|
+
from itertools import chain
|
528
|
+
|
529
|
+
return list(
|
530
|
+
chain.from_iterable(
|
531
|
+
[
|
532
|
+
append_impulse_response_files(entry=ImpulseResponseFile(entry["name"], entry.get("tags", [])))
|
533
|
+
for entry in config["impulse_responses"]
|
534
|
+
]
|
535
|
+
)
|
536
|
+
)
|
537
|
+
|
538
|
+
|
539
|
+
def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | None = None) -> list[ImpulseResponseFile]:
|
545
540
|
"""Process impulse response files list and append as needed
|
546
541
|
|
547
542
|
:param entry: Impulse response file entry to append to the list
|
@@ -569,7 +564,7 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
|
|
569
564
|
if not names:
|
570
565
|
raise OSError(f"Could not find {in_name}. Make sure path exists")
|
571
566
|
|
572
|
-
impulse_response_files: list[
|
567
|
+
impulse_response_files: list[ImpulseResponseFile] = []
|
573
568
|
for name in names:
|
574
569
|
ext = splitext(name)[1].lower()
|
575
570
|
dir_name = dirname(name)
|
@@ -607,14 +602,14 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
|
|
607
602
|
raise OSError(f"Error processing {name}: {e}") from e
|
608
603
|
else:
|
609
604
|
validate_input_file(name)
|
610
|
-
impulse_response_files.append(tokenized_replace(name, tokens))
|
605
|
+
impulse_response_files.append(ImpulseResponseFile(tokenized_replace(name, tokens), entry.tags))
|
611
606
|
except Exception as e:
|
612
607
|
raise OSError(f"Error processing {name}: {e}") from e
|
613
608
|
|
614
609
|
return impulse_response_files
|
615
610
|
|
616
611
|
|
617
|
-
def get_spectral_masks(config: dict) ->
|
612
|
+
def get_spectral_masks(config: dict) -> list[SpectralMask]:
|
618
613
|
"""Get the list of spectral masks from a config
|
619
614
|
|
620
615
|
:param config: Config dictionary
|
@@ -623,12 +618,12 @@ def get_spectral_masks(config: dict) -> SpectralMasks:
|
|
623
618
|
from sonusai.utils import dataclass_from_dict
|
624
619
|
|
625
620
|
try:
|
626
|
-
return dataclass_from_dict(
|
621
|
+
return dataclass_from_dict(list[SpectralMask], config["spectral_masks"])
|
627
622
|
except Exception as e:
|
628
623
|
raise ValueError(f"Error in spectral_masks: {e}") from e
|
629
624
|
|
630
625
|
|
631
|
-
def get_truth_parameters(config: dict) ->
|
626
|
+
def get_truth_parameters(config: dict) -> list[TruthParameter]:
|
632
627
|
"""Get the list of truth parameters from a config
|
633
628
|
|
634
629
|
:param config: Config dictionary
|
@@ -637,26 +632,21 @@ def get_truth_parameters(config: dict) -> TruthParameters:
|
|
637
632
|
from copy import deepcopy
|
638
633
|
|
639
634
|
from sonusai.mixture import truth_functions
|
640
|
-
from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
|
641
635
|
|
642
636
|
from .constants import REQUIRED_TRUTH_CONFIGS
|
643
637
|
from .datatypes import TruthParameter
|
644
638
|
|
645
|
-
truth_parameters:
|
639
|
+
truth_parameters: list[TruthParameter] = []
|
646
640
|
for name, truth_config in config["truth_configs"].items():
|
647
641
|
optional_config = deepcopy(truth_config)
|
648
642
|
for key in REQUIRED_TRUTH_CONFIGS:
|
649
643
|
del optional_config[key]
|
650
644
|
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
target_gain=1,
|
656
|
-
config=optional_config,
|
645
|
+
parameters = getattr(truth_functions, truth_config["function"] + "_parameters")(
|
646
|
+
config["feature"],
|
647
|
+
config["num_classes"],
|
648
|
+
optional_config,
|
657
649
|
)
|
658
|
-
|
659
|
-
parameters = getattr(truth_functions, truth_config["function"] + "_parameters")(t_config)
|
660
650
|
truth_parameters.append(TruthParameter(name, parameters))
|
661
651
|
|
662
652
|
return truth_parameters
|
sonusai/mixture/data_io.py
CHANGED
@@ -128,6 +128,22 @@ def write_pickle_data(location: str, index: str, items: list[tuple[str, Any]] |
|
|
128
128
|
f.write(pickle.dumps(item[1]))
|
129
129
|
|
130
130
|
|
131
|
+
def clear_pickle_data(location: str, index: str, items: list[str] | str) -> None:
|
132
|
+
"""Clear mixture, target, or noise data pickle file
|
133
|
+
|
134
|
+
:param location: Location of the file
|
135
|
+
:param index: Mixture, target, or noise index
|
136
|
+
:param items: String(s) of data to retrieve
|
137
|
+
"""
|
138
|
+
from pathlib import Path
|
139
|
+
|
140
|
+
if not isinstance(items, list):
|
141
|
+
items = [items]
|
142
|
+
|
143
|
+
for item in items:
|
144
|
+
Path(_get_pickle_name(location, index, item)).unlink(missing_ok=True)
|
145
|
+
|
146
|
+
|
131
147
|
def read_cached_data(location: str, name: str, index: str, items: list[str] | str) -> Any:
|
132
148
|
"""Read cached data from a file
|
133
149
|
|
@@ -143,7 +159,7 @@ def read_cached_data(location: str, name: str, index: str, items: list[str] | st
|
|
143
159
|
|
144
160
|
|
145
161
|
def write_cached_data(location: str, name: str, index: str, items: list[tuple[str, Any]] | tuple[str, Any]) -> None:
|
146
|
-
"""Write
|
162
|
+
"""Write data to a file
|
147
163
|
|
148
164
|
:param location: Location of the mixture database
|
149
165
|
:param name: Data name ('mixture', 'target', or 'noise')
|
@@ -153,3 +169,16 @@ def write_cached_data(location: str, name: str, index: str, items: list[tuple[st
|
|
153
169
|
from os.path import join
|
154
170
|
|
155
171
|
write_pickle_data(join(location, name), index, items)
|
172
|
+
|
173
|
+
|
174
|
+
def clear_cached_data(location: str, name: str, index: str, items: list[str] | str) -> None:
|
175
|
+
"""Remove cached data file(s)
|
176
|
+
|
177
|
+
:param location: Location of the mixture database
|
178
|
+
:param name: Data name ('mixture', 'target', or 'noise')
|
179
|
+
:param index: Data index (mixture, target, or noise ID)
|
180
|
+
:param items: String(s) of data to clear
|
181
|
+
"""
|
182
|
+
from os.path import join
|
183
|
+
|
184
|
+
clear_pickle_data(join(location, name), index, items)
|