sonusai 0.18.9__py3-none-any.whl → 0.19.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +93 -77
  13. sonusai/genmetrics.py +59 -46
  14. sonusai/genmix.py +116 -104
  15. sonusai/genmixdb.py +194 -153
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +15 -17
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +19 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +40 -38
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +669 -477
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +52 -85
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +40 -27
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
  113. sonusai-0.19.5.dist-info/RECORD +125 -0
  114. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.9.dist-info/RECORD +0 -125
  118. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
sonusai/onnx_predict.py CHANGED
@@ -41,11 +41,12 @@ TBD not sure below make sense, need to continue ??
41
41
  3. Classification
42
42
 
43
43
  Outputs the following to opredict-<TIMESTAMP> directory:
44
- <id>.h5
45
- dataset: predict
44
+ <id>
45
+ predict.pkl
46
46
  onnx_predict.log
47
47
 
48
48
  """
49
+
49
50
  import signal
50
51
 
51
52
 
@@ -54,7 +55,7 @@ def signal_handler(_sig, _frame):
54
55
 
55
56
  from sonusai import logger
56
57
 
57
- logger.info('Canceled due to keyboard interrupt')
58
+ logger.info("Canceled due to keyboard interrupt")
58
59
  sys.exit(1)
59
60
 
60
61
 
@@ -69,12 +70,12 @@ def main() -> None:
69
70
 
70
71
  args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
71
72
 
72
- verbose = args['--verbose']
73
- wav = args['--write-wav']
74
- mixids = args['--mixid']
75
- include = args['--include']
76
- model_path = args['MODEL']
77
- data_paths = args['DATA']
73
+ verbose = args["--verbose"]
74
+ wav = args["--write-wav"]
75
+ mixids = args["--mixid"]
76
+ include = args["--include"]
77
+ model_path = args["MODEL"]
78
+ data_paths = args["DATA"]
78
79
 
79
80
  from os import makedirs
80
81
  from os.path import abspath
@@ -103,8 +104,8 @@ def main() -> None:
103
104
  from sonusai.utils import write_audio
104
105
 
105
106
  mixdb_path = None
106
- mixdb = None
107
- p_mixids = None
107
+ mixdb: MixtureDatabase | None = None
108
+ p_mixids: list[int] = []
108
109
  entries: list[PathInfo] = []
109
110
 
110
111
  if len(data_paths) == 1 and isdir(data_paths[0]):
@@ -113,96 +114,98 @@ def main() -> None:
113
114
  mixdb_path = data_paths[0]
114
115
  else:
115
116
  # search all data paths for .wav, .flac (or whatever is specified in include)
116
- in_basename = ''
117
+ in_basename = ""
117
118
 
118
- output_dir = create_ts_name('opredict-' + in_basename)
119
+ output_dir = create_ts_name("opredict-" + in_basename)
119
120
  makedirs(output_dir, exist_ok=True)
120
121
 
121
122
  # Setup logging file
122
- create_file_handler(join(output_dir, 'onnx-predict.log'))
123
+ create_file_handler(join(output_dir, "onnx-predict.log"))
123
124
  update_console_handler(verbose)
124
- initial_log_messages('onnx_predict')
125
+ initial_log_messages("onnx_predict")
125
126
 
126
127
  providers = ort.get_available_providers()
127
- logger.info(f'Loaded ONNX Runtime, available providers: {providers}.')
128
+ logger.info(f"Loaded ONNX Runtime, available providers: {providers}.")
128
129
 
129
130
  session, options, model_root, hparams, sess_inputs, sess_outputs = load_ort_session(model_path)
130
131
  if hparams is None:
131
- logger.error(f'Error: ONNX model does not have required SonusAI hyperparameters, cannot proceed.')
132
+ logger.error("Error: ONNX model does not have required SonusAI hyperparameters, cannot proceed.")
132
133
  raise SystemExit(1)
133
134
  if len(sess_inputs) != 1:
134
- logger.error(f'Error: ONNX model does not have 1 input, but {len(sess_inputs)}. Exit due to unknown input.')
135
+ logger.error(f"Error: ONNX model does not have 1 input, but {len(sess_inputs)}. Exit due to unknown input.")
135
136
 
136
137
  in0name = sess_inputs[0].name
137
138
  in0type = sess_inputs[0].type
138
139
  out_names = [n.name for n in session.get_outputs()]
139
140
 
140
- logger.info(f'Read and compiled ONNX model from {model_path}.')
141
+ logger.info(f"Read and compiled ONNX model from {model_path}.")
141
142
 
142
143
  if mixdb_path is not None:
143
144
  # Assume it's a single path to SonusAI mixdb subdir
144
- logger.debug(f'Attempting to load mixture database from {mixdb_path}')
145
+ logger.debug(f"Attempting to load mixture database from {mixdb_path}")
145
146
  mixdb = MixtureDatabase(mixdb_path)
146
- logger.info(f'SonusAI mixdb: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes')
147
+ logger.info(f"SonusAI mixdb: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes")
147
148
  p_mixids = mixdb.mixids_to_list(mixids)
148
149
  if len(p_mixids) != mixdb.num_mixtures:
149
- logger.info(f'Processing a subset of {p_mixids} from available mixtures.')
150
+ logger.info(f"Processing a subset of {p_mixids} from available mixtures.")
150
151
  else:
151
152
  for p in data_paths:
152
- location = join(realpath(abspath(p)), '**', include)
153
- logger.debug(f'Processing {location}')
153
+ location = join(realpath(abspath(p)), "**", include)
154
+ logger.debug(f"Processing {location}")
154
155
  for file in braced_iglob(pathname=location, recursive=True):
155
156
  name = file
156
157
  entries.append(PathInfo(abs_path=file, audio_filepath=name))
157
- logger.info(f'{len(data_paths)} data paths specified, found {len(entries)} audio files.')
158
+ logger.info(f"{len(data_paths)} data paths specified, found {len(entries)} audio files.")
158
159
 
159
- if in0type.find('float16') != -1:
160
+ if in0type.find("float16") != -1:
160
161
  model_is_fp16 = True
161
- logger.info(f'Detected input of float16, converting all feature inputs to that type.')
162
+ logger.info("Detected input of float16, converting all feature inputs to that type.")
162
163
  else:
163
164
  model_is_fp16 = False
164
165
 
165
- if mixdb_path is not None and hparams['batch_size'] == 1:
166
+ if mixdb is not None and hparams["batch_size"] == 1:
166
167
  # mixdb input
167
168
  # Assume (of course) that mixdb feature, etc. is what model expects
168
- if hparams['feature'] != mixdb.feature:
169
- logger.warning(f'Mixture feature does not match model feature, this inference run may fail.')
169
+ if hparams["feature"] != mixdb.feature:
170
+ logger.warning("Mixture feature does not match model feature, this inference run may fail.")
170
171
  # no choice, can't use hparams.feature since it's different from the mixdb
171
172
  feature_mode = mixdb.feature
172
173
 
173
174
  for mixid in p_mixids:
174
175
  # frames x stride x feature_params
175
176
  feature, _ = mixdb.mixture_ft(mixid)
176
- if hparams['timesteps'] == 0:
177
+ if hparams["timesteps"] == 0:
177
178
  # no timestep dimension, reshape will handle
178
179
  timesteps = 0
179
180
  else:
180
181
  # fit frames into timestep dimension (TSE mode)
181
182
  timesteps = feature.shape[0]
182
183
 
183
- feature, _ = reshape_inputs(feature=feature,
184
- batch_size=1,
185
- timesteps=timesteps,
186
- flatten=hparams['flatten'],
187
- add1ch=hparams['add1ch'])
184
+ feature, _ = reshape_inputs(
185
+ feature=feature,
186
+ batch_size=1,
187
+ timesteps=timesteps,
188
+ flatten=hparams["flatten"],
189
+ add1ch=hparams["add1ch"],
190
+ )
188
191
  if model_is_fp16:
189
- feature = np.float16(feature) # type: ignore
192
+ feature = np.float16(feature) # type: ignore[assignment]
190
193
  # run inference, ort session wants i.e. batch x timesteps x feat_params, outputs numpy BxTxFP or BxFP
191
194
  predict = session.run(out_names, {in0name: feature})[0]
192
195
  # predict, _ = reshape_outputs(predict=predict[0], timesteps=frames) # frames x feat_params
193
196
  output_fname = join(output_dir, mixdb.mixtures[mixid].name)
194
- with h5py.File(output_fname, 'a') as f:
195
- if 'predict' in f:
196
- del f['predict']
197
- f.create_dataset('predict', data=predict)
197
+ with h5py.File(output_fname, "a") as f:
198
+ if "predict" in f:
199
+ del f["predict"]
200
+ f.create_dataset("predict", data=predict)
198
201
  if wav:
199
202
  # note only makes sense if model is predicting audio, i.e., timestep dimension exists
200
203
  # predict_audio wants [frames, channels, feature_parameters] equivalent to timesteps, batch, bins
201
204
  predict = np.transpose(predict, [1, 0, 2])
202
205
  predict_audio = get_audio_from_feature(feature=predict, feature_mode=feature_mode)
203
- owav_name = splitext(output_fname)[0] + '_predict.wav'
206
+ owav_name = splitext(output_fname)[0] + "_predict.wav"
204
207
  write_audio(owav_name, predict_audio)
205
208
 
206
209
 
207
- if __name__ == '__main__':
210
+ if __name__ == "__main__":
208
211
  main()
@@ -1,6 +1,8 @@
1
1
  # SonusAI query utilities
2
+ # ruff: noqa: F401
3
+
2
4
  from .queries import get_mixids_from_noise
3
5
  from .queries import get_mixids_from_snr
4
6
  from .queries import get_mixids_from_target
5
7
  from .queries import get_mixids_from_truth_function
6
- from .queries import get_mixids_from_truth_index
8
+ from .queries import get_mixids_from_class_indices
@@ -1,14 +1,16 @@
1
+ from collections.abc import Callable
1
2
  from typing import Any
2
- from typing import Callable
3
3
 
4
4
  from sonusai.mixture.datatypes import GeneralizedIDs
5
5
  from sonusai.mixture.mixdb import MixtureDatabase
6
6
 
7
7
 
8
- def get_mixids_from_mixture_field_predicate(mixdb: MixtureDatabase,
9
- field: str,
10
- mixids: GeneralizedIDs = None,
11
- predicate: Callable[[Any], bool] = None) -> dict[int, list[int]]:
8
+ def get_mixids_from_mixture_field_predicate(
9
+ mixdb: MixtureDatabase,
10
+ field: str,
11
+ mixids: GeneralizedIDs = "*",
12
+ predicate: Callable[[Any], bool] | None = None,
13
+ ) -> dict[int, list[int]]:
12
14
  """
