sonusai 0.19.6__py3-none-any.whl → 0.19.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. sonusai/__init__.py +1 -1
  2. sonusai/aawscd_probwrite.py +1 -1
  3. sonusai/calc_metric_spenh.py +1 -1
  4. sonusai/genft.py +29 -14
  5. sonusai/genmetrics.py +60 -42
  6. sonusai/genmix.py +41 -29
  7. sonusai/genmixdb.py +54 -62
  8. sonusai/metrics/calc_class_weights.py +1 -3
  9. sonusai/metrics/calc_optimal_thresholds.py +2 -2
  10. sonusai/metrics/calc_phase_distance.py +1 -1
  11. sonusai/metrics/calc_speech.py +6 -6
  12. sonusai/metrics/class_summary.py +6 -15
  13. sonusai/metrics/confusion_matrix_summary.py +11 -27
  14. sonusai/metrics/one_hot.py +3 -3
  15. sonusai/metrics/snr_summary.py +7 -7
  16. sonusai/mixture/__init__.py +2 -17
  17. sonusai/mixture/augmentation.py +5 -6
  18. sonusai/mixture/class_count.py +1 -1
  19. sonusai/mixture/config.py +36 -46
  20. sonusai/mixture/data_io.py +30 -1
  21. sonusai/mixture/datatypes.py +29 -40
  22. sonusai/mixture/db_datatypes.py +1 -1
  23. sonusai/mixture/feature.py +3 -23
  24. sonusai/mixture/generation.py +202 -235
  25. sonusai/mixture/helpers.py +29 -187
  26. sonusai/mixture/mixdb.py +386 -159
  27. sonusai/mixture/soundfile_audio.py +1 -1
  28. sonusai/mixture/sox_audio.py +4 -4
  29. sonusai/mixture/sox_augmentation.py +1 -1
  30. sonusai/mixture/target_class_balancing.py +9 -11
  31. sonusai/mixture/targets.py +23 -20
  32. sonusai/mixture/truth.py +21 -34
  33. sonusai/mixture/truth_functions/__init__.py +6 -0
  34. sonusai/mixture/truth_functions/crm.py +51 -37
  35. sonusai/mixture/truth_functions/energy.py +95 -50
  36. sonusai/mixture/truth_functions/file.py +12 -8
  37. sonusai/mixture/truth_functions/metadata.py +24 -0
  38. sonusai/mixture/truth_functions/metrics.py +28 -0
  39. sonusai/mixture/truth_functions/phoneme.py +4 -5
  40. sonusai/mixture/truth_functions/sed.py +32 -23
  41. sonusai/mixture/truth_functions/target.py +62 -29
  42. sonusai/mkwav.py +20 -19
  43. sonusai/queries/queries.py +9 -15
  44. sonusai/speech/l2arctic.py +6 -2
  45. sonusai/summarize_metric_spenh.py +1 -1
  46. sonusai/utils/__init__.py +1 -0
  47. sonusai/utils/asr_functions/aaware_whisper.py +1 -1
  48. sonusai/utils/audio_devices.py +27 -18
  49. sonusai/utils/docstring.py +6 -3
  50. sonusai/utils/energy_f.py +5 -3
  51. sonusai/utils/human_readable_size.py +6 -6
  52. sonusai/utils/load_object.py +15 -0
  53. sonusai/utils/onnx_utils.py +2 -2
  54. sonusai/utils/print_mixture_details.py +3 -3
  55. {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/METADATA +2 -2
  56. {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/RECORD +58 -56
  57. sonusai/mixture/truth_functions/datatypes.py +0 -37
  58. {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/WHEEL +0 -0
  59. {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/entry_points.txt +0 -0
@@ -54,7 +54,7 @@ def calc_class_weights_from_truth(truth: Truth, other_weight: float | None = Non
54
54
 
55
55
  def calc_class_weights_from_mixdb(
56
56
  mixdb: MixtureDatabase,
57
- mixids: GeneralizedIDs | None = None,
57
+ mixids: GeneralizedIDs = "*",
58
58
  other_weight: float = 1,
59
59
  other_index: int = -1,
60
60
  ) -> tuple[np.ndarray, np.ndarray]:
@@ -77,8 +77,6 @@ def calc_class_weights_from_mixdb(
77
77
  from sonusai.mixture import get_class_count_from_mixids
78
78
 
79
79
  count = np.ceil(np.array(get_class_count_from_mixids(mixdb=mixdb, mixids=mixids)) / mixdb.feature_step_samples)
80
- if mixdb.truth_mutex and other_weight is not None and other_weight > 0:
81
- count[other_index] = count[other_index] / np.float32(other_weight)
82
80
  total_features = sum(count)
83
81
 
84
82
  weights = np.empty(mixdb.num_classes, dtype=np.float32)
@@ -51,8 +51,8 @@ def calc_optimal_thresholds(
51
51
  AUC[nci] = np.NaN
52
52
  AP[nci] = np.NaN
53
53
  else:
54
- AP[nci] = average_precision_score(truth_binary[:, nci], predict[:, nci], average=None)
55
- AUC[nci] = roc_auc_score(truth_binary[:, nci], predict[:, nci], average=None)
54
+ AP[nci] = average_precision_score(truth_binary[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
55
+ AUC[nci] = roc_auc_score(truth_binary[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
56
56
 
57
57
  # Optimal threshold from PR curve, optimizes f-score
58
58
  precision, recall, thrpr = precision_recall_curve(truth_binary[:, nci], predict[:, nci])
@@ -26,7 +26,7 @@ def calc_phase_distance(
26
26
  # weighted mean over all (scalar)
27
27
  reference_mag = np.abs(reference)
28
28
  ref_weight = reference_mag / (np.sum(reference_mag) + eps) # frames x bins
29
- err = np.around(np.sum(ref_weight * rh_angle_diff), 3)
29
+ err = float(np.around(np.sum(ref_weight * rh_angle_diff), 3))
30
30
 
31
31
  # weighted mean over frames (value per bin)
32
32
  err_b = np.zeros(reference.shape[1])
@@ -32,16 +32,16 @@ def calc_speech(hypothesis: np.ndarray, reference: np.ndarray, sample_rate: int
32
32
  llr_mean = np.mean(ll_rs[:llr_len])
33
33
 
34
34
  # Segmental SNR
35
- snr_dist, segsnr_dist = _calc_snr(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
35
+ _, segsnr_dist = _calc_snr(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
36
36
  seg_snr = np.mean(segsnr_dist)
37
37
 
38
38
  # PESQ
39
39
  _pesq = calc_pesq(hypothesis=hypothesis, reference=reference, sample_rate=sample_rate)
40
40
 
41
41
  # Now compute the composite measures
42
- csig = np.clip(3.093 - 1.029 * llr_mean + 0.603 * _pesq - 0.009 * wss_dist, 1, 5)
43
- cbak = np.clip(1.634 + 0.478 * _pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5)
44
- covl = np.clip(1.594 + 0.805 * _pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5)
42
+ csig = float(np.clip(3.093 - 1.029 * llr_mean + 0.603 * _pesq - 0.009 * wss_dist, 1, 5))
43
+ cbak = float(np.clip(1.634 + 0.478 * _pesq - 0.007 * wss_dist + 0.063 * seg_snr, 1, 5))
44
+ covl = float(np.clip(1.594 + 0.805 * _pesq - 0.512 * llr_mean - 0.007 * wss_dist, 1, 5))
45
45
 
46
46
  return SpeechMetrics(_pesq, csig, cbak, covl)
47
47
 
@@ -284,8 +284,8 @@ def _calc_log_likelihood_ratio_measure(
284
284
  hypothesis_frame = np.multiply(hypothesis_frame, window)
285
285
 
286
286
  # (2) Get the autocorrelation lags and LPC parameters used to compute the log likelihood ratio measure.
287
- r_reference, ref_reference, a_reference = _lp_coefficients(reference_frame, p)
288
- r_hypothesis, ref_hypothesis, a_hypothesis = _lp_coefficients(hypothesis_frame, p)
287
+ r_reference, _, a_reference = _lp_coefficients(reference_frame, p)
288
+ _, _, a_hypothesis = _lp_coefficients(hypothesis_frame, p)
289
289
 
290
290
  # (3) Compute the log likelihood ratio measure
291
291
  numerator = np.dot(np.matmul(a_hypothesis, toeplitz(r_reference)), a_hypothesis)
@@ -38,7 +38,7 @@ def class_summary(
38
38
  # TODO: re-work for modern mixdb API
39
39
  y_truth_f, y_predict = get_mixids_data(mixdb, mixids, truth_f, predict) # type: ignore[name-defined]
40
40
 
41
- if not mixdb.truth_mutex and num_classes > 1:
41
+ if num_classes > 1:
42
42
  if not isinstance(predict_thr, np.ndarray):
43
43
  if predict_thr == 0:
44
44
  predict_thr = np.atleast_1d(0.5)
@@ -53,25 +53,16 @@ def class_summary(
53
53
  # [ACC, TPR, PPV, TNR, FPR, HITFA, F1, MCC, NT, PT, TP, FP, AP, AUC]
54
54
  table_idx = np.array([2, 1, 6, 4, 0, 12, 13, 9])
55
55
  col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC", "Support"]
56
- if mixdb.truth_mutex:
57
- if len(mixdb.class_labels) >= num_classes - 1: # labels exist with or without Other
58
- row_n = mixdb.class_labels
59
- if len(mixdb.class_labels) == num_classes - 1: # Other label does not exist, so add it
60
- row_n.append("Other")
61
- else:
62
- row_n = [f"Class {i}" for i in range(1, num_classes)]
63
- row_n.append("Other")
56
+ if len(mixdb.class_labels) == num_classes:
57
+ row_n = mixdb.class_labels
64
58
  else:
65
- if len(mixdb.class_labels) == num_classes:
66
- row_n = mixdb.class_labels
67
- else:
68
- row_n = [f"Class {i}" for i in range(1, num_classes + 1)]
59
+ row_n = [f"Class {i}" for i in range(1, num_classes + 1)]
69
60
 
70
- df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n)
61
+ df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n) # pyright: ignore [reportArgumentType]
71
62
 
72
63
  # [miPPV, miTPR, miF1, miFPR, miACC, miAP, miAUC, TPSUM]
73
64
  avg_row_n = ["Macro-avg", "Micro-avg", "Weighted-avg"]
74
- dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n)
65
+ dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n) # pyright: ignore [reportArgumentType]
75
66
 
76
67
  # dfblank = pd.DataFrame([''])
77
68
  # pd.concat([df, dfblank, dfblank, dfavg])
@@ -37,7 +37,7 @@ def confusion_matrix_summary(
37
37
  ytrue, ypred = get_mixids_data(mixdb=mixdb, mixids=mixids, truth_f=truth_f, predict=predict) # type: ignore[name-defined]
38
38
 
39
39
  # Check predict_thr array or scalar and return final scalar predict_thr value
40
- if not mixdb.truth_mutex and num_classes > 1:
40
+ if num_classes > 1:
41
41
  if not isinstance(predict_thr, np.ndarray):
42
42
  if predict_thr == 0:
43
43
  # multi-label predict_thr scalar 0 force to 0.5 default
@@ -61,31 +61,15 @@ def confusion_matrix_summary(
61
61
  else:
62
62
  class_names = [f"Class {i}" for i in range(1, num_classes + 1)]
63
63
 
64
- class_nums = [f"{i}" for i in range(1, num_classes + 1)]
65
-
66
- if mixdb.truth_mutex:
67
- # single-label mode force to argmax mode
68
- predict_thr = np.array(0, dtype=np.float32)
69
- _, _, cm, cmn, _, _ = one_hot(ytrue, ypred, predict_thr, truth_thr, timesteps)
70
- row_n = class_names
71
- row_n[-1] = "Other"
72
- # mux = pd.MultiIndex.from_product([['Single-label/mutex mode, truth thr = {}'.format(truth_thr)],
73
- # class_nums])
74
- # mux = pd.MultiIndex.from_product([['truth thr = {}'.format(truth_thr)], class_nums])
75
-
76
- cmdf = pd.DataFrame(cm, index=row_n, columns=class_nums, dtype=np.int32)
77
- cmndf = pd.DataFrame(cmn, index=row_n, columns=class_nums, dtype=np.float32)
78
-
79
- else:
80
- _, _, cm, cmn, _, _ = one_hot(ytrue[:, class_idx], ypred[:, class_idx], predict_thr, truth_thr, timesteps)
81
- cname = class_names[class_idx]
82
- row_n = ["TrueN", "TrueP"]
83
- col_n = ["N-" + cname, "P-" + cname]
84
- cmdf = pd.DataFrame(cm, index=row_n, columns=col_n, dtype=np.int32)
85
- cmndf = pd.DataFrame(cmn, index=row_n, columns=col_n, dtype=np.float32)
86
- # add thresholds in 3rd row
87
- pdnote = pd.DataFrame(np.atleast_2d([predict_thr, truth_thr]), index=["p/t thr:"], columns=col_n)
88
- cmdf = pd.concat([cmdf, pdnote])
89
- cmndf = pd.concat([cmndf, pdnote])
64
+ _, _, cm, cmn, _, _ = one_hot(ytrue[:, class_idx], ypred[:, class_idx], predict_thr, truth_thr, timesteps)
65
+ cname = class_names[class_idx]
66
+ row_n = ["TrueN", "TrueP"]
67
+ col_n = ["N-" + cname, "P-" + cname]
68
+ cmdf = pd.DataFrame(cm, index=row_n, columns=col_n, dtype=np.int32) # pyright: ignore [reportArgumentType]
69
+ cmndf = pd.DataFrame(cmn, index=row_n, columns=col_n, dtype=np.float32) # pyright: ignore [reportArgumentType]
70
+ # add thresholds in 3rd row
71
+ pdnote = pd.DataFrame(np.atleast_2d([predict_thr, truth_thr]), index=["p/t thr:"], columns=col_n) # pyright: ignore [reportArgumentType, reportCallIssue]
72
+ cmdf = pd.concat([cmdf, pdnote])
73
+ cmndf = pd.concat([cmndf, pdnote])
90
74
 
91
75
  return cmdf, cmndf
@@ -185,11 +185,11 @@ def one_hot(
185
185
  AP = np.NaN
186
186
  # threshold_optpr[nci] = np.NaN
187
187
  else:
188
- AP = average_precision_score(truthb[:, nci], predict[:, nci], average=None)
188
+ AP = average_precision_score(truthb[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
189
189
  if len(np.unique(truthb[:, nci])) < 2: # if active classes not > 1 AUC must be NaN
190
190
  AUC = np.NaN # i.e. all ones sklearn auc will fail
191
191
  else:
192
- AUC = roc_auc_score(truthb[:, nci], predict[:, nci], average=None)
192
+ AUC = roc_auc_score(truthb[:, nci], predict[:, nci], average=None) # pyright: ignore [reportArgumentType]
193
193
  # # Optimal threshold from PR curve, optimizes f-score
194
194
  # precision, recall, thresholds = precision_recall_curve(truthb[:, nci], predict[:, nci])
195
195
  # fscore = (2 * precision * recall) / (precision + recall)
@@ -263,7 +263,7 @@ def one_hot(
263
263
  ] # specific format, last 3 are unique
264
264
 
265
265
  # weighted average TBD
266
- wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average="weighted", zero_division=0)
266
+ wp, wr, wf1, _ = precision_recall_fscore_support(truthb, predb, average="weighted", zero_division=0) # pyright: ignore [reportArgumentType]
267
267
  if np.sum(truthb):
268
268
  taidx = np.sum(truthb, axis=0) > 0
269
269
  wap = average_precision_score(truthb[:, taidx], predict[:, taidx], average="weighted")
@@ -48,7 +48,7 @@ def snr_summary(
48
48
  snr_mixids = get_mixids_from_snr(mixdb=mixdb, mixids=mixid)
49
49
 
50
50
  # Check predict_thr array or scalar and return final scalar predict_thr value
51
- if not mixdb.truth_mutex and num_classes > 1:
51
+ if num_classes > 1:
52
52
  if not isinstance(predict_thr, np.ndarray):
53
53
  if predict_thr == 0:
54
54
  # multi-label predict_thr scalar 0 force to 0.5 default
@@ -84,7 +84,7 @@ def snr_summary(
84
84
  for ii, snr in enumerate(snr_mixids):
85
85
  # TODO: re-work for modern mixdb API
86
86
  y_truth, y_predict = get_mixids_data(mixdb, snr_mixids[snr], truth_f, predict) # type: ignore[name-defined]
87
- _, metrics, _, _, _, mavg = one_hot(y_truth, y_predict, predict_thr, truth_thr, timesteps)
87
+ _, _, _, _, _, mavg = one_hot(y_truth, y_predict, predict_thr, truth_thr, timesteps)
88
88
 
89
89
  # mavg macro, micro, weighted: [PPV, TPR, F1, FPR, ACC, mAP, mAUC, TPSUM]
90
90
  macro_avg[ii, :] = mavg[0, 0:7]
@@ -104,21 +104,21 @@ def snr_summary(
104
104
 
105
105
  # SNR format: PPV, TPR, F1, FPR, ACC, AP, AUC
106
106
  col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC"]
107
- snr_macrodf = pd.DataFrame(macro_avg, index=list(snr_mixids.keys()), columns=col_n)
107
+ snr_macrodf = pd.DataFrame(macro_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
108
108
  snr_macrodf.sort_index(ascending=False, inplace=True)
109
109
 
110
- snr_microdf = pd.DataFrame(micro_avg, index=list(snr_mixids.keys()), columns=col_n)
110
+ snr_microdf = pd.DataFrame(micro_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
111
111
  snr_microdf.sort_index(ascending=False, inplace=True)
112
112
 
113
- snr_wghtdf = pd.DataFrame(wghtd_avg, index=list(snr_mixids.keys()), columns=col_n)
113
+ snr_wghtdf = pd.DataFrame(wghtd_avg, index=list(snr_mixids.keys()), columns=col_n) # pyright: ignore [reportArgumentType]
114
114
  snr_wghtdf.sort_index(ascending=False, inplace=True)
115
115
 
116
116
  # Add segmental SNR columns if provided
117
117
  if segsnr is not None:
118
118
  ssnrdf = pd.DataFrame(
119
119
  ssnr_stats,
120
- index=list(snr_mixids.keys()),
121
- columns=["SSNRavg", "SSNR80p", "SSNRmax"],
120
+ index=list(snr_mixids.keys()), # pyright: ignore [reportArgumentType]
121
+ columns=["SSNRavg", "SSNR80p", "SSNRmax"], # pyright: ignore [reportArgumentType]
122
122
  )
123
123
  ssnrdf.sort_index(ascending=False, inplace=True)
124
124
  snr_macrodf = pd.concat([snr_macrodf, ssnrdf], axis=1)
@@ -46,19 +46,15 @@ from .constants import SAMPLE_RATE
46
46
  from .constants import VALID_AUGMENTATIONS
47
47
  from .constants import VALID_CONFIGS
48
48
  from .constants import VALID_NOISE_MIX_MODES
49
+ from .data_io import clear_cached_data
49
50
  from .data_io import read_cached_data
50
51
  from .data_io import write_cached_data
51
52
  from .datatypes import AudioF
52
- from .datatypes import AudiosF
53
- from .datatypes import AudiosT
54
53
  from .datatypes import AudioStatsMetrics
55
54
  from .datatypes import AudioT
56
55
  from .datatypes import Augmentation
57
56
  from .datatypes import AugmentationRule
58
- from .datatypes import AugmentationRules
59
- from .datatypes import Augmentations
60
57
  from .datatypes import AugmentedTarget
61
- from .datatypes import AugmentedTargets
62
58
  from .datatypes import ClassCount
63
59
  from .datatypes import EnergyF
64
60
  from .datatypes import EnergyT
@@ -70,35 +66,27 @@ from .datatypes import GenFTData
70
66
  from .datatypes import GenMixData
71
67
  from .datatypes import ImpulseResponseData
72
68
  from .datatypes import ImpulseResponseFile
73
- from .datatypes import ImpulseResponseFiles
74
- from .datatypes import ListAudiosT
75
69
  from .datatypes import MetricDoc
76
70
  from .datatypes import MetricDocs
77
71
  from .datatypes import Mixture
78
72
  from .datatypes import MixtureDatabaseConfig
79
- from .datatypes import Mixtures
80
73
  from .datatypes import NoiseFile
81
- from .datatypes import NoiseFiles
82
74
  from .datatypes import Predict
83
75
  from .datatypes import Segsnr
84
76
  from .datatypes import SnrFMetrics
85
77
  from .datatypes import SpectralMask
86
- from .datatypes import SpectralMasks
87
78
  from .datatypes import SpeechMetadata
88
79
  from .datatypes import SpeechMetrics
89
80
  from .datatypes import TargetFile
90
- from .datatypes import TargetFiles
91
81
  from .datatypes import TransformConfig
92
82
  from .datatypes import Truth
93
83
  from .datatypes import TruthConfig
94
84
  from .datatypes import TruthConfigs
95
85
  from .datatypes import TruthDict
96
86
  from .datatypes import TruthParameter
97
- from .datatypes import TruthParameters
98
87
  from .datatypes import UniversalSNR
99
88
  from .feature import get_audio_from_feature
100
89
  from .feature import get_feature_from_audio
101
- from .generation import generate_mixtures
102
90
  from .generation import get_all_snrs_from_config
103
91
  from .generation import initialize_db
104
92
  from .generation import populate_class_label_table
@@ -111,17 +99,14 @@ from .generation import populate_target_file_table
111
99
  from .generation import populate_top_table
112
100
  from .generation import populate_truth_parameters_table
113
101
  from .generation import update_mixid_width
114
- from .generation import update_mixture
102
+ from .generation import update_mixture_table
115
103
  from .helpers import augmented_noise_samples
116
104
  from .helpers import augmented_target_samples
117
105
  from .helpers import check_audio_files_exist
118
106
  from .helpers import forward_transform
119
107
  from .helpers import frames_from_samples
120
108
  from .helpers import get_audio_from_transform
121
- from .helpers import get_ft
122
- from .helpers import get_segsnr
123
109
  from .helpers import get_transform_from_audio
124
- from .helpers import get_truth
125
110
  from .helpers import inverse_transform
126
111
  from .helpers import mixture_metadata
127
112
  from .helpers import write_mixture_metadata
@@ -1,12 +1,11 @@
1
1
  from sonusai.mixture.datatypes import AudioT
2
2
  from sonusai.mixture.datatypes import Augmentation
3
3
  from sonusai.mixture.datatypes import AugmentationRule
4
- from sonusai.mixture.datatypes import AugmentationRules
5
4
  from sonusai.mixture.datatypes import ImpulseResponseData
6
5
  from sonusai.mixture.datatypes import OptionalNumberStr
7
6
 
8
7
 
9
- def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> AugmentationRules:
8
+ def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> list[AugmentationRule]:
10
9
  """Generate augmentation rules from list of input rules
11
10
 
12
11
  :param rules: Dictionary of augmentation config rule[s]
@@ -25,7 +24,7 @@ def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> Augment
25
24
  rule = _parse_ir(rule, num_ir)
26
25
  processed_rules = _expand_rules(expanded_rules=processed_rules, rule=rule)
27
26
 
28
- return [dataclass_from_dict(AugmentationRule, processed_rule) for processed_rule in processed_rules]
27
+ return [dataclass_from_dict(AugmentationRule, processed_rule) for processed_rule in processed_rules] # pyright: ignore [reportReturnType]
29
28
 
30
29
 
31
30
  def _expand_rules(expanded_rules: list[dict], rule: dict) -> list[dict]:
@@ -163,7 +162,7 @@ def estimate_augmented_length_from_length(length: int, tempo: OptionalNumberStr
163
162
  return length
164
163
 
165
164
 
166
- def get_mixups(augmentations: AugmentationRules) -> list[int]:
165
+ def get_mixups(augmentations: list[AugmentationRule]) -> list[int]:
167
166
  """Get a list of mixup values used
168
167
 
169
168
  :param augmentations: List of augmentations
@@ -172,7 +171,7 @@ def get_mixups(augmentations: AugmentationRules) -> list[int]:
172
171
  return sorted({augmentation.mixup for augmentation in augmentations})
173
172
 
174
173
 
175
- def get_augmentation_indices_for_mixup(augmentations: AugmentationRules, mixup: int) -> list[int]:
174
+ def get_augmentation_indices_for_mixup(augmentations: list[AugmentationRule], mixup: int) -> list[int]:
176
175
  """Get a list of augmentation indices for a given mixup value
177
176
 
178
177
  :param augmentations: List of augmentations
@@ -327,4 +326,4 @@ def augmentation_from_rule(rule: AugmentationRule, num_ir: int) -> Augmentation:
327
326
  if _rule_has_rand(processed_rule):
328
327
  processed_rule = _generate_random_rule(processed_rule, num_ir)
329
328
 
330
- return dataclass_from_dict(Augmentation, processed_rule)
329
+ return dataclass_from_dict(Augmentation, processed_rule) # pyright: ignore [reportReturnType]
@@ -3,7 +3,7 @@ from sonusai.mixture.datatypes import GeneralizedIDs
3
3
  from sonusai.mixture.mixdb import MixtureDatabase
4
4
 
5
5
 
6
- def get_class_count_from_mixids(mixdb: MixtureDatabase, mixids: GeneralizedIDs | None = None) -> ClassCount:
6
+ def get_class_count_from_mixids(mixdb: MixtureDatabase, mixids: GeneralizedIDs = "*") -> ClassCount:
7
7
  """Sums the class counts for given mixids"""
8
8
  total_class_count = [0] * mixdb.num_classes
9
9
  m_ids = mixdb.mixids_to_list(mixids)
sonusai/mixture/config.py CHANGED
@@ -1,9 +1,8 @@
1
1
  from sonusai.mixture.datatypes import ImpulseResponseFile
2
- from sonusai.mixture.datatypes import ImpulseResponseFiles
3
- from sonusai.mixture.datatypes import NoiseFiles
4
- from sonusai.mixture.datatypes import SpectralMasks
5
- from sonusai.mixture.datatypes import TargetFiles
6
- from sonusai.mixture.datatypes import TruthParameters
2
+ from sonusai.mixture.datatypes import NoiseFile
3
+ from sonusai.mixture.datatypes import SpectralMask
4
+ from sonusai.mixture.datatypes import TargetFile
5
+ from sonusai.mixture.datatypes import TruthParameter
7
6
 
8
7
 
9
8
  def raw_load_config(name: str) -> dict:
@@ -210,7 +209,7 @@ def update_config_from_hierarchy(root: str, leaf: str, config: dict) -> dict:
210
209
  return new_config
211
210
 
212
211
 
213
- def get_target_files(config: dict, show_progress: bool = False) -> TargetFiles:
212
+ def get_target_files(config: dict, show_progress: bool = False) -> list[TargetFile]:
214
213
  """Get the list of target files from a config
215
214
 
216
215
  :param config: Config dictionary
@@ -223,7 +222,7 @@ def get_target_files(config: dict, show_progress: bool = False) -> TargetFiles:
223
222
  from sonusai.utils import par_track
224
223
  from sonusai.utils import track
225
224
 
226
- from .datatypes import TargetFiles
225
+ from .datatypes import TargetFile
227
226
 
228
227
  class_indices = config["class_indices"]
229
228
  if not isinstance(class_indices, list):
@@ -255,7 +254,7 @@ def get_target_files(config: dict, show_progress: bool = False) -> TargetFiles:
255
254
  if any(class_index > num_classes for class_index in target_file["class_indices"]):
256
255
  raise ValueError(f"class index elements must not be greater than {num_classes}")
257
256
 
258
- return dataclass_from_dict(TargetFiles, target_files)
257
+ return dataclass_from_dict(list[TargetFile], target_files)
259
258
 
260
259
 
261
260
  def append_target_files(
@@ -294,6 +293,7 @@ def append_target_files(
294
293
  if tokens is None:
295
294
  tokens = {}
296
295
 
296
+ truth_configs_merged = deepcopy(truth_configs)
297
297
  if isinstance(entry, dict):
298
298
  if "name" in entry:
299
299
  in_name = entry["name"]
@@ -312,15 +312,11 @@ def append_target_files(
312
312
  raise AttributeError(
313
313
  f"Truth config '{key}' override specified for {entry['name']} is not defined at top level"
314
314
  )
315
- truth_configs_merged = {}
316
- for key in truth_configs_override:
317
- truth_configs_merged[key] = deepcopy(truth_configs[key])
318
- if truth_configs_override[key] is not None:
315
+ if key in truth_configs_override:
319
316
  truth_configs_merged[key] |= truth_configs_override[key]
320
317
  level_type = entry.get("level_type", level_type)
321
318
  else:
322
319
  in_name = entry
323
- truth_configs_merged = deepcopy(truth_configs)
324
320
 
325
321
  in_name, new_tokens = tokenized_expand(in_name)
326
322
  tokens.update(new_tokens)
@@ -416,7 +412,7 @@ def append_target_files(
416
412
  return target_files
417
413
 
418
414
 
419
- def get_noise_files(config: dict, show_progress: bool = False) -> NoiseFiles:
415
+ def get_noise_files(config: dict, show_progress: bool = False) -> list[NoiseFile]:
420
416
  """Get the list of noise files from a config
421
417
 
422
418
  :param config: Config dictionary
@@ -429,7 +425,7 @@ def get_noise_files(config: dict, show_progress: bool = False) -> NoiseFiles:
429
425
  from sonusai.utils import par_track
430
426
  from sonusai.utils import track
431
427
 
432
- from .datatypes import NoiseFiles
428
+ from .datatypes import NoiseFile
433
429
 
434
430
  noise_files = list(chain.from_iterable([append_noise_files(entry=entry) for entry in config["noises"]]))
435
431
 
@@ -437,7 +433,7 @@ def get_noise_files(config: dict, show_progress: bool = False) -> NoiseFiles:
437
433
  noise_files = par_track(_get_num_samples, noise_files, progress=progress)
438
434
  progress.close()
439
435
 
440
- return dataclass_from_dict(NoiseFiles, noise_files)
436
+ return dataclass_from_dict(list[NoiseFile], noise_files)
441
437
 
442
438
 
443
439
  def append_noise_files(entry: dict | str, tokens: dict | None = None) -> list[dict]:
@@ -522,26 +518,25 @@ def append_noise_files(entry: dict | str, tokens: dict | None = None) -> list[di
522
518
  return noise_files
523
519
 
524
520
 
525
- def get_impulse_response_files(config: dict) -> ImpulseResponseFiles:
521
+ def get_impulse_response_files(config: dict) -> list[ImpulseResponseFile]:
526
522
  """Get the list of impulse response files from a config
527
523
 
528
524
  :param config: Config dictionary
529
525
  :return: List of impulse response files
530
526
  """
531
- return [ImpulseResponseFile(entry["name"], entry["tags"]) for entry in config["impulse_responses"]]
532
- # from itertools import chain
533
- #
534
- # return list(
535
- # chain.from_iterable(
536
- # [
537
- # append_impulse_response_files(entry=ImpulseResponseFile(entry["name"], entry["tags"]))
538
- # for entry in config["impulse_responses"]
539
- # ]
540
- # )
541
- # )
542
-
543
-
544
- def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | None = None) -> list[str]:
527
+ from itertools import chain
528
+
529
+ return list(
530
+ chain.from_iterable(
531
+ [
532
+ append_impulse_response_files(entry=ImpulseResponseFile(entry["name"], entry.get("tags", [])))
533
+ for entry in config["impulse_responses"]
534
+ ]
535
+ )
536
+ )
537
+
538
+
539
+ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | None = None) -> list[ImpulseResponseFile]:
545
540
  """Process impulse response files list and append as needed
546
541
 
547
542
  :param entry: Impulse response file entry to append to the list
@@ -569,7 +564,7 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
569
564
  if not names:
570
565
  raise OSError(f"Could not find {in_name}. Make sure path exists")
571
566
 
572
- impulse_response_files: list[str] = []
567
+ impulse_response_files: list[ImpulseResponseFile] = []
573
568
  for name in names:
574
569
  ext = splitext(name)[1].lower()
575
570
  dir_name = dirname(name)
@@ -607,14 +602,14 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
607
602
  raise OSError(f"Error processing {name}: {e}") from e
608
603
  else:
609
604
  validate_input_file(name)
610
- impulse_response_files.append(tokenized_replace(name, tokens))
605
+ impulse_response_files.append(ImpulseResponseFile(tokenized_replace(name, tokens), entry.tags))
611
606
  except Exception as e:
612
607
  raise OSError(f"Error processing {name}: {e}") from e
613
608
 
614
609
  return impulse_response_files
615
610
 
616
611
 
617
- def get_spectral_masks(config: dict) -> SpectralMasks:
612
+ def get_spectral_masks(config: dict) -> list[SpectralMask]:
618
613
  """Get the list of spectral masks from a config
619
614
 
620
615
  :param config: Config dictionary
@@ -623,12 +618,12 @@ def get_spectral_masks(config: dict) -> SpectralMasks:
623
618
  from sonusai.utils import dataclass_from_dict
624
619
 
625
620
  try:
626
- return dataclass_from_dict(SpectralMasks, config["spectral_masks"])
621
+ return dataclass_from_dict(list[SpectralMask], config["spectral_masks"])
627
622
  except Exception as e:
628
623
  raise ValueError(f"Error in spectral_masks: {e}") from e
629
624
 
630
625
 
631
- def get_truth_parameters(config: dict) -> TruthParameters:
626
+ def get_truth_parameters(config: dict) -> list[TruthParameter]:
632
627
  """Get the list of truth parameters from a config
633
628
 
634
629
  :param config: Config dictionary
@@ -637,26 +632,21 @@ def get_truth_parameters(config: dict) -> TruthParameters:
637
632
  from copy import deepcopy
638
633
 
639
634
  from sonusai.mixture import truth_functions
640
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
641
635
 
642
636
  from .constants import REQUIRED_TRUTH_CONFIGS
643
637
  from .datatypes import TruthParameter
644
638
 
645
- truth_parameters: TruthParameters = []
639
+ truth_parameters: list[TruthParameter] = []
646
640
  for name, truth_config in config["truth_configs"].items():
647
641
  optional_config = deepcopy(truth_config)
648
642
  for key in REQUIRED_TRUTH_CONFIGS:
649
643
  del optional_config[key]
650
644
 
651
- t_config = TruthFunctionConfig(
652
- feature=config["feature"],
653
- num_classes=config["num_classes"],
654
- class_indices=[1],
655
- target_gain=1,
656
- config=optional_config,
645
+ parameters = getattr(truth_functions, truth_config["function"] + "_parameters")(
646
+ config["feature"],
647
+ config["num_classes"],
648
+ optional_config,
657
649
  )
658
-
659
- parameters = getattr(truth_functions, truth_config["function"] + "_parameters")(t_config)
660
650
  truth_parameters.append(TruthParameter(name, parameters))
661
651
 
662
652
  return truth_parameters
@@ -128,6 +128,22 @@ def write_pickle_data(location: str, index: str, items: list[tuple[str, Any]] |
128
128
  f.write(pickle.dumps(item[1]))
129
129
 
130
130
 
131
+ def clear_pickle_data(location: str, index: str, items: list[str] | str) -> None:
132
+ """Clear mixture, target, or noise data pickle file
133
+
134
+ :param location: Location of the file
135
+ :param index: Mixture, target, or noise index
136
+ :param items: String(s) of data to retrieve
137
+ """
138
+ from pathlib import Path
139
+
140
+ if not isinstance(items, list):
141
+ items = [items]
142
+
143
+ for item in items:
144
+ Path(_get_pickle_name(location, index, item)).unlink(missing_ok=True)
145
+
146
+
131
147
  def read_cached_data(location: str, name: str, index: str, items: list[str] | str) -> Any:
132
148
  """Read cached data from a file
133
149
 
@@ -143,7 +159,7 @@ def read_cached_data(location: str, name: str, index: str, items: list[str] | st
143
159
 
144
160
 
145
161
  def write_cached_data(location: str, name: str, index: str, items: list[tuple[str, Any]] | tuple[str, Any]) -> None:
146
- """Write mixture data to a file
162
+ """Write data to a file
147
163
 
148
164
  :param location: Location of the mixture database
149
165
  :param name: Data name ('mixture', 'target', or 'noise')
@@ -153,3 +169,16 @@ def write_cached_data(location: str, name: str, index: str, items: list[tuple[st
153
169
  from os.path import join
154
170
 
155
171
  write_pickle_data(join(location, name), index, items)
172
+
173
+
174
+ def clear_cached_data(location: str, name: str, index: str, items: list[str] | str) -> None:
175
+ """Remove cached data file(s)
176
+
177
+ :param location: Location of the mixture database
178
+ :param name: Data name ('mixture', 'target', or 'noise')
179
+ :param index: Data index (mixture, target, or noise ID)
180
+ :param items: String(s) of data to clear
181
+ """
182
+ from os.path import join
183
+
184
+ clear_pickle_data(join(location, name), index, items)