sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +16 -3
- sonusai/audiofe.py +241 -77
- sonusai/calc_metric_spenh.py +71 -73
- sonusai/config/__init__.py +3 -0
- sonusai/config/config.py +61 -0
- sonusai/config/config.yml +20 -0
- sonusai/config/constants.py +8 -0
- sonusai/constants.py +11 -0
- sonusai/data/genmixdb.yml +21 -36
- sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
- sonusai/deprecated/plot.py +4 -5
- sonusai/doc/doc.py +4 -4
- sonusai/doc.py +11 -4
- sonusai/genft.py +43 -45
- sonusai/genmetrics.py +25 -19
- sonusai/genmix.py +54 -82
- sonusai/genmixdb.py +88 -264
- sonusai/ir_metric.py +30 -34
- sonusai/lsdb.py +41 -48
- sonusai/main.py +15 -22
- sonusai/metrics/calc_audio_stats.py +4 -293
- sonusai/metrics/calc_class_weights.py +4 -4
- sonusai/metrics/calc_optimal_thresholds.py +8 -5
- sonusai/metrics/calc_pesq.py +2 -2
- sonusai/metrics/calc_segsnr_f.py +4 -4
- sonusai/metrics/calc_speech.py +25 -13
- sonusai/metrics/class_summary.py +7 -7
- sonusai/metrics/confusion_matrix_summary.py +5 -5
- sonusai/metrics/one_hot.py +4 -4
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/metrics_summary.py +38 -45
- sonusai/mixture/__init__.py +4 -104
- sonusai/mixture/audio.py +10 -39
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/config.py +251 -271
- sonusai/mixture/constants.py +35 -39
- sonusai/mixture/data_io.py +25 -36
- sonusai/mixture/db_datatypes.py +58 -22
- sonusai/mixture/effects.py +386 -0
- sonusai/mixture/feature.py +7 -11
- sonusai/mixture/generation.py +478 -628
- sonusai/mixture/helpers.py +82 -184
- sonusai/mixture/ir_delay.py +3 -4
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +6 -12
- sonusai/mixture/mixdb.py +910 -729
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +2 -2
- sonusai/mixture/truth.py +17 -15
- sonusai/mixture/truth_functions/crm.py +12 -12
- sonusai/mixture/truth_functions/energy.py +22 -22
- sonusai/mixture/truth_functions/file.py +5 -5
- sonusai/mixture/truth_functions/metadata.py +4 -4
- sonusai/mixture/truth_functions/metrics.py +4 -4
- sonusai/mixture/truth_functions/phoneme.py +3 -3
- sonusai/mixture/truth_functions/sed.py +11 -13
- sonusai/mixture/truth_functions/target.py +10 -10
- sonusai/mkwav.py +26 -29
- sonusai/onnx_predict.py +240 -88
- sonusai/queries/__init__.py +2 -2
- sonusai/queries/queries.py +38 -34
- sonusai/speech/librispeech.py +1 -1
- sonusai/speech/mcgill.py +1 -1
- sonusai/speech/timit.py +2 -2
- sonusai/summarize_metric_spenh.py +10 -17
- sonusai/utils/__init__.py +7 -1
- sonusai/utils/asl_p56.py +2 -2
- sonusai/utils/asr.py +2 -2
- sonusai/utils/asr_functions/aaware_whisper.py +4 -5
- sonusai/utils/choice.py +31 -0
- sonusai/utils/compress.py +1 -1
- sonusai/utils/dataclass_from_dict.py +19 -1
- sonusai/utils/energy_f.py +3 -3
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/onnx_utils.py +3 -17
- sonusai/utils/print_mixture_details.py +21 -19
- sonusai/utils/{temp_seed.py → rand.py} +3 -3
- sonusai/utils/read_predict_data.py +2 -2
- sonusai/utils/reshape.py +3 -3
- sonusai/utils/stratified_shuffle_split.py +3 -3
- sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
- sonusai/utils/write_audio.py +2 -2
- sonusai/vars.py +11 -4
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
- sonusai-1.0.2.dist-info/RECORD +138 -0
- sonusai/mixture/augmentation.py +0 -444
- sonusai/mixture/class_count.py +0 -15
- sonusai/mixture/eq_rule_is_valid.py +0 -45
- sonusai/mixture/target_class_balancing.py +0 -107
- sonusai/mixture/targets.py +0 -175
- sonusai-0.20.3.dist-info/RECORD +0 -128
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
sonusai/mixture/augmentation.py
DELETED
@@ -1,444 +0,0 @@
|
|
1
|
-
from sonusai.mixture.datatypes import AudioT
|
2
|
-
from sonusai.mixture.datatypes import Augmentation
|
3
|
-
from sonusai.mixture.datatypes import AugmentationEffects
|
4
|
-
from sonusai.mixture.datatypes import AugmentationRule
|
5
|
-
from sonusai.mixture.datatypes import ImpulseResponseData
|
6
|
-
from sonusai.mixture.datatypes import OptionalNumberStr
|
7
|
-
from sonusai.mixture.mixdb import MixtureDatabase
|
8
|
-
|
9
|
-
|
10
|
-
def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> list[AugmentationRule]:
|
11
|
-
"""Generate augmentation rules from list of input rules
|
12
|
-
|
13
|
-
:param rules: Dictionary of augmentation config rule[s]
|
14
|
-
:param num_ir: Number of impulse responses in config
|
15
|
-
:return: List of augmentation rules
|
16
|
-
"""
|
17
|
-
from sonusai.utils import dataclass_from_dict
|
18
|
-
|
19
|
-
from .datatypes import AugmentationRule
|
20
|
-
|
21
|
-
processed_rules: list[dict] = []
|
22
|
-
if not isinstance(rules, list):
|
23
|
-
rules = [rules]
|
24
|
-
|
25
|
-
for rule in rules:
|
26
|
-
rule = _parse_ir(rule, num_ir)
|
27
|
-
processed_rules = _expand_rules(expanded_rules=processed_rules, rule=rule)
|
28
|
-
|
29
|
-
return [dataclass_from_dict(AugmentationRule, processed_rule) for processed_rule in processed_rules] # pyright: ignore [reportReturnType]
|
30
|
-
|
31
|
-
|
32
|
-
def _expand_rules(expanded_rules: list[dict], rule: dict) -> list[dict]:
|
33
|
-
"""Expand rules
|
34
|
-
|
35
|
-
:param expanded_rules: Working list of expanded rules
|
36
|
-
:param rule: Rule to process
|
37
|
-
:return: List of expanded rules
|
38
|
-
"""
|
39
|
-
from copy import deepcopy
|
40
|
-
|
41
|
-
from sonusai.utils import convert_string_to_number
|
42
|
-
|
43
|
-
from .constants import VALID_AUGMENTATIONS
|
44
|
-
from .eq_rule_is_valid import eq_rule_is_valid
|
45
|
-
|
46
|
-
if "pre" not in rule:
|
47
|
-
raise ValueError("Rule must have 'pre' key")
|
48
|
-
|
49
|
-
if "post" not in rule:
|
50
|
-
rule["post"] = {}
|
51
|
-
|
52
|
-
for key in rule:
|
53
|
-
if key not in ("pre", "post", "mixup"):
|
54
|
-
raise ValueError(f"Invalid augmentation key: {key}")
|
55
|
-
|
56
|
-
if key in ("pre", "post"):
|
57
|
-
for k, v in list(rule[key].items()):
|
58
|
-
if v is None:
|
59
|
-
del rule[key][k]
|
60
|
-
|
61
|
-
# replace old 'eq' rule with new 'eq1' rule to allow both for backward compatibility
|
62
|
-
for key in rule:
|
63
|
-
rule[key] = {"eq1" if k == "eq" else k: v for k, v in rule[key].items()}
|
64
|
-
|
65
|
-
for key in ("pre", "post"):
|
66
|
-
for k in rule[key]:
|
67
|
-
if k not in VALID_AUGMENTATIONS:
|
68
|
-
nice_list = "\n".join([f" {item}" for item in VALID_AUGMENTATIONS])
|
69
|
-
raise ValueError(f"Invalid augmentation: {k}.\nValid augmentations are:\n{nice_list}")
|
70
|
-
|
71
|
-
if k in ["eq1", "eq2", "eq3"]:
|
72
|
-
if not eq_rule_is_valid(rule[key][k]):
|
73
|
-
raise ValueError(f"Invalid augmentation value for {k}: {rule[key][k]}")
|
74
|
-
|
75
|
-
if all(isinstance(el, list) or (isinstance(el, str) and el == "none") for el in rule[key][k]):
|
76
|
-
# Expand multiple rules
|
77
|
-
for value in rule[key][k]:
|
78
|
-
expanded_rule = deepcopy(rule)
|
79
|
-
if isinstance(value, str) and value == "none":
|
80
|
-
expanded_rule[key][k] = None
|
81
|
-
else:
|
82
|
-
expanded_rule[key][k] = deepcopy(value)
|
83
|
-
_expand_rules(expanded_rules, expanded_rule)
|
84
|
-
return expanded_rules
|
85
|
-
|
86
|
-
else:
|
87
|
-
if isinstance(rule[key][k], list):
|
88
|
-
for value in rule[key][k]:
|
89
|
-
if isinstance(value, list):
|
90
|
-
raise TypeError(f"Invalid augmentation value for {k}: {rule[key][k]}")
|
91
|
-
expanded_rule = deepcopy(rule)
|
92
|
-
expanded_rule[key][k] = deepcopy(value)
|
93
|
-
_expand_rules(expanded_rules, expanded_rule)
|
94
|
-
return expanded_rules
|
95
|
-
else:
|
96
|
-
rule[key][k] = convert_string_to_number(rule[key][k])
|
97
|
-
if not (
|
98
|
-
isinstance(rule[key][k], float | int)
|
99
|
-
or rule[key][k].startswith("rand")
|
100
|
-
or rule[key][k] == "none"
|
101
|
-
):
|
102
|
-
raise ValueError(f"Invalid augmentation value for {k}: {rule[key][k]}")
|
103
|
-
|
104
|
-
expanded_rules.append(rule)
|
105
|
-
return expanded_rules
|
106
|
-
|
107
|
-
|
108
|
-
def _generate_none_rule(rule: dict) -> dict:
|
109
|
-
"""Generate a new rule from a rule that contains 'none' directives
|
110
|
-
|
111
|
-
:param rule: Rule
|
112
|
-
:return: New rule
|
113
|
-
"""
|
114
|
-
from copy import deepcopy
|
115
|
-
|
116
|
-
out_rule = deepcopy(rule)
|
117
|
-
for key in out_rule:
|
118
|
-
if out_rule[key] == "none":
|
119
|
-
out_rule[key] = None
|
120
|
-
|
121
|
-
return out_rule
|
122
|
-
|
123
|
-
|
124
|
-
def _generate_random_rule(rule: dict, num_ir: int = 0) -> dict:
|
125
|
-
"""Generate a new rule from a rule that contains 'rand' directives
|
126
|
-
|
127
|
-
:param rule: Rule
|
128
|
-
:param num_ir: Number of impulse responses in config
|
129
|
-
:return: Randomized rule
|
130
|
-
"""
|
131
|
-
from copy import deepcopy
|
132
|
-
from random import randint
|
133
|
-
|
134
|
-
out_rule = deepcopy(rule)
|
135
|
-
for key in ("pre", "post"):
|
136
|
-
for k in out_rule[key]:
|
137
|
-
if k == "ir" and out_rule[key][k] == "rand":
|
138
|
-
# IR is special case
|
139
|
-
if num_ir == 0:
|
140
|
-
out_rule[key][k] = None
|
141
|
-
else:
|
142
|
-
out_rule[key][k] = randint(0, num_ir - 1) # noqa: S311
|
143
|
-
else:
|
144
|
-
out_rule[key][k] = evaluate_random_rule(str(out_rule[key][k]))
|
145
|
-
|
146
|
-
# convert EQ values from strings to numbers
|
147
|
-
if k in ("eq1", "eq2", "eq3"):
|
148
|
-
for n in range(3):
|
149
|
-
if isinstance(out_rule[key][k][n], str):
|
150
|
-
out_rule[key][k][n] = eval(out_rule[key][k][n]) # noqa: S307
|
151
|
-
|
152
|
-
return out_rule
|
153
|
-
|
154
|
-
|
155
|
-
def _rule_has_rand(rule: dict) -> bool:
|
156
|
-
"""Determine if any keys in the given rule contain 'rand'
|
157
|
-
|
158
|
-
:param rule: Rule
|
159
|
-
:return: True if rule contains 'rand'
|
160
|
-
"""
|
161
|
-
return any("rand" in str(rule[key][k]) for key in rule for k in rule[key])
|
162
|
-
|
163
|
-
|
164
|
-
def estimate_augmented_length_from_length(length: int, tempo: OptionalNumberStr = None, frame_length: int = 1) -> int:
|
165
|
-
"""Estimate the length of audio after augmentation
|
166
|
-
|
167
|
-
:param length: Number of samples in audio
|
168
|
-
:param tempo: Tempo rule
|
169
|
-
:param frame_length: Pad resulting audio to be a multiple of this
|
170
|
-
:return: Estimated length of augmented audio
|
171
|
-
"""
|
172
|
-
import numpy as np
|
173
|
-
|
174
|
-
if tempo is not None:
|
175
|
-
length = int(np.round(length / float(tempo)))
|
176
|
-
|
177
|
-
length = _get_padded_length(length, frame_length)
|
178
|
-
|
179
|
-
return length
|
180
|
-
|
181
|
-
|
182
|
-
def get_mixups(augmentations: list[AugmentationRule]) -> list[int]:
|
183
|
-
"""Get a list of mixup values used
|
184
|
-
|
185
|
-
:param augmentations: List of augmentations
|
186
|
-
:return: List of mixup values used
|
187
|
-
"""
|
188
|
-
return sorted({augmentation.mixup for augmentation in augmentations})
|
189
|
-
|
190
|
-
|
191
|
-
def get_augmentation_indices_for_mixup(augmentations: list[AugmentationRule], mixup: int) -> list[int]:
|
192
|
-
"""Get a list of augmentation indices for a given mixup value
|
193
|
-
|
194
|
-
:param augmentations: List of augmentations
|
195
|
-
:param mixup: Mixup value of interest
|
196
|
-
:return: List of augmentation indices
|
197
|
-
"""
|
198
|
-
indices = []
|
199
|
-
for idx, augmentation in enumerate(augmentations):
|
200
|
-
if mixup == augmentation.mixup:
|
201
|
-
indices.append(idx)
|
202
|
-
|
203
|
-
return indices
|
204
|
-
|
205
|
-
|
206
|
-
def pad_audio_to_frame(audio: AudioT, frame_length: int = 1) -> AudioT:
|
207
|
-
"""Pad audio to be a multiple of frame length
|
208
|
-
|
209
|
-
:param audio: Audio
|
210
|
-
:param frame_length: Pad resulting audio to be a multiple of this
|
211
|
-
:return: Padded audio
|
212
|
-
"""
|
213
|
-
return pad_audio_to_length(audio, _get_padded_length(len(audio), frame_length))
|
214
|
-
|
215
|
-
|
216
|
-
def _get_padded_length(length: int, frame_length: int) -> int:
|
217
|
-
"""Get the number of pad samples needed
|
218
|
-
|
219
|
-
:param length: Length of audio
|
220
|
-
:param frame_length: Desired length will be a multiple of this
|
221
|
-
:return: Padded length
|
222
|
-
"""
|
223
|
-
mod = int(length % frame_length)
|
224
|
-
pad_length = frame_length - mod if mod else 0
|
225
|
-
return length + pad_length
|
226
|
-
|
227
|
-
|
228
|
-
def pad_audio_to_length(audio: AudioT, length: int) -> AudioT:
|
229
|
-
"""Pad audio to given length
|
230
|
-
|
231
|
-
:param audio: Audio
|
232
|
-
:param length: Length of output
|
233
|
-
:return: Padded audio
|
234
|
-
"""
|
235
|
-
import numpy as np
|
236
|
-
|
237
|
-
return np.pad(array=audio, pad_width=(0, length - len(audio)))
|
238
|
-
|
239
|
-
|
240
|
-
def apply_gain(audio: AudioT, gain: float) -> AudioT:
|
241
|
-
"""Apply gain to audio
|
242
|
-
|
243
|
-
:param audio: Audio
|
244
|
-
:param gain: Amount of gain
|
245
|
-
:return: Adjusted audio
|
246
|
-
"""
|
247
|
-
return audio * gain
|
248
|
-
|
249
|
-
|
250
|
-
def evaluate_random_rule(rule: str) -> str | float:
|
251
|
-
"""Evaluate 'rand' directive
|
252
|
-
|
253
|
-
:param rule: Rule
|
254
|
-
:return: Resolved value
|
255
|
-
"""
|
256
|
-
import re
|
257
|
-
from random import uniform
|
258
|
-
|
259
|
-
from .constants import RAND_PATTERN
|
260
|
-
|
261
|
-
def rand_repl(m):
|
262
|
-
return f"{uniform(float(m.group(1)), float(m.group(4))):.2f}" # noqa: S311
|
263
|
-
|
264
|
-
return eval(re.sub(RAND_PATTERN, rand_repl, rule)) # noqa: S307
|
265
|
-
|
266
|
-
|
267
|
-
def _parse_ir(rule: dict, num_ir: int) -> dict:
|
268
|
-
from .helpers import generic_ids_to_list
|
269
|
-
|
270
|
-
def _resolve_str(rule_in: str) -> str | list[int]:
|
271
|
-
if rule_in in ["rand", "none"]:
|
272
|
-
return rule_in
|
273
|
-
|
274
|
-
rule_out = generic_ids_to_list(num_ir, rule_in)
|
275
|
-
if not all(ro in range(num_ir) for ro in rule_out):
|
276
|
-
raise ValueError(f"Invalid ir entry of {rule_in}")
|
277
|
-
return rule_out
|
278
|
-
|
279
|
-
def _process(rule_in: dict) -> dict:
|
280
|
-
if "ir" not in rule_in:
|
281
|
-
return rule_in
|
282
|
-
|
283
|
-
ir = rule_in["ir"]
|
284
|
-
|
285
|
-
if ir is None:
|
286
|
-
return rule_in
|
287
|
-
|
288
|
-
if isinstance(ir, str):
|
289
|
-
rule_in["ir"] = _resolve_str(ir)
|
290
|
-
return rule_in
|
291
|
-
|
292
|
-
if isinstance(ir, list):
|
293
|
-
rule_in["ir"] = []
|
294
|
-
for item in ir:
|
295
|
-
result = _resolve_str(item)
|
296
|
-
if isinstance(result, str):
|
297
|
-
rule_in["ir"].append(_resolve_str(item))
|
298
|
-
else:
|
299
|
-
rule_in["ir"] += _resolve_str(item)
|
300
|
-
|
301
|
-
return rule_in
|
302
|
-
|
303
|
-
if isinstance(ir, int):
|
304
|
-
if ir not in range(num_ir):
|
305
|
-
raise ValueError(f"Invalid ir of {ir}")
|
306
|
-
return rule_in
|
307
|
-
|
308
|
-
raise ValueError(f"Invalid ir of {ir}")
|
309
|
-
|
310
|
-
for key in rule:
|
311
|
-
if key in ("pre", "post"):
|
312
|
-
rule[key] = _process(rule[key])
|
313
|
-
|
314
|
-
return rule
|
315
|
-
|
316
|
-
|
317
|
-
def apply_augmentation(
|
318
|
-
mixdb: MixtureDatabase,
|
319
|
-
audio: AudioT,
|
320
|
-
augmentation: AugmentationEffects,
|
321
|
-
frame_length: int = 1,
|
322
|
-
) -> AudioT:
|
323
|
-
"""Apply augmentations to audio data using torchaudio.sox_effects
|
324
|
-
|
325
|
-
:param mixdb: Mixture database
|
326
|
-
:param audio: Audio
|
327
|
-
:param augmentation: Augmentation
|
328
|
-
:param frame_length: Pad resulting audio to be a multiple of this
|
329
|
-
:return: Augmented audio
|
330
|
-
"""
|
331
|
-
import numpy as np
|
332
|
-
import torch
|
333
|
-
import torchaudio
|
334
|
-
|
335
|
-
from .audio import read_ir
|
336
|
-
from .constants import SAMPLE_RATE
|
337
|
-
|
338
|
-
effects: list[list[str]] = []
|
339
|
-
|
340
|
-
# TODO: Always normalize and remove normalize from list of available augmentations
|
341
|
-
# Normalize to globally set level (should this be a global config parameter, or hard-coded into the script?)
|
342
|
-
# TODO: Support all sox effects supported by torchaudio (torchaudio.sox_effects.effect_names())
|
343
|
-
if augmentation.normalize is not None:
|
344
|
-
effects.append(["norm", str(augmentation.normalize)])
|
345
|
-
|
346
|
-
if augmentation.gain is not None:
|
347
|
-
effects.append(["gain", str(augmentation.gain)])
|
348
|
-
|
349
|
-
if augmentation.pitch is not None:
|
350
|
-
effects.append(["pitch", str(augmentation.pitch)])
|
351
|
-
effects.append(["rate", str(SAMPLE_RATE)])
|
352
|
-
|
353
|
-
if augmentation.tempo is not None:
|
354
|
-
effects.append(["tempo", "-s", str(augmentation.tempo)])
|
355
|
-
|
356
|
-
if augmentation.eq1 is not None:
|
357
|
-
effects.append(["equalizer", *[str(item) for item in augmentation.eq1]])
|
358
|
-
|
359
|
-
if augmentation.eq2 is not None:
|
360
|
-
effects.append(["equalizer", *[str(item) for item in augmentation.eq2]])
|
361
|
-
|
362
|
-
if augmentation.eq3 is not None:
|
363
|
-
effects.append(["equalizer", *[str(item) for item in augmentation.eq3]])
|
364
|
-
|
365
|
-
if augmentation.lpf is not None:
|
366
|
-
effects.append(["lowpass", "-2", str(augmentation.lpf), "0.707"])
|
367
|
-
|
368
|
-
if effects:
|
369
|
-
if audio.ndim == 1:
|
370
|
-
audio = np.reshape(audio, (1, audio.shape[0]))
|
371
|
-
out = torch.tensor(audio)
|
372
|
-
|
373
|
-
try:
|
374
|
-
out, _ = torchaudio.sox_effects.apply_effects_tensor(out, sample_rate=SAMPLE_RATE, effects=effects)
|
375
|
-
except Exception as e:
|
376
|
-
raise RuntimeError(f"Error applying {augmentation}: {e}") from e
|
377
|
-
|
378
|
-
audio_out = np.squeeze(np.array(out))
|
379
|
-
else:
|
380
|
-
audio_out = audio
|
381
|
-
|
382
|
-
if augmentation.ir is not None:
|
383
|
-
audio_out = apply_impulse_response(
|
384
|
-
audio=audio_out,
|
385
|
-
ir=read_ir(
|
386
|
-
name=mixdb.impulse_response_file(augmentation.ir), # pyright: ignore [reportArgumentType]
|
387
|
-
delay=mixdb.impulse_response_delay(augmentation.ir), # pyright: ignore [reportArgumentType]
|
388
|
-
use_cache=mixdb.use_cache,
|
389
|
-
),
|
390
|
-
)
|
391
|
-
|
392
|
-
# make sure length is multiple of frame_length
|
393
|
-
return pad_audio_to_frame(audio=audio_out, frame_length=frame_length)
|
394
|
-
|
395
|
-
|
396
|
-
def apply_impulse_response(audio: AudioT, ir: ImpulseResponseData) -> AudioT:
|
397
|
-
"""Apply impulse response to audio data using scipy
|
398
|
-
|
399
|
-
:param audio: Audio
|
400
|
-
:param ir: Impulse response data
|
401
|
-
:return: Augmented audio
|
402
|
-
"""
|
403
|
-
import numpy as np
|
404
|
-
from librosa import resample
|
405
|
-
from scipy.signal import fftconvolve
|
406
|
-
|
407
|
-
from .constants import SAMPLE_RATE
|
408
|
-
|
409
|
-
# Early exit if no ir or if all audio is zero
|
410
|
-
if ir is None or not audio.any():
|
411
|
-
return audio
|
412
|
-
|
413
|
-
# Convert audio to IR sample rate
|
414
|
-
audio_in = resample(audio, orig_sr=SAMPLE_RATE, target_sr=ir.sample_rate, res_type="soxr_hq")
|
415
|
-
max_in = np.max(np.abs(audio_in))
|
416
|
-
|
417
|
-
# Apply IR
|
418
|
-
audio_out = fftconvolve(audio_in, ir.data, mode="full")
|
419
|
-
|
420
|
-
# Delay compensation
|
421
|
-
audio_out = audio_out[ir.delay :]
|
422
|
-
|
423
|
-
# Convert back to global sample rate
|
424
|
-
audio_out = resample(audio_out, orig_sr=ir.sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_hq")
|
425
|
-
|
426
|
-
# Trim to length
|
427
|
-
audio_out = audio_out[: len(audio)]
|
428
|
-
max_out = np.max(np.abs(audio_out))
|
429
|
-
|
430
|
-
compensation_gain = max_in / max_out
|
431
|
-
|
432
|
-
return audio_out * compensation_gain
|
433
|
-
|
434
|
-
|
435
|
-
def augmentation_from_rule(rule: AugmentationRule, num_ir: int) -> Augmentation:
|
436
|
-
from sonusai.utils import dataclass_from_dict
|
437
|
-
|
438
|
-
processed_rule = rule.to_dict()
|
439
|
-
del processed_rule["mixup"]
|
440
|
-
processed_rule = _generate_none_rule(processed_rule)
|
441
|
-
if _rule_has_rand(processed_rule):
|
442
|
-
processed_rule = _generate_random_rule(processed_rule, num_ir)
|
443
|
-
|
444
|
-
return dataclass_from_dict(Augmentation, processed_rule) # pyright: ignore [reportReturnType]
|
sonusai/mixture/class_count.py
DELETED
@@ -1,15 +0,0 @@
|
|
1
|
-
from sonusai.mixture.datatypes import ClassCount
|
2
|
-
from sonusai.mixture.datatypes import GeneralizedIDs
|
3
|
-
from sonusai.mixture.mixdb import MixtureDatabase
|
4
|
-
|
5
|
-
|
6
|
-
def get_class_count_from_mixids(mixdb: MixtureDatabase, mixids: GeneralizedIDs = "*") -> ClassCount:
|
7
|
-
"""Sums the class counts for given mixids"""
|
8
|
-
total_class_count = [0] * mixdb.num_classes
|
9
|
-
m_ids = mixdb.mixids_to_list(mixids)
|
10
|
-
for m_id in m_ids:
|
11
|
-
class_count = mixdb.mixture_class_count(m_id)
|
12
|
-
for cl in range(mixdb.num_classes):
|
13
|
-
total_class_count[cl] += class_count[cl]
|
14
|
-
|
15
|
-
return total_class_count
|
@@ -1,45 +0,0 @@
|
|
1
|
-
from typing import Any
|
2
|
-
|
3
|
-
|
4
|
-
def eq_rule_is_valid(rule: Any) -> bool:
|
5
|
-
"""Check if EQ rule is valid
|
6
|
-
|
7
|
-
An EQ rule must be a tuple of length 3 or a list of length 3 tuples.
|
8
|
-
"""
|
9
|
-
|
10
|
-
# Must be a list or string equal to 'none'
|
11
|
-
if isinstance(rule, str) and rule == "none":
|
12
|
-
return True
|
13
|
-
|
14
|
-
if not isinstance(rule, list):
|
15
|
-
return False
|
16
|
-
|
17
|
-
if len(rule) != 3:
|
18
|
-
# If the length is not 3, then all elements must also be lists
|
19
|
-
if not all(_check_for_none(el) for el in rule):
|
20
|
-
return False
|
21
|
-
rules = rule
|
22
|
-
else:
|
23
|
-
rules = [rule]
|
24
|
-
|
25
|
-
for r in rules:
|
26
|
-
# Each item must be a number or string
|
27
|
-
if not all(isinstance(el, float | int | str) for el in r):
|
28
|
-
return False
|
29
|
-
|
30
|
-
if isinstance(r, str) and r == "none":
|
31
|
-
continue
|
32
|
-
|
33
|
-
for el in r:
|
34
|
-
# If a string, item must start with 'rand'
|
35
|
-
if isinstance(el, str) and not el.startswith("rand"):
|
36
|
-
return False
|
37
|
-
|
38
|
-
return True
|
39
|
-
|
40
|
-
|
41
|
-
def _check_for_none(rule: Any) -> bool:
|
42
|
-
"""Check if EQ rule is 'none'"""
|
43
|
-
if isinstance(rule, str) and rule == "none":
|
44
|
-
return True
|
45
|
-
return bool(isinstance(rule, list) and len(rule) == 3)
|
@@ -1,107 +0,0 @@
|
|
1
|
-
from sonusai.mixture.datatypes import AugmentationRule
|
2
|
-
from sonusai.mixture.datatypes import AugmentedTarget
|
3
|
-
from sonusai.mixture.datatypes import TargetFile
|
4
|
-
|
5
|
-
|
6
|
-
def balance_targets(
|
7
|
-
augmented_targets: list[AugmentedTarget],
|
8
|
-
targets: list[TargetFile],
|
9
|
-
target_augmentations: list[AugmentationRule],
|
10
|
-
class_balancing_augmentation: AugmentationRule,
|
11
|
-
num_classes: int,
|
12
|
-
num_ir: int,
|
13
|
-
mixups: list[int] | None = None,
|
14
|
-
) -> tuple[list[AugmentedTarget], list[AugmentationRule]]:
|
15
|
-
import math
|
16
|
-
|
17
|
-
from .augmentation import get_mixups
|
18
|
-
from .datatypes import AugmentedTarget
|
19
|
-
from .targets import get_augmented_target_ids_by_class
|
20
|
-
|
21
|
-
first_cba_id = len(target_augmentations)
|
22
|
-
|
23
|
-
if mixups is None:
|
24
|
-
mixups = get_mixups(target_augmentations)
|
25
|
-
|
26
|
-
for mixup in mixups:
|
27
|
-
if mixup == 1:
|
28
|
-
continue
|
29
|
-
|
30
|
-
augmented_target_indices_by_class = get_augmented_target_ids_by_class(
|
31
|
-
augmented_targets=augmented_targets,
|
32
|
-
targets=targets,
|
33
|
-
target_augmentations=target_augmentations,
|
34
|
-
mixup=mixup,
|
35
|
-
num_classes=num_classes,
|
36
|
-
)
|
37
|
-
|
38
|
-
largest = max([len(item) for item in augmented_target_indices_by_class])
|
39
|
-
largest = math.ceil(largest / mixup) * mixup
|
40
|
-
for at_indices in augmented_target_indices_by_class:
|
41
|
-
additional_augmentations_needed = largest - len(at_indices)
|
42
|
-
target_ids = sorted({augmented_targets[at_index].target_id for at_index in at_indices})
|
43
|
-
|
44
|
-
tfi_idx = 0
|
45
|
-
for _ in range(additional_augmentations_needed):
|
46
|
-
target_id = target_ids[tfi_idx]
|
47
|
-
tfi_idx = (tfi_idx + 1) % len(target_ids)
|
48
|
-
augmentation_index, target_augmentations = _get_unused_balancing_augmentation(
|
49
|
-
augmented_targets=augmented_targets,
|
50
|
-
targets=targets,
|
51
|
-
target_augmentations=target_augmentations,
|
52
|
-
class_balancing_augmentation=class_balancing_augmentation,
|
53
|
-
target_id=target_id,
|
54
|
-
mixup=mixup,
|
55
|
-
num_ir=num_ir,
|
56
|
-
first_cba_id=first_cba_id,
|
57
|
-
)
|
58
|
-
augmented_target = AugmentedTarget(target_id=target_id, target_augmentation_id=augmentation_index)
|
59
|
-
augmented_targets.append(augmented_target)
|
60
|
-
|
61
|
-
return augmented_targets, target_augmentations
|
62
|
-
|
63
|
-
|
64
|
-
def _get_unused_balancing_augmentation(
|
65
|
-
augmented_targets: list[AugmentedTarget],
|
66
|
-
targets: list[TargetFile],
|
67
|
-
target_augmentations: list[AugmentationRule],
|
68
|
-
class_balancing_augmentation: AugmentationRule,
|
69
|
-
target_id: int,
|
70
|
-
mixup: int,
|
71
|
-
num_ir: int,
|
72
|
-
first_cba_id: int,
|
73
|
-
) -> tuple[int, list[AugmentationRule]]:
|
74
|
-
"""Get an unused balancing augmentation for a given target file index"""
|
75
|
-
from dataclasses import asdict
|
76
|
-
|
77
|
-
from .augmentation import get_augmentation_rules
|
78
|
-
|
79
|
-
balancing_augmentations = [item for item in range(len(target_augmentations)) if item >= first_cba_id]
|
80
|
-
used_balancing_augmentations = [
|
81
|
-
at.target_augmentation_id
|
82
|
-
for at in augmented_targets
|
83
|
-
if at.target_id == target_id and at.target_augmentation_id in balancing_augmentations
|
84
|
-
]
|
85
|
-
|
86
|
-
augmentation_indices = [
|
87
|
-
item
|
88
|
-
for item in balancing_augmentations
|
89
|
-
if item not in used_balancing_augmentations and target_augmentations[item].mixup == mixup
|
90
|
-
]
|
91
|
-
if len(augmentation_indices) > 0:
|
92
|
-
return augmentation_indices[0], target_augmentations
|
93
|
-
|
94
|
-
class_balancing_augmentation = get_class_balancing_augmentation(
|
95
|
-
target=targets[target_id], default_cba=class_balancing_augmentation
|
96
|
-
)
|
97
|
-
new_augmentation = get_augmentation_rules(rules=asdict(class_balancing_augmentation), num_ir=num_ir)[0]
|
98
|
-
new_augmentation.mixup = mixup
|
99
|
-
target_augmentations.append(new_augmentation)
|
100
|
-
return len(target_augmentations) - 1, target_augmentations
|
101
|
-
|
102
|
-
|
103
|
-
def get_class_balancing_augmentation(target: TargetFile, default_cba: AugmentationRule) -> AugmentationRule:
|
104
|
-
"""Get the class balancing augmentation rule for the given target"""
|
105
|
-
if target.class_balancing_augmentation is not None:
|
106
|
-
return target.class_balancing_augmentation
|
107
|
-
return default_cba
|