13
15
  Generate mixture IDs based on mixture field and predicate
14
16
  Return a dictionary where:
@@ -18,6 +20,7 @@ def get_mixids_from_mixture_field_predicate(mixdb: MixtureDatabase,
18
20
  mixid_out = mixdb.mixids_to_list(mixids)
19
21
 
20
22
  if predicate is None:
23
+
21
24
  def predicate(_: Any) -> bool:
22
25
  return True
23
26
 
@@ -30,7 +33,7 @@ def get_mixids_from_mixture_field_predicate(mixdb: MixtureDatabase,
30
33
  criteria_set.add(v)
31
34
  elif predicate(value):
32
35
  criteria_set.add(value)
33
- criteria = sorted(list(criteria_set))
36
+ criteria = sorted(criteria_set)
34
37
 
35
38
  result: dict[int, list[int]] = {}
36
39
  for criterion in criteria:
@@ -47,22 +50,27 @@ def get_mixids_from_mixture_field_predicate(mixdb: MixtureDatabase,
47
50
  return result
48
51
 
49
52
 
50
- def get_mixids_from_truth_settings_field_predicate(mixdb: MixtureDatabase,
51
- field: str,
52
- mixids: GeneralizedIDs = None,
53
- predicate: Callable[[Any], bool] = None) -> dict[int, list[int]]:
53
+ def get_mixids_from_truth_configs_field_predicate(
54
+ mixdb: MixtureDatabase,
55
+ field: str,
56
+ mixids: GeneralizedIDs = "*",
57
+ predicate: Callable[[Any], bool] | None = None,
58
+ ) -> dict[int, list[int]]:
54
59
  """
55
- Generate mixture IDs based on target truth_settings field and predicate
60
+ Generate mixture IDs based on target truth_configs field and predicate
56
61
  Return a dictionary where:
57
62
  - keys are the matching field values
58
63
  - values are lists of the mixids that match the criteria
59
64
  """
65
+ from sonusai.mixture import REQUIRED_TRUTH_CONFIGS
66
+
60
67
  mixid_out = mixdb.mixids_to_list(mixids)
61
68
 
62
69
  # Get all field values
63
- values = get_all_truth_settings_values_from_field(mixdb, field)
70
+ values = get_all_truth_configs_values_from_field(mixdb, field)
64
71
 
65
72
  if predicate is None:
73
+
66
74
  def predicate(_: Any) -> bool:
67
75
  return True
68
76
 
@@ -75,10 +83,14 @@ def get_mixids_from_truth_settings_field_predicate(mixdb: MixtureDatabase,
75
83
  indices = []
76
84
  for t_id in mixdb.target_file_ids:
77
85
  target = mixdb.target_file(t_id)
78
- for truth_setting in target.truth_settings:
79
- if value in getattr(truth_setting, field):
80
- indices.append(t_id)
81
- indices = sorted(list(set(indices)))
86
+ for truth_config in target.truth_configs.values():
87
+ if field in REQUIRED_TRUTH_CONFIGS:
88
+ if value in getattr(truth_config, field):
89
+ indices.append(t_id)
90
+ else:
91
+ if value in getattr(truth_config.config, field):
92
+ indices.append(t_id)
93
+ indices = sorted(set(indices))
82
94
 
83
95
  mixids = []
84
96
  for index in indices:
@@ -86,61 +98,66 @@ def get_mixids_from_truth_settings_field_predicate(mixdb: MixtureDatabase,
86
98
  if index in [target.file_id for target in mixdb.mixture(m_id).targets]:
87
99
  mixids.append(m_id)
88
100
 
89
- mixids = sorted(list(set(mixids)))
101
+ mixids = sorted(set(mixids))
90
102
  if mixids:
91
103
  result[value] = mixids
92
104
 
93
105
  return result
94
106
 
95
107
 
96
- def get_all_truth_settings_values_from_field(mixdb: MixtureDatabase, field: str) -> list:
108
+ def get_all_truth_configs_values_from_field(mixdb: MixtureDatabase, field: str) -> list:
97
109
  """
98
- Generate a list of all values corresponding to the given field in truth_settings
110
+ Generate a list of all values corresponding to the given field in truth_configs
99
111
  """
112
+ from sonusai.mixture import REQUIRED_TRUTH_CONFIGS
113
+
100
114
  result = []
101
115
  for target in mixdb.target_files:
102
- for truth_setting in target.truth_settings:
103
- value = getattr(truth_setting, field)
116
+ for truth_config in target.truth_configs.values():
117
+ if field in REQUIRED_TRUTH_CONFIGS:
118
+ value = getattr(truth_config, field)
119
+ else:
120
+ value = getattr(truth_config.config, field, None)
104
121
  if isinstance(value, str):
105
122
  value = [value]
106
123
  result.extend(value)
107
124
 
108
- return sorted(list(set(result)))
125
+ return sorted(set(result))
109
126
 
110
127
 
111
- def get_mixids_from_noise(mixdb: MixtureDatabase,
112
- mixids: GeneralizedIDs = None,
113
- predicate: Callable[[Any], bool] = None) -> dict[int, list[int]]:
128
+ def get_mixids_from_noise(
129
+ mixdb: MixtureDatabase,
130
+ mixids: GeneralizedIDs = "*",
131
+ predicate: Callable[[Any], bool] | None = None,
132
+ ) -> dict[int, list[int]]:
114
133
  """
115
134
  Generate mixids based on noise index predicate
116
135
  Return a dictionary where:
117
136
  - keys are the noise indices
118
137
  - values are lists of the mixids that match the noise index
119
138
  """
120
- return get_mixids_from_mixture_field_predicate(mixdb=mixdb,
121
- mixids=mixids,
122
- field='noise_id',
123
- predicate=predicate)
139
+ return get_mixids_from_mixture_field_predicate(mixdb=mixdb, mixids=mixids, field="noise_id", predicate=predicate)
124
140
 
125
141
 
126
- def get_mixids_from_target(mixdb: MixtureDatabase,
127
- mixids: GeneralizedIDs = None,
128
- predicate: Callable[[Any], bool] = None) -> dict[int, list[int]]:
142
+ def get_mixids_from_target(
143
+ mixdb: MixtureDatabase,
144
+ mixids: GeneralizedIDs = "*",
145
+ predicate: Callable[[Any], bool] | None = None,
146
+ ) -> dict[int, list[int]]:
129
147
  """
130
148
  Generate mixids based on a target index predicate
131
149
  Return a dictionary where:
132
150
  - keys are the target indices
133
151
  - values are lists of the mixids that match the target index
134
152
  """
135
- return get_mixids_from_mixture_field_predicate(mixdb=mixdb,
136
- mixids=mixids,
137
- field='target_ids',
138
- predicate=predicate)
153
+ return get_mixids_from_mixture_field_predicate(mixdb=mixdb, mixids=mixids, field="target_ids", predicate=predicate)
139
154
 
140
155
 
141
- def get_mixids_from_snr(mixdb: MixtureDatabase,
142
- mixids: GeneralizedIDs = None,
143
- predicate: Callable[[Any], bool] = None) -> dict[float, list[int]]:
156
+ def get_mixids_from_snr(
157
+ mixdb: MixtureDatabase,
158
+ mixids: GeneralizedIDs = "*",
159
+ predicate: Callable[[Any], bool] | None = None,
160
+ ) -> dict[float, list[int]]:
144
161
  """
145
162
  Generate mixids based on an SNR predicate
146
163
  Return a dictionary where:
@@ -155,46 +172,70 @@ def get_mixids_from_snr(mixdb: MixtureDatabase,
155
172
  snrs = [float(snr) for snr in mixdb.all_snrs if not snr.is_random]
156
173
 
157
174
  if predicate is None:
175
+
158
176
  def predicate(_: Any) -> bool:
159
177
  return True
160
178
 
161
179
  # Get only the SNRs of interest (filter on predicate)
162
180
  snrs = [snr for snr in snrs if predicate(snr)]
163
181
 
164
- result = {}
182
+ result: dict[float, list[int]] = {}
165
183
  for snr in snrs:
166
184
  # Get a list of mixids for each SNR
167
- result[snr] = sorted(
168
- [i for i, mixture in enumerate(mixdb.mixtures) if mixture.snr == snr and i in mixid_out])
185
+ result[snr] = sorted([i for i, mixture in enumerate(mixdb.mixtures) if mixture.snr == snr and i in mixid_out])
169
186
 
170
187
  return result
171
188
 
172
189
 
173
- def get_mixids_from_truth_index(mixdb: MixtureDatabase,
174
- mixids: GeneralizedIDs = None,
175
- predicate: Callable[[Any], bool] = None) -> dict[int, list[int]]:
190
+ def get_mixids_from_class_indices(
191
+ mixdb: MixtureDatabase,
192
+ mixids: GeneralizedIDs = "*",
193
+ predicate: Callable[[Any], bool] | None = None,
194
+ ) -> dict[int, list[int]]:
176
195
  """
177
- Generate mixids based on a truth index predicate
196
+ Generate mixids based on a class index predicate
178
197
  Return a dictionary where:
179
- - keys are the truth indices
180
- - values are lists of the mixids that match the truth index
198
+ - keys are the class indices
199
+ - values are lists of the mixids that match the class index
181
200
  """
182
- return get_mixids_from_truth_settings_field_predicate(mixdb=mixdb,
183
- mixids=mixids,
184
- field='index',
185
- predicate=predicate)
201
+ mixid_out = mixdb.mixids_to_list(mixids)
202
+
203
+ if predicate is None:
204
+
205
+ def predicate(_: Any) -> bool:
206
+ return True
207
+
208
+ criteria_set = set()
209
+ for m_id in mixid_out:
210
+ class_indices = mixdb.mixture_class_indices(m_id)
211
+ for class_index in class_indices:
212
+ if predicate(class_index):
213
+ criteria_set.add(class_index)
214
+ criteria = sorted(criteria_set)
215
+
216
+ result: dict[int, list[int]] = {}
217
+ for criterion in criteria:
218
+ result[criterion] = []
219
+ for m_id in mixid_out:
220
+ class_indices = mixdb.mixture_class_indices(m_id)
221
+ for class_index in class_indices:
222
+ if class_index == criterion:
223
+ result[criterion].append(m_id)
224
+
225
+ return result
186
226
 
187
227
 
188
- def get_mixids_from_truth_function(mixdb: MixtureDatabase,
189
- mixids: GeneralizedIDs = None,
190
- predicate: Callable[[Any], bool] = None) -> dict[int, list[int]]:
228
+ def get_mixids_from_truth_function(
229
+ mixdb: MixtureDatabase,
230
+ mixids: GeneralizedIDs = "*",
231
+ predicate: Callable[[Any], bool] | None = None,
232
+ ) -> dict[int, list[int]]:
191
233
  """
192
234
  Generate mixids based on a truth function predicate
193
235
  Return a dictionary where:
194
236
  - keys are the truth functions
195
237
  - values are lists of the mixids that match the truth function
196
238
  """
197
- return get_mixids_from_truth_settings_field_predicate(mixdb=mixdb,
198
- mixids=mixids,
199
- field='function',
200
- predicate=predicate)
239
+ return get_mixids_from_truth_configs_field_predicate(
240
+ mixdb=mixdb, mixids=mixids, field="function", predicate=predicate
241
+ )
@@ -1,3 +1,5 @@
1
+ # ruff: noqa: F401
2
+
1
3
  from .textgrid import annotate_textgrid
2
4
  from .textgrid import create_textgrid
3
5
  from .types import TimeAlignedType
@@ -1,7 +1,6 @@
1
1
  import os
2
2
  import string
3
3
  from pathlib import Path
4
- from typing import Optional
5
4
 
6
5
  from .types import TimeAlignedType
7
6
 
@@ -9,54 +8,54 @@ from .types import TimeAlignedType
9
8
  def _get_duration(name: str) -> float:
10
9
  import soundfile
11
10
 
12
- from sonusai import SonusAIError
13
-
14
11
  try:
15
12
  return soundfile.info(name).duration
16
13
  except Exception as e:
17
- raise SonusAIError(f'Error reading {name}: {e}')
14
+ raise OSError(f"Error reading {name}: {e}") from e
18
15
 
19
16
 
20
- def load_text(audio: str | os.PathLike[str]) -> Optional[TimeAlignedType]:
17
+ def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
21
18
  """Load time-aligned text data given a L2-ARCTIC audio file.
22
19
 
23
20
  :param audio: Path to the L2-ARCTIC audio file.
24
21
  :return: A TimeAlignedType object.
25
22
  """
26
- file = Path(audio).parent.parent / 'transcript' / (Path(audio).stem + '.txt')
23
+ file = Path(audio).parent.parent / "transcript" / (Path(audio).stem + ".txt")
27
24
  if not os.path.exists(file):
28
25
  return None
29
26
 
30
- with open(file, mode='r', encoding='utf-8') as f:
27
+ with open(file, encoding="utf-8") as f:
31
28
  line = f.read()
32
29
 
33
- return TimeAlignedType(0,
34
- _get_duration(str(audio)),
35
- line.strip().lower().translate(str.maketrans('', '', string.punctuation)))
30
+ return TimeAlignedType(
31
+ 0,
32
+ _get_duration(str(audio)),
33
+ line.strip().lower().translate(str.maketrans("", "", string.punctuation)),
34
+ )
36
35
 
37
36
 
38
- def load_words(audio: str | os.PathLike[str]) -> Optional[list[TimeAlignedType]]:
37
+ def load_words(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
39
38
  """Load time-aligned word data given a L2-ARCTIC audio file.
40
39
 
41
40
  :param audio: Path to the L2-ARCTIC audio file.
42
41
  :return: A list of TimeAlignedType objects.
43
42
  """
44
- return _load_ta(audio, 'words')
43
+ return _load_ta(audio, "words")
45
44
 
46
45
 
47
- def load_phonemes(audio: str | os.PathLike[str]) -> Optional[list[TimeAlignedType]]:
46
+ def load_phonemes(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
48
47
  """Load time-aligned phonemes data given a L2-ARCTIC audio file.
49
48
 
50
49
  :param audio: Path to the L2-ARCTIC audio file.
51
50
  :return: A list of TimeAlignedType objects.
52
51
  """
53
- return _load_ta(audio, 'phones')
52
+ return _load_ta(audio, "phones")
54
53
 
55
54
 
56
- def _load_ta(audio: str | os.PathLike[str], tier: str) -> Optional[list[TimeAlignedType]]:
55
+ def _load_ta(audio: str | os.PathLike[str], tier: str) -> list[TimeAlignedType] | None:
57
56
  from praatio import textgrid
58
57
 
59
- file = Path(audio).parent.parent / 'textgrid' / (Path(audio).stem + '.TextGrid')
58
+ file = Path(audio).parent.parent / "textgrid" / (Path(audio).stem + ".TextGrid")
60
59
  if not os.path.exists(file):
61
60
  return None
62
61
 
@@ -71,7 +70,9 @@ def _load_ta(audio: str | os.PathLike[str], tier: str) -> Optional[list[TimeAlig
71
70
  return entries
72
71
 
73
72
 
74
- def load_annotations(audio: str | os.PathLike[str]) -> Optional[dict[str, list[TimeAlignedType]]]:
73
+ def load_annotations(
74
+ audio: str | os.PathLike[str],
75
+ ) -> dict[str, list[TimeAlignedType]] | None:
75
76
  """Load time-aligned annotation data given a L2-ARCTIC audio file.
76
77
 
77
78
  :param audio: Path to the L2-ARCTIC audio file.
@@ -79,7 +80,7 @@ def load_annotations(audio: str | os.PathLike[str]) -> Optional[dict[str, list[T
79
80
  """
80
81
  from praatio import textgrid
81
82
 
82
- file = Path(audio).parent.parent / 'annotation' / (Path(audio).stem + '.TextGrid')
83
+ file = Path(audio).parent.parent / "annotation" / (Path(audio).stem + ".TextGrid")
83
84
  if not os.path.exists(file):
84
85
  return None
85
86
 
@@ -96,21 +97,21 @@ def load_annotations(audio: str | os.PathLike[str]) -> Optional[dict[str, list[T
96
97
 
97
98
  def load_speakers(input_dir: Path) -> dict:
98
99
  speakers = {}
99
- with open(input_dir / 'readme-download.txt') as file:
100
+ with open(input_dir / "readme-download.txt") as file:
100
101
  processing = False
101
102
  for line in file:
102
- if not processing and line.startswith('|---|'):
103
+ if not processing and line.startswith("|---|"):
103
104
  processing = True
104
105
  continue
105
106
 
106
107
  if processing:
107
- if line.startswith('|**Total**|'):
108
+ if line.startswith("|**Total**|"):
108
109
  break
109
110
  else:
110
- fields = line.strip().split('|')
111
+ fields = line.strip().split("|")
111
112
  speaker_id = fields[1]
112
113
  gender = fields[2]
113
114
  dialect = fields[3]
114
- speakers[speaker_id] = {'gender': gender, 'dialect': dialect}
115
+ speakers[speaker_id] = {"gender": gender, "dialect": dialect}
115
116
 
116
117
  return speakers