sonusai 0.18.9__py3-none-any.whl → 0.19.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +81 -91
- sonusai/genmetrics.py +51 -61
- sonusai/genmix.py +105 -115
- sonusai/genmixdb.py +201 -174
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +16 -18
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +20 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +40 -38
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +669 -477
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +58 -101
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +41 -30
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/METADATA +20 -21
- sonusai-0.19.6.dist-info/RECORD +125 -0
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.9.dist-info/RECORD +0 -125
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/entry_points.txt +0 -0
sonusai/mixture/helpers.py
CHANGED
@@ -1,17 +1,13 @@
|
|
1
|
-
from
|
2
|
-
from
|
1
|
+
from pyaaware import ForwardTransform
|
2
|
+
from pyaaware import InverseTransform
|
3
3
|
|
4
|
-
from praatio.utilities.constants import Interval
|
5
|
-
|
6
|
-
from sonusai import ForwardTransform
|
7
|
-
from sonusai import InverseTransform
|
8
|
-
from sonusai.mixture import EnergyT
|
9
4
|
from sonusai.mixture.datatypes import AudioF
|
10
|
-
from sonusai.mixture.datatypes import AudioT
|
11
5
|
from sonusai.mixture.datatypes import AudiosT
|
6
|
+
from sonusai.mixture.datatypes import AudioT
|
12
7
|
from sonusai.mixture.datatypes import Augmentation
|
13
8
|
from sonusai.mixture.datatypes import AugmentationRules
|
14
9
|
from sonusai.mixture.datatypes import Augmentations
|
10
|
+
from sonusai.mixture.datatypes import EnergyT
|
15
11
|
from sonusai.mixture.datatypes import Feature
|
16
12
|
from sonusai.mixture.datatypes import FeatureGeneratorConfig
|
17
13
|
from sonusai.mixture.datatypes import FeatureGeneratorInfo
|
@@ -25,37 +21,33 @@ from sonusai.mixture.datatypes import Target
|
|
25
21
|
from sonusai.mixture.datatypes import TargetFiles
|
26
22
|
from sonusai.mixture.datatypes import Targets
|
27
23
|
from sonusai.mixture.datatypes import TransformConfig
|
28
|
-
from sonusai.mixture.datatypes import
|
24
|
+
from sonusai.mixture.datatypes import TruthDict
|
29
25
|
from sonusai.mixture.db_datatypes import MixtureRecord
|
30
26
|
from sonusai.mixture.db_datatypes import TargetRecord
|
31
27
|
from sonusai.mixture.mixdb import MixtureDatabase
|
32
28
|
|
33
29
|
|
34
|
-
def generic_ids_to_list(num_ids: int, ids: GeneralizedIDs =
|
30
|
+
def generic_ids_to_list(num_ids: int, ids: GeneralizedIDs = "*") -> list[int]:
|
35
31
|
"""Resolve generalized IDs to a list of integers
|
36
32
|
|
37
33
|
:param num_ids: Total number of indices
|
38
34
|
:param ids: Generalized IDs
|
39
35
|
:return: List of ID integers
|
40
36
|
"""
|
41
|
-
from sonusai import SonusAIError
|
42
|
-
|
43
37
|
all_ids = list(range(num_ids))
|
44
38
|
|
45
|
-
if ids is None:
|
46
|
-
return all_ids
|
47
|
-
|
48
39
|
if isinstance(ids, str):
|
49
|
-
if ids ==
|
40
|
+
if ids == "*":
|
50
41
|
return all_ids
|
51
42
|
|
52
43
|
try:
|
53
|
-
result = eval(f
|
54
|
-
if
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
44
|
+
result = eval(f"{all_ids}[{ids}]") # noqa: S307
|
45
|
+
if isinstance(result, list):
|
46
|
+
return result
|
47
|
+
else:
|
48
|
+
return [result]
|
49
|
+
except NameError as e:
|
50
|
+
raise ValueError(f"Empty ids {ids}: {e}") from e
|
59
51
|
|
60
52
|
if isinstance(ids, range):
|
61
53
|
result = list(ids)
|
@@ -65,15 +57,17 @@ def generic_ids_to_list(num_ids: int, ids: GeneralizedIDs = None) -> list[int]:
|
|
65
57
|
result = ids
|
66
58
|
|
67
59
|
if not all(isinstance(x, int) and 0 <= x < num_ids for x in result):
|
68
|
-
raise
|
60
|
+
raise ValueError(f"Invalid entries in ids of {ids}")
|
69
61
|
|
70
62
|
if not result:
|
71
|
-
raise
|
63
|
+
raise ValueError(f"Empty ids {ids}")
|
72
64
|
|
73
65
|
return result
|
74
66
|
|
75
67
|
|
76
|
-
def get_feature_generator_info(
|
68
|
+
def get_feature_generator_info(
|
69
|
+
fg_config: FeatureGeneratorConfig,
|
70
|
+
) -> FeatureGeneratorInfo:
|
77
71
|
from dataclasses import asdict
|
78
72
|
|
79
73
|
from pyaaware import FeatureGenerator
|
@@ -88,49 +82,36 @@ def get_feature_generator_info(fg_config: FeatureGeneratorConfig) -> FeatureGene
|
|
88
82
|
stride=fg.stride,
|
89
83
|
step=fg.step,
|
90
84
|
feature_parameters=fg.feature_parameters,
|
91
|
-
ft_config=TransformConfig(
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
85
|
+
ft_config=TransformConfig(
|
86
|
+
length=fg.ftransform_length,
|
87
|
+
overlap=fg.ftransform_overlap,
|
88
|
+
bin_start=fg.bin_start,
|
89
|
+
bin_end=fg.bin_end,
|
90
|
+
ttype=fg.ftransform_ttype,
|
91
|
+
),
|
92
|
+
eft_config=TransformConfig(
|
93
|
+
length=fg.eftransform_length,
|
94
|
+
overlap=fg.eftransform_overlap,
|
95
|
+
bin_start=fg.bin_start,
|
96
|
+
bin_end=fg.bin_end,
|
97
|
+
ttype=fg.eftransform_ttype,
|
98
|
+
),
|
99
|
+
it_config=TransformConfig(
|
100
|
+
length=fg.itransform_length,
|
101
|
+
overlap=fg.itransform_overlap,
|
102
|
+
bin_start=fg.bin_start,
|
103
|
+
bin_end=fg.bin_end,
|
104
|
+
ttype=fg.itransform_ttype,
|
105
|
+
),
|
106
106
|
)
|
107
107
|
|
108
108
|
|
109
|
-
def
|
110
|
-
|
111
|
-
|
112
|
-
"""Write mixture data to a mixture HDF5 file
|
113
|
-
|
114
|
-
:param mixdb: Mixture database
|
115
|
-
:param mixture: Mixture record
|
116
|
-
:param items: Tuple(s) of (name, data)
|
117
|
-
"""
|
118
|
-
import h5py
|
119
|
-
|
120
|
-
if not isinstance(items, list):
|
121
|
-
items = [items]
|
122
|
-
|
123
|
-
name = mixdb.location_filename(mixture.name)
|
124
|
-
with h5py.File(name=name, mode='a') as f:
|
125
|
-
for item in items:
|
126
|
-
if item[0] in f:
|
127
|
-
del f[item[0]]
|
128
|
-
f.create_dataset(name=item[0], data=item[1])
|
109
|
+
def mixture_all_speech_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> list[dict[str, SpeechMetadata]]:
|
110
|
+
"""Get a list of all speech metadata for the given mixture"""
|
111
|
+
from praatio.utilities.constants import Interval
|
129
112
|
|
113
|
+
from .datatypes import SpeechMetadata
|
130
114
|
|
131
|
-
def mixture_all_speech_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> list[dict[str, SpeechMetadata]]:
|
132
|
-
"""Get a list of all speech metadata for the given mixture
|
133
|
-
"""
|
134
115
|
results: list[dict[str, SpeechMetadata]] = []
|
135
116
|
for target in mixture.targets:
|
136
117
|
data: dict[str, SpeechMetadata] = {}
|
@@ -144,9 +125,13 @@ def mixture_all_speech_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> lis
|
|
144
125
|
entries = []
|
145
126
|
for entry in item:
|
146
127
|
if target.augmentation.tempo is not None:
|
147
|
-
entries.append(
|
148
|
-
|
149
|
-
|
128
|
+
entries.append(
|
129
|
+
Interval(
|
130
|
+
entry.start / target.augmentation.tempo,
|
131
|
+
entry.end / target.augmentation.tempo,
|
132
|
+
entry.label,
|
133
|
+
)
|
134
|
+
)
|
150
135
|
else:
|
151
136
|
entries.append(entry)
|
152
137
|
data[tier] = entries
|
@@ -164,41 +149,32 @@ def mixture_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> str:
|
|
164
149
|
:param mixture: Mixture record
|
165
150
|
:return: String of metadata
|
166
151
|
"""
|
167
|
-
metadata =
|
152
|
+
metadata = ""
|
168
153
|
speech_metadata = mixture_all_speech_metadata(mixdb, mixture)
|
169
154
|
for mi, target in enumerate(mixture.targets):
|
170
155
|
target_file = mixdb.target_file(target.file_id)
|
171
156
|
target_augmentation = target.augmentation
|
172
|
-
metadata += f
|
173
|
-
metadata += f
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
metadata += f'target {mi} truth index {tsi}: {truth_settings[tsi].index}\n'
|
183
|
-
metadata += f'target {mi} truth function {tsi}: {truth_settings[tsi].function}\n'
|
184
|
-
metadata += f'target {mi} truth config {tsi}: {truth_settings[tsi].config}\n'
|
185
|
-
for key in speech_metadata[mi].keys():
|
186
|
-
metadata += f'target {mi} speech {key}: {speech_metadata[mi][key]}\n'
|
157
|
+
metadata += f"target {mi} name: {target_file.name}\n"
|
158
|
+
metadata += f"target {mi} augmentation: {target.augmentation.to_dict()}\n"
|
159
|
+
metadata += f"target {mi} ir: {mixdb.impulse_response_file(target_augmentation.ir)}\n"
|
160
|
+
metadata += f"target {mi} target_gain: {target.gain}\n"
|
161
|
+
metadata += f"target {mi} class indices: {target_file.class_indices}\n"
|
162
|
+
for key in target_file.truth_configs:
|
163
|
+
metadata += f"target {mi} truth '{key}' function: {target_file.truth_configs[key].function}\n"
|
164
|
+
metadata += f"target {mi} truth '{key}' config: {target_file.truth_configs[key].config}\n"
|
165
|
+
for key in speech_metadata[mi]:
|
166
|
+
metadata += f"target {mi} speech {key}: {speech_metadata[mi][key]}\n"
|
187
167
|
noise = mixdb.noise_file(mixture.noise.file_id)
|
188
168
|
noise_augmentation = mixture.noise.augmentation
|
189
|
-
metadata += f
|
190
|
-
metadata += f
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
metadata += f
|
196
|
-
metadata += f
|
197
|
-
metadata += f
|
198
|
-
metadata += f'random_snr: {mixture.snr.is_random}\n'
|
199
|
-
metadata += f'samples: {mixture.samples}\n'
|
200
|
-
metadata += f'target_snr_gain: {float(mixture.target_snr_gain)}\n'
|
201
|
-
metadata += f'noise_snr_gain: {float(mixture.noise_snr_gain)}\n'
|
169
|
+
metadata += f"noise name: {noise.name}\n"
|
170
|
+
metadata += f"noise augmentation: {noise_augmentation.to_dict()}\n"
|
171
|
+
metadata += f"noise ir: {mixdb.impulse_response_file(noise_augmentation.ir)}\n"
|
172
|
+
metadata += f"noise offset: {mixture.noise.offset}\n"
|
173
|
+
metadata += f"snr: {mixture.snr}\n"
|
174
|
+
metadata += f"random_snr: {mixture.snr.is_random}\n"
|
175
|
+
metadata += f"samples: {mixture.samples}\n"
|
176
|
+
metadata += f"target_snr_gain: {float(mixture.target_snr_gain)}\n"
|
177
|
+
metadata += f"noise_snr_gain: {float(mixture.noise_snr_gain)}\n"
|
202
178
|
|
203
179
|
return metadata
|
204
180
|
|
@@ -209,47 +185,54 @@ def write_mixture_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> None:
|
|
209
185
|
:param mixdb: Mixture database
|
210
186
|
:param mixture: Mixture record
|
211
187
|
"""
|
212
|
-
from os.path import
|
188
|
+
from os.path import join
|
213
189
|
|
214
|
-
name = mixdb.
|
215
|
-
with open(file=name, mode=
|
190
|
+
name = join(mixdb.location, "mixture", mixture.name, "metadata.txt")
|
191
|
+
with open(file=name, mode="w") as f:
|
216
192
|
f.write(mixture_metadata(mixdb, mixture))
|
217
193
|
|
218
194
|
|
219
|
-
def from_mixture(
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
195
|
+
def from_mixture(
|
196
|
+
mixture: Mixture,
|
197
|
+
) -> tuple[str, int, str, int, float, bool, float, int, int, int, float]:
|
198
|
+
return (
|
199
|
+
mixture.name,
|
200
|
+
mixture.noise.file_id,
|
201
|
+
mixture.noise.augmentation.to_json(),
|
202
|
+
mixture.noise.offset,
|
203
|
+
mixture.noise_snr_gain,
|
204
|
+
mixture.snr.is_random,
|
205
|
+
mixture.snr,
|
206
|
+
mixture.samples,
|
207
|
+
mixture.spectral_mask_id,
|
208
|
+
mixture.spectral_mask_seed,
|
209
|
+
mixture.target_snr_gain,
|
210
|
+
)
|
231
211
|
|
232
212
|
|
233
213
|
def to_mixture(entry: MixtureRecord, targets: Targets) -> Mixture:
|
234
214
|
import json
|
235
215
|
|
236
216
|
from sonusai.utils import dataclass_from_dict
|
237
|
-
|
238
|
-
from .datatypes import Mixture
|
217
|
+
|
239
218
|
from .datatypes import Noise
|
240
219
|
from .datatypes import UniversalSNR
|
241
220
|
|
242
|
-
return Mixture(
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
221
|
+
return Mixture(
|
222
|
+
targets=targets,
|
223
|
+
name=entry.name,
|
224
|
+
noise=Noise(
|
225
|
+
file_id=entry.noise_file_id,
|
226
|
+
augmentation=dataclass_from_dict(Augmentation, json.loads(entry.noise_augmentation)),
|
227
|
+
offset=entry.noise_offset,
|
228
|
+
),
|
229
|
+
noise_snr_gain=entry.noise_snr_gain,
|
230
|
+
snr=UniversalSNR(is_random=entry.random_snr, value=entry.snr),
|
231
|
+
samples=entry.samples,
|
232
|
+
spectral_mask_id=entry.spectral_mask_id,
|
233
|
+
spectral_mask_seed=entry.spectral_mask_seed,
|
234
|
+
target_snr_gain=entry.target_snr_gain,
|
235
|
+
)
|
253
236
|
|
254
237
|
|
255
238
|
def from_target(target: Target) -> tuple[int, str, float]:
|
@@ -260,105 +243,67 @@ def to_target(entry: TargetRecord) -> Target:
|
|
260
243
|
import json
|
261
244
|
|
262
245
|
from sonusai.utils import dataclass_from_dict
|
246
|
+
|
263
247
|
from .datatypes import Augmentation
|
264
248
|
from .datatypes import Target
|
265
249
|
|
266
|
-
return Target(
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
def read_mixture_data(name: str, items: list[str] | str) -> Any:
|
272
|
-
"""Read mixture data from a mixture HDF5 file
|
273
|
-
|
274
|
-
:param name: Mixture file name
|
275
|
-
:param items: String(s) of dataset(s) to retrieve
|
276
|
-
:return: Data (or tuple of data)
|
277
|
-
"""
|
278
|
-
from os.path import exists
|
279
|
-
|
280
|
-
import h5py
|
281
|
-
import numpy as np
|
282
|
-
|
283
|
-
from sonusai import SonusAIError
|
284
|
-
|
285
|
-
def _get_dataset(file: h5py.File, d_name: str) -> Any:
|
286
|
-
if d_name in file:
|
287
|
-
data = np.array(file[d_name])
|
288
|
-
if data.size == 1:
|
289
|
-
item = data.item()
|
290
|
-
if isinstance(item, bytes):
|
291
|
-
return item.decode('utf-8')
|
292
|
-
return item
|
293
|
-
return data
|
294
|
-
return None
|
295
|
-
|
296
|
-
if not isinstance(items, list):
|
297
|
-
items = [items]
|
298
|
-
|
299
|
-
if exists(name):
|
300
|
-
try:
|
301
|
-
with h5py.File(name, 'r') as f:
|
302
|
-
result = ([_get_dataset(f, item) for item in items])
|
303
|
-
except Exception as e:
|
304
|
-
raise SonusAIError(f'Error reading {name}: {e}')
|
305
|
-
else:
|
306
|
-
result = ([None for _ in items])
|
307
|
-
|
308
|
-
if len(items) == 1:
|
309
|
-
result = result[0]
|
310
|
-
|
311
|
-
return result
|
250
|
+
return Target(
|
251
|
+
file_id=entry.file_id,
|
252
|
+
augmentation=dataclass_from_dict(Augmentation, json.loads(entry.augmentation)),
|
253
|
+
gain=entry.gain,
|
254
|
+
)
|
312
255
|
|
313
256
|
|
314
|
-
def
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
257
|
+
def get_truth(
|
258
|
+
mixdb: MixtureDatabase,
|
259
|
+
mixture: Mixture,
|
260
|
+
targets_audio: AudiosT,
|
261
|
+
noise_audio: AudioT,
|
262
|
+
mixture_audio: AudioT,
|
263
|
+
) -> TruthDict:
|
264
|
+
"""Get the truth data for the given mixture record
|
320
265
|
|
321
266
|
:param mixdb: Mixture database
|
322
267
|
:param mixture: Mixture record
|
323
268
|
:param targets_audio: List of augmented target audio data (one per target in the mixup) for the given mixture ID
|
324
269
|
:param noise_audio: Augmented noise audio data for the given mixture ID
|
325
270
|
:param mixture_audio: Mixture audio data for the given mixture ID
|
326
|
-
:return:
|
271
|
+
:return: truth data
|
327
272
|
"""
|
328
|
-
|
329
|
-
|
330
|
-
from sonusai import SonusAIError
|
331
|
-
from .datatypes import TruthFunctionConfig
|
273
|
+
from .datatypes import TruthDict
|
332
274
|
from .truth import truth_function
|
333
275
|
|
334
276
|
if not all(len(target) == mixture.samples for target in targets_audio):
|
335
|
-
raise
|
277
|
+
raise ValueError("Lengths of targets do not match length of mixture")
|
336
278
|
|
337
279
|
if len(noise_audio) != mixture.samples:
|
338
|
-
raise
|
280
|
+
raise ValueError("Length of noise does not match length of mixture")
|
339
281
|
|
340
282
|
# TODO: Need to understand how to do this correctly for mixup and target_mixture_f truth
|
341
|
-
|
283
|
+
if len(targets_audio) != 1:
|
284
|
+
raise NotImplementedError("mixup is not implemented")
|
285
|
+
|
286
|
+
truth: TruthDict = {}
|
342
287
|
for idx in range(len(targets_audio)):
|
343
|
-
|
344
|
-
|
288
|
+
target_file = mixdb.target_file(mixture.targets[idx].file_id)
|
289
|
+
for key, value in target_file.truth_configs.items():
|
290
|
+
truth[key] = truth_function(
|
291
|
+
target_audio=targets_audio[idx],
|
292
|
+
noise_audio=noise_audio,
|
293
|
+
mixture_audio=mixture_audio,
|
294
|
+
config=value,
|
345
295
|
feature=mixdb.feature,
|
346
|
-
index=truth_setting.index,
|
347
|
-
function=truth_setting.function,
|
348
|
-
config=truth_setting.config,
|
349
296
|
num_classes=mixdb.num_classes,
|
350
|
-
|
351
|
-
target_gain=mixture.targets[idx].gain * mixture.target_snr_gain
|
297
|
+
class_indices=target_file.class_indices,
|
298
|
+
target_gain=mixture.targets[idx].gain * mixture.target_snr_gain,
|
352
299
|
)
|
353
|
-
truth_t += truth_function(target_audio=targets_audio[idx],
|
354
|
-
noise_audio=noise_audio,
|
355
|
-
mixture_audio=mixture_audio,
|
356
|
-
config=config)
|
357
300
|
|
358
|
-
return
|
301
|
+
return truth
|
359
302
|
|
360
303
|
|
361
|
-
def get_ft(
|
304
|
+
def get_ft(
|
305
|
+
mixdb: MixtureDatabase, mixture: Mixture, mixture_audio: AudioT, truth_t: TruthDict
|
306
|
+
) -> tuple[Feature, TruthDict]:
|
362
307
|
"""Get the feature and truth_f data for the given mixture record
|
363
308
|
|
364
309
|
:param mixdb: Mixture database
|
@@ -367,37 +312,19 @@ def get_ft(mixdb: MixtureDatabase, mixture: Mixture, mixture_audio: AudioT, trut
|
|
367
312
|
:param truth_t: truth_t for the given mixid
|
368
313
|
:return: Tuple of (feature, truth_f) data
|
369
314
|
"""
|
370
|
-
from dataclasses import asdict
|
371
315
|
|
372
|
-
import numpy as np
|
373
316
|
from pyaaware import FeatureGenerator
|
374
317
|
|
375
|
-
from .truth import
|
318
|
+
from .truth import truth_stride_reduction
|
376
319
|
|
377
320
|
mixture_f = get_mixture_f(mixdb=mixdb, mixture=mixture, mixture_audio=mixture_audio)
|
378
321
|
|
379
|
-
|
380
|
-
|
322
|
+
fg = FeatureGenerator(mixdb.fg_config.feature_mode, mixdb.fg_config.truth_parameters)
|
323
|
+
feature, truth_f = fg.execute_all(mixture_f, truth_t)
|
324
|
+
for name in truth_f:
|
325
|
+
truth_f[name] = truth_stride_reduction(truth_f[name], mixdb.truth_configs[name].stride_reduction)
|
381
326
|
|
382
|
-
feature
|
383
|
-
truth_f = np.empty((feature_frames, mixdb.num_classes), dtype=np.complex64)
|
384
|
-
|
385
|
-
fg = FeatureGenerator(**asdict(mixdb.fg_config))
|
386
|
-
feature_frame = 0
|
387
|
-
for transform_frame in range(transform_frames):
|
388
|
-
indices = slice(transform_frame * mixdb.ft_config.R, (transform_frame + 1) * mixdb.ft_config.R)
|
389
|
-
fg.execute(mixture_f[transform_frame],
|
390
|
-
truth_reduction(truth_t[indices], mixdb.truth_reduction_function))
|
391
|
-
|
392
|
-
if fg.eof():
|
393
|
-
feature[feature_frame] = fg.feature()
|
394
|
-
truth_f[feature_frame] = fg.truth()
|
395
|
-
feature_frame += 1
|
396
|
-
|
397
|
-
if np.isreal(truth_f).all():
|
398
|
-
return feature, truth_f.real
|
399
|
-
|
400
|
-
return feature, truth_f # type: ignore
|
327
|
+
return feature, truth_f
|
401
328
|
|
402
329
|
|
403
330
|
def get_segsnr(mixdb: MixtureDatabase, mixture: Mixture, target_audio: AudioT, noise: AudioT) -> Segsnr:
|
@@ -410,7 +337,7 @@ def get_segsnr(mixdb: MixtureDatabase, mixture: Mixture, target_audio: AudioT, n
|
|
410
337
|
:return: segsnr data
|
411
338
|
"""
|
412
339
|
segsnr_t = get_segsnr_t(mixdb=mixdb, mixture=mixture, target_audio=target_audio, noise_audio=noise)
|
413
|
-
return segsnr_t[0::mixdb.ft_config.
|
340
|
+
return segsnr_t[0 :: mixdb.ft_config.overlap]
|
414
341
|
|
415
342
|
|
416
343
|
def get_segsnr_t(mixdb: MixtureDatabase, mixture: Mixture, target_audio: AudioT, noise_audio: AudioT) -> Segsnr:
|
@@ -424,28 +351,29 @@ def get_segsnr_t(mixdb: MixtureDatabase, mixture: Mixture, target_audio: AudioT,
|
|
424
351
|
"""
|
425
352
|
import numpy as np
|
426
353
|
import torch
|
427
|
-
from
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
354
|
+
from pyaaware import ForwardTransform
|
355
|
+
|
356
|
+
fft = ForwardTransform(
|
357
|
+
length=mixdb.ft_config.length,
|
358
|
+
overlap=mixdb.ft_config.overlap,
|
359
|
+
bin_start=mixdb.ft_config.bin_start,
|
360
|
+
bin_end=mixdb.ft_config.bin_end,
|
361
|
+
ttype=mixdb.ft_config.ttype,
|
362
|
+
)
|
436
363
|
|
437
364
|
segsnr_t = np.empty(mixture.samples, dtype=np.float32)
|
438
365
|
|
439
366
|
target_energy = fft.execute_all(torch.from_numpy(target_audio))[1].numpy()
|
440
367
|
noise_energy = fft.execute_all(torch.from_numpy(noise_audio))[1].numpy()
|
441
368
|
|
442
|
-
offsets = range(0, mixture.samples, mixdb.ft_config.
|
369
|
+
offsets = range(0, mixture.samples, mixdb.ft_config.overlap)
|
443
370
|
if len(target_energy) != len(offsets):
|
444
|
-
raise
|
445
|
-
|
371
|
+
raise ValueError(
|
372
|
+
f"Number of frames in energy, {len(target_energy)}," f" is not number of frames in mixture, {len(offsets)}"
|
373
|
+
)
|
446
374
|
|
447
375
|
for idx, offset in enumerate(offsets):
|
448
|
-
indices = slice(offset, offset + mixdb.ft_config.
|
376
|
+
indices = slice(offset, offset + mixdb.ft_config.overlap)
|
449
377
|
|
450
378
|
if noise_energy[idx] == 0:
|
451
379
|
snr = np.float32(np.inf)
|
@@ -475,8 +403,9 @@ def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: AudiosT)
|
|
475
403
|
for idx, target in enumerate(targets_audio):
|
476
404
|
ir_idx = mixture.targets[idx].augmentation.ir
|
477
405
|
if ir_idx is not None:
|
478
|
-
targets_ir.append(
|
479
|
-
|
406
|
+
targets_ir.append(
|
407
|
+
apply_impulse_response(audio=target, ir=read_ir(mixdb.impulse_response_file(int(ir_idx))))
|
408
|
+
)
|
480
409
|
else:
|
481
410
|
targets_ir.append(target)
|
482
411
|
|
@@ -497,9 +426,11 @@ def get_mixture_f(mixdb: MixtureDatabase, mixture: Mixture, mixture_audio: Audio
|
|
497
426
|
mixture_f = forward_transform(mixture_audio, mixdb.ft_config)
|
498
427
|
|
499
428
|
if mixture.spectral_mask_id is not None:
|
500
|
-
mixture_f = apply_spectral_mask(
|
501
|
-
|
502
|
-
|
429
|
+
mixture_f = apply_spectral_mask(
|
430
|
+
audio_f=mixture_f,
|
431
|
+
spectral_mask=mixdb.spectral_mask(mixture.spectral_mask_id),
|
432
|
+
seed=mixture.spectral_mask_seed,
|
433
|
+
)
|
503
434
|
|
504
435
|
return mixture_f
|
505
436
|
|
@@ -527,14 +458,18 @@ def forward_transform(audio: AudioT, config: TransformConfig) -> AudioF:
|
|
527
458
|
:param config: Transform configuration
|
528
459
|
:return: Frequency domain data [frames, bins]
|
529
460
|
"""
|
530
|
-
from
|
531
|
-
|
532
|
-
audio_f, _ = get_transform_from_audio(
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
461
|
+
from pyaaware import ForwardTransform
|
462
|
+
|
463
|
+
audio_f, _ = get_transform_from_audio(
|
464
|
+
audio=audio,
|
465
|
+
transform=ForwardTransform(
|
466
|
+
length=config.length,
|
467
|
+
overlap=config.overlap,
|
468
|
+
bin_start=config.bin_start,
|
469
|
+
bin_end=config.bin_end,
|
470
|
+
ttype=config.ttype,
|
471
|
+
),
|
472
|
+
)
|
538
473
|
return audio_f
|
539
474
|
|
540
475
|
|
@@ -545,6 +480,7 @@ def get_audio_from_transform(data: AudioF, transform: InverseTransform) -> tuple
|
|
545
480
|
:param transform: InverseTransform object
|
546
481
|
:return: Time domain data [samples], Energy [frames]
|
547
482
|
"""
|
483
|
+
|
548
484
|
import torch
|
549
485
|
|
550
486
|
t, e = transform.execute_all(torch.from_numpy(data))
|
@@ -562,40 +498,44 @@ def inverse_transform(transform: AudioF, config: TransformConfig) -> AudioT:
|
|
562
498
|
:return: Time domain data [samples]
|
563
499
|
"""
|
564
500
|
import numpy as np
|
565
|
-
from
|
566
|
-
|
567
|
-
audio, _ = get_audio_from_transform(
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
501
|
+
from pyaaware import InverseTransform
|
502
|
+
|
503
|
+
audio, _ = get_audio_from_transform(
|
504
|
+
data=transform,
|
505
|
+
transform=InverseTransform(
|
506
|
+
length=config.length,
|
507
|
+
overlap=config.overlap,
|
508
|
+
bin_start=config.bin_start,
|
509
|
+
bin_end=config.bin_end,
|
510
|
+
ttype=config.ttype,
|
511
|
+
gain=np.float32(1),
|
512
|
+
),
|
513
|
+
)
|
574
514
|
return audio
|
575
515
|
|
576
516
|
|
577
517
|
def check_audio_files_exist(mixdb: MixtureDatabase) -> None:
|
578
|
-
"""Walk through all the noise and target audio files in a mixture database ensuring that they exist
|
579
|
-
"""
|
518
|
+
"""Walk through all the noise and target audio files in a mixture database ensuring that they exist"""
|
580
519
|
from os.path import exists
|
581
520
|
|
582
|
-
from sonusai import SonusAIError
|
583
521
|
from .tokenized_shell_vars import tokenized_expand
|
584
522
|
|
585
523
|
for noise in mixdb.noise_files:
|
586
524
|
file_name, _ = tokenized_expand(noise.name)
|
587
525
|
if not exists(file_name):
|
588
|
-
raise
|
526
|
+
raise OSError(f"Could not find {file_name}")
|
589
527
|
|
590
528
|
for target in mixdb.target_files:
|
591
529
|
file_name, _ = tokenized_expand(target.name)
|
592
530
|
if not exists(file_name):
|
593
|
-
raise
|
531
|
+
raise OSError(f"Could not find {file_name}")
|
594
532
|
|
595
533
|
|
596
|
-
def augmented_target_samples(
|
597
|
-
|
598
|
-
|
534
|
+
def augmented_target_samples(
|
535
|
+
target_files: TargetFiles,
|
536
|
+
target_augmentations: AugmentationRules,
|
537
|
+
feature_step_samples: int,
|
538
|
+
) -> int:
|
599
539
|
from itertools import product
|
600
540
|
|
601
541
|
from .augmentation import estimate_augmented_length_from_length
|
@@ -603,10 +543,16 @@ def augmented_target_samples(target_files: TargetFiles,
|
|
603
543
|
target_ids = list(range(len(target_files)))
|
604
544
|
target_augmentation_ids = list(range(len(target_augmentations)))
|
605
545
|
it = list(product(*[target_ids, target_augmentation_ids]))
|
606
|
-
return sum(
|
607
|
-
|
608
|
-
|
609
|
-
|
546
|
+
return sum(
|
547
|
+
[
|
548
|
+
estimate_augmented_length_from_length(
|
549
|
+
length=target_files[fi].samples,
|
550
|
+
tempo=target_augmentations[ai].tempo,
|
551
|
+
frame_length=feature_step_samples,
|
552
|
+
)
|
553
|
+
for fi, ai in it
|
554
|
+
]
|
555
|
+
)
|
610
556
|
|
611
557
|
|
612
558
|
def augmented_noise_samples(noise_files: NoiseFiles, noise_augmentations: Augmentations) -> int:
|
@@ -621,18 +567,17 @@ def augmented_noise_samples(noise_files: NoiseFiles, noise_augmentations: Augmen
|
|
621
567
|
def augmented_noise_length(noise_file: NoiseFile, noise_augmentation: Augmentation) -> int:
|
622
568
|
from .augmentation import estimate_augmented_length_from_length
|
623
569
|
|
624
|
-
return estimate_augmented_length_from_length(length=noise_file.samples,
|
625
|
-
tempo=noise_augmentation.tempo)
|
570
|
+
return estimate_augmented_length_from_length(length=noise_file.samples, tempo=noise_augmentation.tempo)
|
626
571
|
|
627
572
|
|
628
|
-
def get_textgrid_tier_from_target_file(target_file: str, tier: str) ->
|
573
|
+
def get_textgrid_tier_from_target_file(target_file: str, tier: str) -> SpeechMetadata | None:
|
629
574
|
from pathlib import Path
|
630
575
|
|
631
576
|
from praatio import textgrid
|
632
577
|
|
633
578
|
from .tokenized_shell_vars import tokenized_expand
|
634
579
|
|
635
|
-
textgrid_file = Path(tokenized_expand(target_file)[0]).with_suffix(
|
580
|
+
textgrid_file = Path(tokenized_expand(target_file)[0]).with_suffix(".TextGrid")
|
636
581
|
if not textgrid_file.exists():
|
637
582
|
return None
|
638
583
|
|