sonusai 0.18.9__py3-none-any.whl → 0.19.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +93 -77
- sonusai/genmetrics.py +59 -46
- sonusai/genmix.py +116 -104
- sonusai/genmixdb.py +194 -153
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +15 -17
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +19 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +40 -38
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +669 -477
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +52 -85
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +40 -27
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
- sonusai-0.19.5.dist-info/RECORD +125 -0
- {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.9.dist-info/RECORD +0 -125
- {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
sonusai/onnx_predict.py
CHANGED
@@ -41,11 +41,12 @@ TBD not sure below make sense, need to continue ??
|
|
41
41
|
3. Classification
|
42
42
|
|
43
43
|
Outputs the following to opredict-<TIMESTAMP> directory:
|
44
|
-
<id
|
45
|
-
|
44
|
+
<id>
|
45
|
+
predict.pkl
|
46
46
|
onnx_predict.log
|
47
47
|
|
48
48
|
"""
|
49
|
+
|
49
50
|
import signal
|
50
51
|
|
51
52
|
|
@@ -54,7 +55,7 @@ def signal_handler(_sig, _frame):
|
|
54
55
|
|
55
56
|
from sonusai import logger
|
56
57
|
|
57
|
-
logger.info(
|
58
|
+
logger.info("Canceled due to keyboard interrupt")
|
58
59
|
sys.exit(1)
|
59
60
|
|
60
61
|
|
@@ -69,12 +70,12 @@ def main() -> None:
|
|
69
70
|
|
70
71
|
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
71
72
|
|
72
|
-
verbose = args[
|
73
|
-
wav = args[
|
74
|
-
mixids = args[
|
75
|
-
include = args[
|
76
|
-
model_path = args[
|
77
|
-
data_paths = args[
|
73
|
+
verbose = args["--verbose"]
|
74
|
+
wav = args["--write-wav"]
|
75
|
+
mixids = args["--mixid"]
|
76
|
+
include = args["--include"]
|
77
|
+
model_path = args["MODEL"]
|
78
|
+
data_paths = args["DATA"]
|
78
79
|
|
79
80
|
from os import makedirs
|
80
81
|
from os.path import abspath
|
@@ -103,8 +104,8 @@ def main() -> None:
|
|
103
104
|
from sonusai.utils import write_audio
|
104
105
|
|
105
106
|
mixdb_path = None
|
106
|
-
mixdb = None
|
107
|
-
p_mixids =
|
107
|
+
mixdb: MixtureDatabase | None = None
|
108
|
+
p_mixids: list[int] = []
|
108
109
|
entries: list[PathInfo] = []
|
109
110
|
|
110
111
|
if len(data_paths) == 1 and isdir(data_paths[0]):
|
@@ -113,96 +114,98 @@ def main() -> None:
|
|
113
114
|
mixdb_path = data_paths[0]
|
114
115
|
else:
|
115
116
|
# search all data paths for .wav, .flac (or whatever is specified in include)
|
116
|
-
in_basename =
|
117
|
+
in_basename = ""
|
117
118
|
|
118
|
-
output_dir = create_ts_name(
|
119
|
+
output_dir = create_ts_name("opredict-" + in_basename)
|
119
120
|
makedirs(output_dir, exist_ok=True)
|
120
121
|
|
121
122
|
# Setup logging file
|
122
|
-
create_file_handler(join(output_dir,
|
123
|
+
create_file_handler(join(output_dir, "onnx-predict.log"))
|
123
124
|
update_console_handler(verbose)
|
124
|
-
initial_log_messages(
|
125
|
+
initial_log_messages("onnx_predict")
|
125
126
|
|
126
127
|
providers = ort.get_available_providers()
|
127
|
-
logger.info(f
|
128
|
+
logger.info(f"Loaded ONNX Runtime, available providers: {providers}.")
|
128
129
|
|
129
130
|
session, options, model_root, hparams, sess_inputs, sess_outputs = load_ort_session(model_path)
|
130
131
|
if hparams is None:
|
131
|
-
logger.error(
|
132
|
+
logger.error("Error: ONNX model does not have required SonusAI hyperparameters, cannot proceed.")
|
132
133
|
raise SystemExit(1)
|
133
134
|
if len(sess_inputs) != 1:
|
134
|
-
logger.error(f
|
135
|
+
logger.error(f"Error: ONNX model does not have 1 input, but {len(sess_inputs)}. Exit due to unknown input.")
|
135
136
|
|
136
137
|
in0name = sess_inputs[0].name
|
137
138
|
in0type = sess_inputs[0].type
|
138
139
|
out_names = [n.name for n in session.get_outputs()]
|
139
140
|
|
140
|
-
logger.info(f
|
141
|
+
logger.info(f"Read and compiled ONNX model from {model_path}.")
|
141
142
|
|
142
143
|
if mixdb_path is not None:
|
143
144
|
# Assume it's a single path to SonusAI mixdb subdir
|
144
|
-
logger.debug(f
|
145
|
+
logger.debug(f"Attempting to load mixture database from {mixdb_path}")
|
145
146
|
mixdb = MixtureDatabase(mixdb_path)
|
146
|
-
logger.info(f
|
147
|
+
logger.info(f"SonusAI mixdb: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes")
|
147
148
|
p_mixids = mixdb.mixids_to_list(mixids)
|
148
149
|
if len(p_mixids) != mixdb.num_mixtures:
|
149
|
-
logger.info(f
|
150
|
+
logger.info(f"Processing a subset of {p_mixids} from available mixtures.")
|
150
151
|
else:
|
151
152
|
for p in data_paths:
|
152
|
-
location = join(realpath(abspath(p)),
|
153
|
-
logger.debug(f
|
153
|
+
location = join(realpath(abspath(p)), "**", include)
|
154
|
+
logger.debug(f"Processing {location}")
|
154
155
|
for file in braced_iglob(pathname=location, recursive=True):
|
155
156
|
name = file
|
156
157
|
entries.append(PathInfo(abs_path=file, audio_filepath=name))
|
157
|
-
logger.info(f
|
158
|
+
logger.info(f"{len(data_paths)} data paths specified, found {len(entries)} audio files.")
|
158
159
|
|
159
|
-
if in0type.find(
|
160
|
+
if in0type.find("float16") != -1:
|
160
161
|
model_is_fp16 = True
|
161
|
-
logger.info(
|
162
|
+
logger.info("Detected input of float16, converting all feature inputs to that type.")
|
162
163
|
else:
|
163
164
|
model_is_fp16 = False
|
164
165
|
|
165
|
-
if
|
166
|
+
if mixdb is not None and hparams["batch_size"] == 1:
|
166
167
|
# mixdb input
|
167
168
|
# Assume (of course) that mixdb feature, etc. is what model expects
|
168
|
-
if hparams[
|
169
|
-
logger.warning(
|
169
|
+
if hparams["feature"] != mixdb.feature:
|
170
|
+
logger.warning("Mixture feature does not match model feature, this inference run may fail.")
|
170
171
|
# no choice, can't use hparams.feature since it's different from the mixdb
|
171
172
|
feature_mode = mixdb.feature
|
172
173
|
|
173
174
|
for mixid in p_mixids:
|
174
175
|
# frames x stride x feature_params
|
175
176
|
feature, _ = mixdb.mixture_ft(mixid)
|
176
|
-
if hparams[
|
177
|
+
if hparams["timesteps"] == 0:
|
177
178
|
# no timestep dimension, reshape will handle
|
178
179
|
timesteps = 0
|
179
180
|
else:
|
180
181
|
# fit frames into timestep dimension (TSE mode)
|
181
182
|
timesteps = feature.shape[0]
|
182
183
|
|
183
|
-
feature, _ = reshape_inputs(
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
184
|
+
feature, _ = reshape_inputs(
|
185
|
+
feature=feature,
|
186
|
+
batch_size=1,
|
187
|
+
timesteps=timesteps,
|
188
|
+
flatten=hparams["flatten"],
|
189
|
+
add1ch=hparams["add1ch"],
|
190
|
+
)
|
188
191
|
if model_is_fp16:
|
189
|
-
feature = np.float16(feature) # type: ignore
|
192
|
+
feature = np.float16(feature) # type: ignore[assignment]
|
190
193
|
# run inference, ort session wants i.e. batch x timesteps x feat_params, outputs numpy BxTxFP or BxFP
|
191
194
|
predict = session.run(out_names, {in0name: feature})[0]
|
192
195
|
# predict, _ = reshape_outputs(predict=predict[0], timesteps=frames) # frames x feat_params
|
193
196
|
output_fname = join(output_dir, mixdb.mixtures[mixid].name)
|
194
|
-
with h5py.File(output_fname,
|
195
|
-
if
|
196
|
-
del f[
|
197
|
-
f.create_dataset(
|
197
|
+
with h5py.File(output_fname, "a") as f:
|
198
|
+
if "predict" in f:
|
199
|
+
del f["predict"]
|
200
|
+
f.create_dataset("predict", data=predict)
|
198
201
|
if wav:
|
199
202
|
# note only makes sense if model is predicting audio, i.e., timestep dimension exists
|
200
203
|
# predict_audio wants [frames, channels, feature_parameters] equivalent to timesteps, batch, bins
|
201
204
|
predict = np.transpose(predict, [1, 0, 2])
|
202
205
|
predict_audio = get_audio_from_feature(feature=predict, feature_mode=feature_mode)
|
203
|
-
owav_name = splitext(output_fname)[0] +
|
206
|
+
owav_name = splitext(output_fname)[0] + "_predict.wav"
|
204
207
|
write_audio(owav_name, predict_audio)
|
205
208
|
|
206
209
|
|
207
|
-
if __name__ ==
|
210
|
+
if __name__ == "__main__":
|
208
211
|
main()
|
sonusai/queries/__init__.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
# SonusAI query utilities
|
2
|
+
# ruff: noqa: F401
|
3
|
+
|
2
4
|
from .queries import get_mixids_from_noise
|
3
5
|
from .queries import get_mixids_from_snr
|
4
6
|
from .queries import get_mixids_from_target
|
5
7
|
from .queries import get_mixids_from_truth_function
|
6
|
-
from .queries import
|
8
|
+
from .queries import get_mixids_from_class_indices
|
sonusai/queries/queries.py
CHANGED
@@ -1,14 +1,16 @@
|
|
1
|
+
from collections.abc import Callable
|
1
2
|
from typing import Any
|
2
|
-
from typing import Callable
|
3
3
|
|
4
4
|
from sonusai.mixture.datatypes import GeneralizedIDs
|
5
5
|
from sonusai.mixture.mixdb import MixtureDatabase
|
6
6
|
|
7
7
|
|
8
|
-
def get_mixids_from_mixture_field_predicate(
|
9
|
-
|
10
|
-
|
11
|
-
|
8
|
+
def get_mixids_from_mixture_field_predicate(
|
9
|
+
mixdb: MixtureDatabase,
|
10
|
+
field: str,
|
11
|
+
mixids: GeneralizedIDs = "*",
|
12
|
+
predicate: Callable[[Any], bool] | None = None,
|
13
|
+
) -> dict[int, list[int]]:
|
12
14
|
"""
|
13
15
|
Generate mixture IDs based on mixture field and predicate
|
14
16
|
Return a dictionary where:
|
@@ -18,6 +20,7 @@ def get_mixids_from_mixture_field_predicate(mixdb: MixtureDatabase,
|
|
18
20
|
mixid_out = mixdb.mixids_to_list(mixids)
|
19
21
|
|
20
22
|
if predicate is None:
|
23
|
+
|
21
24
|
def predicate(_: Any) -> bool:
|
22
25
|
return True
|
23
26
|
|
@@ -30,7 +33,7 @@ def get_mixids_from_mixture_field_predicate(mixdb: MixtureDatabase,
|
|
30
33
|
criteria_set.add(v)
|
31
34
|
elif predicate(value):
|
32
35
|
criteria_set.add(value)
|
33
|
-
criteria = sorted(
|
36
|
+
criteria = sorted(criteria_set)
|
34
37
|
|
35
38
|
result: dict[int, list[int]] = {}
|
36
39
|
for criterion in criteria:
|
@@ -47,22 +50,27 @@ def get_mixids_from_mixture_field_predicate(mixdb: MixtureDatabase,
|
|
47
50
|
return result
|
48
51
|
|
49
52
|
|
50
|
-
def
|
51
|
-
|
52
|
-
|
53
|
-
|
53
|
+
def get_mixids_from_truth_configs_field_predicate(
|
54
|
+
mixdb: MixtureDatabase,
|
55
|
+
field: str,
|
56
|
+
mixids: GeneralizedIDs = "*",
|
57
|
+
predicate: Callable[[Any], bool] | None = None,
|
58
|
+
) -> dict[int, list[int]]:
|
54
59
|
"""
|
55
|
-
Generate mixture IDs based on target
|
60
|
+
Generate mixture IDs based on target truth_configs field and predicate
|
56
61
|
Return a dictionary where:
|
57
62
|
- keys are the matching field values
|
58
63
|
- values are lists of the mixids that match the criteria
|
59
64
|
"""
|
65
|
+
from sonusai.mixture import REQUIRED_TRUTH_CONFIGS
|
66
|
+
|
60
67
|
mixid_out = mixdb.mixids_to_list(mixids)
|
61
68
|
|
62
69
|
# Get all field values
|
63
|
-
values =
|
70
|
+
values = get_all_truth_configs_values_from_field(mixdb, field)
|
64
71
|
|
65
72
|
if predicate is None:
|
73
|
+
|
66
74
|
def predicate(_: Any) -> bool:
|
67
75
|
return True
|
68
76
|
|
@@ -75,10 +83,14 @@ def get_mixids_from_truth_settings_field_predicate(mixdb: MixtureDatabase,
|
|
75
83
|
indices = []
|
76
84
|
for t_id in mixdb.target_file_ids:
|
77
85
|
target = mixdb.target_file(t_id)
|
78
|
-
for
|
79
|
-
if
|
80
|
-
|
81
|
-
|
86
|
+
for truth_config in target.truth_configs.values():
|
87
|
+
if field in REQUIRED_TRUTH_CONFIGS:
|
88
|
+
if value in getattr(truth_config, field):
|
89
|
+
indices.append(t_id)
|
90
|
+
else:
|
91
|
+
if value in getattr(truth_config.config, field):
|
92
|
+
indices.append(t_id)
|
93
|
+
indices = sorted(set(indices))
|
82
94
|
|
83
95
|
mixids = []
|
84
96
|
for index in indices:
|
@@ -86,61 +98,66 @@ def get_mixids_from_truth_settings_field_predicate(mixdb: MixtureDatabase,
|
|
86
98
|
if index in [target.file_id for target in mixdb.mixture(m_id).targets]:
|
87
99
|
mixids.append(m_id)
|
88
100
|
|
89
|
-
mixids = sorted(
|
101
|
+
mixids = sorted(set(mixids))
|
90
102
|
if mixids:
|
91
103
|
result[value] = mixids
|
92
104
|
|
93
105
|
return result
|
94
106
|
|
95
107
|
|
96
|
-
def
|
108
|
+
def get_all_truth_configs_values_from_field(mixdb: MixtureDatabase, field: str) -> list:
|
97
109
|
"""
|
98
|
-
Generate a list of all values corresponding to the given field in
|
110
|
+
Generate a list of all values corresponding to the given field in truth_configs
|
99
111
|
"""
|
112
|
+
from sonusai.mixture import REQUIRED_TRUTH_CONFIGS
|
113
|
+
|
100
114
|
result = []
|
101
115
|
for target in mixdb.target_files:
|
102
|
-
for
|
103
|
-
|
116
|
+
for truth_config in target.truth_configs.values():
|
117
|
+
if field in REQUIRED_TRUTH_CONFIGS:
|
118
|
+
value = getattr(truth_config, field)
|
119
|
+
else:
|
120
|
+
value = getattr(truth_config.config, field, None)
|
104
121
|
if isinstance(value, str):
|
105
122
|
value = [value]
|
106
123
|
result.extend(value)
|
107
124
|
|
108
|
-
return sorted(
|
125
|
+
return sorted(set(result))
|
109
126
|
|
110
127
|
|
111
|
-
def get_mixids_from_noise(
|
112
|
-
|
113
|
-
|
128
|
+
def get_mixids_from_noise(
|
129
|
+
mixdb: MixtureDatabase,
|
130
|
+
mixids: GeneralizedIDs = "*",
|
131
|
+
predicate: Callable[[Any], bool] | None = None,
|
132
|
+
) -> dict[int, list[int]]:
|
114
133
|
"""
|
115
134
|
Generate mixids based on noise index predicate
|
116
135
|
Return a dictionary where:
|
117
136
|
- keys are the noise indices
|
118
137
|
- values are lists of the mixids that match the noise index
|
119
138
|
"""
|
120
|
-
return get_mixids_from_mixture_field_predicate(mixdb=mixdb,
|
121
|
-
mixids=mixids,
|
122
|
-
field='noise_id',
|
123
|
-
predicate=predicate)
|
139
|
+
return get_mixids_from_mixture_field_predicate(mixdb=mixdb, mixids=mixids, field="noise_id", predicate=predicate)
|
124
140
|
|
125
141
|
|
126
|
-
def get_mixids_from_target(
|
127
|
-
|
128
|
-
|
142
|
+
def get_mixids_from_target(
|
143
|
+
mixdb: MixtureDatabase,
|
144
|
+
mixids: GeneralizedIDs = "*",
|
145
|
+
predicate: Callable[[Any], bool] | None = None,
|
146
|
+
) -> dict[int, list[int]]:
|
129
147
|
"""
|
130
148
|
Generate mixids based on a target index predicate
|
131
149
|
Return a dictionary where:
|
132
150
|
- keys are the target indices
|
133
151
|
- values are lists of the mixids that match the target index
|
134
152
|
"""
|
135
|
-
return get_mixids_from_mixture_field_predicate(mixdb=mixdb,
|
136
|
-
mixids=mixids,
|
137
|
-
field='target_ids',
|
138
|
-
predicate=predicate)
|
153
|
+
return get_mixids_from_mixture_field_predicate(mixdb=mixdb, mixids=mixids, field="target_ids", predicate=predicate)
|
139
154
|
|
140
155
|
|
141
|
-
def get_mixids_from_snr(
|
142
|
-
|
143
|
-
|
156
|
+
def get_mixids_from_snr(
|
157
|
+
mixdb: MixtureDatabase,
|
158
|
+
mixids: GeneralizedIDs = "*",
|
159
|
+
predicate: Callable[[Any], bool] | None = None,
|
160
|
+
) -> dict[float, list[int]]:
|
144
161
|
"""
|
145
162
|
Generate mixids based on an SNR predicate
|
146
163
|
Return a dictionary where:
|
@@ -155,46 +172,70 @@ def get_mixids_from_snr(mixdb: MixtureDatabase,
|
|
155
172
|
snrs = [float(snr) for snr in mixdb.all_snrs if not snr.is_random]
|
156
173
|
|
157
174
|
if predicate is None:
|
175
|
+
|
158
176
|
def predicate(_: Any) -> bool:
|
159
177
|
return True
|
160
178
|
|
161
179
|
# Get only the SNRs of interest (filter on predicate)
|
162
180
|
snrs = [snr for snr in snrs if predicate(snr)]
|
163
181
|
|
164
|
-
result = {}
|
182
|
+
result: dict[float, list[int]] = {}
|
165
183
|
for snr in snrs:
|
166
184
|
# Get a list of mixids for each SNR
|
167
|
-
result[snr] = sorted(
|
168
|
-
[i for i, mixture in enumerate(mixdb.mixtures) if mixture.snr == snr and i in mixid_out])
|
185
|
+
result[snr] = sorted([i for i, mixture in enumerate(mixdb.mixtures) if mixture.snr == snr and i in mixid_out])
|
169
186
|
|
170
187
|
return result
|
171
188
|
|
172
189
|
|
173
|
-
def
|
174
|
-
|
175
|
-
|
190
|
+
def get_mixids_from_class_indices(
|
191
|
+
mixdb: MixtureDatabase,
|
192
|
+
mixids: GeneralizedIDs = "*",
|
193
|
+
predicate: Callable[[Any], bool] | None = None,
|
194
|
+
) -> dict[int, list[int]]:
|
176
195
|
"""
|
177
|
-
Generate mixids based on a
|
196
|
+
Generate mixids based on a class index predicate
|
178
197
|
Return a dictionary where:
|
179
|
-
- keys are the
|
180
|
-
- values are lists of the mixids that match the
|
198
|
+
- keys are the class indices
|
199
|
+
- values are lists of the mixids that match the class index
|
181
200
|
"""
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
201
|
+
mixid_out = mixdb.mixids_to_list(mixids)
|
202
|
+
|
203
|
+
if predicate is None:
|
204
|
+
|
205
|
+
def predicate(_: Any) -> bool:
|
206
|
+
return True
|
207
|
+
|
208
|
+
criteria_set = set()
|
209
|
+
for m_id in mixid_out:
|
210
|
+
class_indices = mixdb.mixture_class_indices(m_id)
|
211
|
+
for class_index in class_indices:
|
212
|
+
if predicate(class_index):
|
213
|
+
criteria_set.add(class_index)
|
214
|
+
criteria = sorted(criteria_set)
|
215
|
+
|
216
|
+
result: dict[int, list[int]] = {}
|
217
|
+
for criterion in criteria:
|
218
|
+
result[criterion] = []
|
219
|
+
for m_id in mixid_out:
|
220
|
+
class_indices = mixdb.mixture_class_indices(m_id)
|
221
|
+
for class_index in class_indices:
|
222
|
+
if class_index == criterion:
|
223
|
+
result[criterion].append(m_id)
|
224
|
+
|
225
|
+
return result
|
186
226
|
|
187
227
|
|
188
|
-
def get_mixids_from_truth_function(
|
189
|
-
|
190
|
-
|
228
|
+
def get_mixids_from_truth_function(
|
229
|
+
mixdb: MixtureDatabase,
|
230
|
+
mixids: GeneralizedIDs = "*",
|
231
|
+
predicate: Callable[[Any], bool] | None = None,
|
232
|
+
) -> dict[int, list[int]]:
|
191
233
|
"""
|
192
234
|
Generate mixids based on a truth function predicate
|
193
235
|
Return a dictionary where:
|
194
236
|
- keys are the truth functions
|
195
237
|
- values are lists of the mixids that match the truth function
|
196
238
|
"""
|
197
|
-
return
|
198
|
-
|
199
|
-
|
200
|
-
predicate=predicate)
|
239
|
+
return get_mixids_from_truth_configs_field_predicate(
|
240
|
+
mixdb=mixdb, mixids=mixids, field="function", predicate=predicate
|
241
|
+
)
|
sonusai/speech/__init__.py
CHANGED
sonusai/speech/l2arctic.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
import string
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Optional
|
5
4
|
|
6
5
|
from .types import TimeAlignedType
|
7
6
|
|
@@ -9,54 +8,54 @@ from .types import TimeAlignedType
|
|
9
8
|
def _get_duration(name: str) -> float:
|
10
9
|
import soundfile
|
11
10
|
|
12
|
-
from sonusai import SonusAIError
|
13
|
-
|
14
11
|
try:
|
15
12
|
return soundfile.info(name).duration
|
16
13
|
except Exception as e:
|
17
|
-
raise
|
14
|
+
raise OSError(f"Error reading {name}: {e}") from e
|
18
15
|
|
19
16
|
|
20
|
-
def load_text(audio: str | os.PathLike[str]) ->
|
17
|
+
def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
|
21
18
|
"""Load time-aligned text data given a L2-ARCTIC audio file.
|
22
19
|
|
23
20
|
:param audio: Path to the L2-ARCTIC audio file.
|
24
21
|
:return: A TimeAlignedType object.
|
25
22
|
"""
|
26
|
-
file = Path(audio).parent.parent /
|
23
|
+
file = Path(audio).parent.parent / "transcript" / (Path(audio).stem + ".txt")
|
27
24
|
if not os.path.exists(file):
|
28
25
|
return None
|
29
26
|
|
30
|
-
with open(file,
|
27
|
+
with open(file, encoding="utf-8") as f:
|
31
28
|
line = f.read()
|
32
29
|
|
33
|
-
return TimeAlignedType(
|
34
|
-
|
35
|
-
|
30
|
+
return TimeAlignedType(
|
31
|
+
0,
|
32
|
+
_get_duration(str(audio)),
|
33
|
+
line.strip().lower().translate(str.maketrans("", "", string.punctuation)),
|
34
|
+
)
|
36
35
|
|
37
36
|
|
38
|
-
def load_words(audio: str | os.PathLike[str]) ->
|
37
|
+
def load_words(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
|
39
38
|
"""Load time-aligned word data given a L2-ARCTIC audio file.
|
40
39
|
|
41
40
|
:param audio: Path to the L2-ARCTIC audio file.
|
42
41
|
:return: A list of TimeAlignedType objects.
|
43
42
|
"""
|
44
|
-
return _load_ta(audio,
|
43
|
+
return _load_ta(audio, "words")
|
45
44
|
|
46
45
|
|
47
|
-
def load_phonemes(audio: str | os.PathLike[str]) ->
|
46
|
+
def load_phonemes(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
|
48
47
|
"""Load time-aligned phonemes data given a L2-ARCTIC audio file.
|
49
48
|
|
50
49
|
:param audio: Path to the L2-ARCTIC audio file.
|
51
50
|
:return: A list of TimeAlignedType objects.
|
52
51
|
"""
|
53
|
-
return _load_ta(audio,
|
52
|
+
return _load_ta(audio, "phones")
|
54
53
|
|
55
54
|
|
56
|
-
def _load_ta(audio: str | os.PathLike[str], tier: str) ->
|
55
|
+
def _load_ta(audio: str | os.PathLike[str], tier: str) -> list[TimeAlignedType] | None:
|
57
56
|
from praatio import textgrid
|
58
57
|
|
59
|
-
file = Path(audio).parent.parent /
|
58
|
+
file = Path(audio).parent.parent / "textgrid" / (Path(audio).stem + ".TextGrid")
|
60
59
|
if not os.path.exists(file):
|
61
60
|
return None
|
62
61
|
|
@@ -71,7 +70,9 @@ def _load_ta(audio: str | os.PathLike[str], tier: str) -> Optional[list[TimeAlig
|
|
71
70
|
return entries
|
72
71
|
|
73
72
|
|
74
|
-
def load_annotations(
|
73
|
+
def load_annotations(
|
74
|
+
audio: str | os.PathLike[str],
|
75
|
+
) -> dict[str, list[TimeAlignedType]] | None:
|
75
76
|
"""Load time-aligned annotation data given a L2-ARCTIC audio file.
|
76
77
|
|
77
78
|
:param audio: Path to the L2-ARCTIC audio file.
|
@@ -79,7 +80,7 @@ def load_annotations(audio: str | os.PathLike[str]) -> Optional[dict[str, list[T
|
|
79
80
|
"""
|
80
81
|
from praatio import textgrid
|
81
82
|
|
82
|
-
file = Path(audio).parent.parent /
|
83
|
+
file = Path(audio).parent.parent / "annotation" / (Path(audio).stem + ".TextGrid")
|
83
84
|
if not os.path.exists(file):
|
84
85
|
return None
|
85
86
|
|
@@ -96,21 +97,21 @@ def load_annotations(audio: str | os.PathLike[str]) -> Optional[dict[str, list[T
|
|
96
97
|
|
97
98
|
def load_speakers(input_dir: Path) -> dict:
|
98
99
|
speakers = {}
|
99
|
-
with open(input_dir /
|
100
|
+
with open(input_dir / "readme-download.txt") as file:
|
100
101
|
processing = False
|
101
102
|
for line in file:
|
102
|
-
if not processing and line.startswith(
|
103
|
+
if not processing and line.startswith("|---|"):
|
103
104
|
processing = True
|
104
105
|
continue
|
105
106
|
|
106
107
|
if processing:
|
107
|
-
if line.startswith(
|
108
|
+
if line.startswith("|**Total**|"):
|
108
109
|
break
|
109
110
|
else:
|
110
|
-
fields = line.strip().split(
|
111
|
+
fields = line.strip().split("|")
|
111
112
|
speaker_id = fields[1]
|
112
113
|
gender = fields[2]
|
113
114
|
dialect = fields[3]
|
114
|
-
speakers[speaker_id] = {
|
115
|
+
speakers[speaker_id] = {"gender": gender, "dialect": dialect}
|
115
116
|
|
116
117
|
return speakers
|