sonusai 0.19.10__py3-none-any.whl → 0.20.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/data/genmixdb.yml +4 -2
- sonusai/doc/doc.py +14 -0
- sonusai/ir_metric.py +555 -0
- sonusai/metrics_summary.py +5 -3
- sonusai/mixture/__init__.py +4 -1
- sonusai/mixture/audio.py +103 -12
- sonusai/mixture/augmentation.py +199 -84
- sonusai/mixture/config.py +9 -4
- sonusai/mixture/constants.py +0 -1
- sonusai/mixture/datatypes.py +19 -10
- sonusai/mixture/generation.py +11 -12
- sonusai/mixture/helpers.py +20 -23
- sonusai/mixture/ir_delay.py +63 -0
- sonusai/mixture/mixdb.py +103 -19
- sonusai/mixture/targets.py +3 -6
- sonusai/utils/__init__.py +2 -0
- sonusai/utils/temp_seed.py +13 -0
- {sonusai-0.19.10.dist-info → sonusai-0.20.2.dist-info}/METADATA +2 -2
- {sonusai-0.19.10.dist-info → sonusai-0.20.2.dist-info}/RECORD +21 -23
- {sonusai-0.19.10.dist-info → sonusai-0.20.2.dist-info}/WHEEL +1 -1
- sonusai/mixture/soundfile_audio.py +0 -130
- sonusai/mixture/sox_audio.py +0 -476
- sonusai/mixture/sox_augmentation.py +0 -136
- sonusai/mixture/torchaudio_audio.py +0 -106
- sonusai/mixture/torchaudio_augmentation.py +0 -109
- {sonusai-0.19.10.dist-info → sonusai-0.20.2.dist-info}/entry_points.txt +0 -0
sonusai/mixture/audio.py
CHANGED
@@ -58,9 +58,62 @@ def get_sample_rate(name: str | Path, use_cache: bool = True) -> int:
|
|
58
58
|
|
59
59
|
@lru_cache
|
60
60
|
def _get_sample_rate(name: str | Path) -> int:
|
61
|
-
from
|
61
|
+
"""Get sample rate from audio file using soundfile
|
62
62
|
|
63
|
-
|
63
|
+
:param name: File name
|
64
|
+
:return: Sample rate
|
65
|
+
"""
|
66
|
+
import soundfile
|
67
|
+
from pydub import AudioSegment
|
68
|
+
|
69
|
+
from .tokenized_shell_vars import tokenized_expand
|
70
|
+
|
71
|
+
expanded_name, _ = tokenized_expand(name)
|
72
|
+
|
73
|
+
try:
|
74
|
+
if expanded_name.endswith(".mp3"):
|
75
|
+
return AudioSegment.from_mp3(expanded_name).frame_rate
|
76
|
+
|
77
|
+
if expanded_name.endswith(".m4a"):
|
78
|
+
return AudioSegment.from_file(expanded_name).frame_rate
|
79
|
+
|
80
|
+
return soundfile.info(expanded_name).samplerate
|
81
|
+
except Exception as e:
|
82
|
+
if name != expanded_name:
|
83
|
+
raise OSError(f"Error reading {name} (expanded: {expanded_name}): {e}") from e
|
84
|
+
else:
|
85
|
+
raise OSError(f"Error reading {name}: {e}") from e
|
86
|
+
|
87
|
+
|
88
|
+
def raw_read_audio(name: str | Path) -> tuple[AudioT, int]:
|
89
|
+
import numpy as np
|
90
|
+
import soundfile
|
91
|
+
from pydub import AudioSegment
|
92
|
+
|
93
|
+
from .tokenized_shell_vars import tokenized_expand
|
94
|
+
|
95
|
+
expanded_name, _ = tokenized_expand(name)
|
96
|
+
|
97
|
+
try:
|
98
|
+
if expanded_name.endswith(".mp3"):
|
99
|
+
sound = AudioSegment.from_mp3(expanded_name)
|
100
|
+
raw = np.array(sound.get_array_of_samples()).astype(np.float32).reshape((-1, sound.channels))
|
101
|
+
raw = raw / 2 ** (sound.sample_width * 8 - 1)
|
102
|
+
sample_rate = sound.frame_rate
|
103
|
+
elif expanded_name.endswith(".m4a"):
|
104
|
+
sound = AudioSegment.from_file(expanded_name)
|
105
|
+
raw = np.array(sound.get_array_of_samples()).astype(np.float32).reshape((-1, sound.channels))
|
106
|
+
raw = raw / 2 ** (sound.sample_width * 8 - 1)
|
107
|
+
sample_rate = sound.frame_rate
|
108
|
+
else:
|
109
|
+
raw, sample_rate = soundfile.read(expanded_name, always_2d=True, dtype="float32")
|
110
|
+
except Exception as e:
|
111
|
+
if name != expanded_name:
|
112
|
+
raise OSError(f"Error reading {name} (expanded: {expanded_name}): {e}") from e
|
113
|
+
else:
|
114
|
+
raise OSError(f"Error reading {name}: {e}") from e
|
115
|
+
|
116
|
+
return np.squeeze(raw[:, 0].astype(np.float32)), sample_rate
|
64
117
|
|
65
118
|
|
66
119
|
def read_audio(name: str | Path, use_cache: bool = True) -> AudioT:
|
@@ -77,28 +130,45 @@ def read_audio(name: str | Path, use_cache: bool = True) -> AudioT:
|
|
77
130
|
|
78
131
|
@lru_cache
|
79
132
|
def _read_audio(name: str | Path) -> AudioT:
|
80
|
-
from
|
133
|
+
"""Read audio data from a file using soundfile
|
134
|
+
|
135
|
+
:param name: File name
|
136
|
+
:return: Array of time domain audio data
|
137
|
+
"""
|
138
|
+
import librosa
|
139
|
+
|
140
|
+
from .constants import SAMPLE_RATE
|
141
|
+
|
142
|
+
out, sample_rate = raw_read_audio(name)
|
143
|
+
out = librosa.resample(out, orig_sr=sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_hq")
|
81
144
|
|
82
|
-
return
|
145
|
+
return out
|
83
146
|
|
84
147
|
|
85
|
-
def read_ir(name: str | Path, use_cache: bool = True) -> ImpulseResponseData:
|
148
|
+
def read_ir(name: str | Path, delay: int, use_cache: bool = True) -> ImpulseResponseData:
|
86
149
|
"""Read impulse response data
|
87
150
|
|
88
151
|
:param name: File name
|
152
|
+
:param delay: Delay in samples
|
89
153
|
:param use_cache: If true, use LRU caching
|
90
154
|
:return: ImpulseResponseData object
|
91
155
|
"""
|
92
156
|
if use_cache:
|
93
|
-
return _read_ir(name)
|
94
|
-
return _read_ir.__wrapped__(name)
|
157
|
+
return _read_ir(name, delay)
|
158
|
+
return _read_ir.__wrapped__(name, delay)
|
95
159
|
|
96
160
|
|
97
161
|
@lru_cache
|
98
|
-
def _read_ir(name: str | Path) -> ImpulseResponseData:
|
99
|
-
|
162
|
+
def _read_ir(name: str | Path, delay: int) -> ImpulseResponseData:
|
163
|
+
"""Read impulse response data using soundfile
|
100
164
|
|
101
|
-
|
165
|
+
:param name: File name
|
166
|
+
:param delay: Delay in samples
|
167
|
+
:return: ImpulseResponseData object
|
168
|
+
"""
|
169
|
+
out, sample_rate = raw_read_audio(name)
|
170
|
+
|
171
|
+
return ImpulseResponseData(data=out, sample_rate=sample_rate, delay=delay)
|
102
172
|
|
103
173
|
|
104
174
|
def get_num_samples(name: str | Path, use_cache: bool = True) -> int:
|
@@ -120,6 +190,27 @@ def _get_num_samples(name: str | Path) -> int:
|
|
120
190
|
:param name: File name
|
121
191
|
:return: number of samples in resampled audio
|
122
192
|
"""
|
123
|
-
|
193
|
+
import math
|
194
|
+
|
195
|
+
import soundfile
|
196
|
+
from pydub import AudioSegment
|
124
197
|
|
125
|
-
|
198
|
+
from .constants import SAMPLE_RATE
|
199
|
+
from .tokenized_shell_vars import tokenized_expand
|
200
|
+
|
201
|
+
expanded_name, _ = tokenized_expand(name)
|
202
|
+
|
203
|
+
if expanded_name.endswith(".mp3"):
|
204
|
+
sound = AudioSegment.from_mp3(expanded_name)
|
205
|
+
samples = sound.frame_count()
|
206
|
+
sample_rate = sound.frame_rate
|
207
|
+
elif expanded_name.endswith(".m4a"):
|
208
|
+
sound = AudioSegment.from_file(expanded_name)
|
209
|
+
samples = sound.frame_count()
|
210
|
+
sample_rate = sound.frame_rate
|
211
|
+
else:
|
212
|
+
info = soundfile.info(name)
|
213
|
+
samples = info.frames
|
214
|
+
sample_rate = info.samplerate
|
215
|
+
|
216
|
+
return math.ceil(SAMPLE_RATE * samples / sample_rate)
|
sonusai/mixture/augmentation.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1
1
|
from sonusai.mixture.datatypes import AudioT
|
2
2
|
from sonusai.mixture.datatypes import Augmentation
|
3
|
+
from sonusai.mixture.datatypes import AugmentationEffects
|
3
4
|
from sonusai.mixture.datatypes import AugmentationRule
|
4
5
|
from sonusai.mixture.datatypes import ImpulseResponseData
|
5
6
|
from sonusai.mixture.datatypes import OptionalNumberStr
|
7
|
+
from sonusai.mixture.mixdb import MixtureDatabase
|
6
8
|
|
7
9
|
|
8
10
|
def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> list[AugmentationRule]:
|
@@ -41,49 +43,63 @@ def _expand_rules(expanded_rules: list[dict], rule: dict) -> list[dict]:
|
|
41
43
|
from .constants import VALID_AUGMENTATIONS
|
42
44
|
from .eq_rule_is_valid import eq_rule_is_valid
|
43
45
|
|
44
|
-
|
45
|
-
|
46
|
-
del rule[key]
|
46
|
+
if "pre" not in rule:
|
47
|
+
raise ValueError("Rule must have 'pre' key")
|
47
48
|
|
48
|
-
|
49
|
-
|
49
|
+
if "post" not in rule:
|
50
|
+
rule["post"] = {}
|
51
|
+
|
52
|
+
for key in rule:
|
53
|
+
if key not in ("pre", "post", "mixup"):
|
54
|
+
raise ValueError(f"Invalid augmentation key: {key}")
|
50
55
|
|
56
|
+
if key in ("pre", "post"):
|
57
|
+
for k, v in list(rule[key].items()):
|
58
|
+
if v is None:
|
59
|
+
del rule[key][k]
|
60
|
+
|
61
|
+
# replace old 'eq' rule with new 'eq1' rule to allow both for backward compatibility
|
51
62
|
for key in rule:
|
52
|
-
if
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
raise ValueError(f"Invalid augmentation
|
59
|
-
|
60
|
-
if
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
expanded_rule
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
if isinstance(rule[key], list):
|
76
|
-
for value in rule[key]:
|
77
|
-
if isinstance(value, list):
|
78
|
-
raise TypeError(f"Invalid augmentation value for {key}: {rule[key]}")
|
79
|
-
expanded_rule = deepcopy(rule)
|
80
|
-
expanded_rule[key] = deepcopy(value)
|
81
|
-
_expand_rules(expanded_rules, expanded_rule)
|
82
|
-
return expanded_rules
|
63
|
+
rule[key] = {"eq1" if k == "eq" else k: v for k, v in rule[key].items()}
|
64
|
+
|
65
|
+
for key in ("pre", "post"):
|
66
|
+
for k in rule[key]:
|
67
|
+
if k not in VALID_AUGMENTATIONS:
|
68
|
+
nice_list = "\n".join([f" {item}" for item in VALID_AUGMENTATIONS])
|
69
|
+
raise ValueError(f"Invalid augmentation: {k}.\nValid augmentations are:\n{nice_list}")
|
70
|
+
|
71
|
+
if k in ["eq1", "eq2", "eq3"]:
|
72
|
+
if not eq_rule_is_valid(rule[key][k]):
|
73
|
+
raise ValueError(f"Invalid augmentation value for {k}: {rule[key][k]}")
|
74
|
+
|
75
|
+
if all(isinstance(el, list) or (isinstance(el, str) and el == "none") for el in rule[key][k]):
|
76
|
+
# Expand multiple rules
|
77
|
+
for value in rule[key][k]:
|
78
|
+
expanded_rule = deepcopy(rule)
|
79
|
+
if isinstance(value, str) and value == "none":
|
80
|
+
expanded_rule[key][k] = None
|
81
|
+
else:
|
82
|
+
expanded_rule[key][k] = deepcopy(value)
|
83
|
+
_expand_rules(expanded_rules, expanded_rule)
|
84
|
+
return expanded_rules
|
85
|
+
|
83
86
|
else:
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
+
if isinstance(rule[key][k], list):
|
88
|
+
for value in rule[key][k]:
|
89
|
+
if isinstance(value, list):
|
90
|
+
raise TypeError(f"Invalid augmentation value for {k}: {rule[key][k]}")
|
91
|
+
expanded_rule = deepcopy(rule)
|
92
|
+
expanded_rule[key][k] = deepcopy(value)
|
93
|
+
_expand_rules(expanded_rules, expanded_rule)
|
94
|
+
return expanded_rules
|
95
|
+
else:
|
96
|
+
rule[key][k] = convert_string_to_number(rule[key][k])
|
97
|
+
if not (
|
98
|
+
isinstance(rule[key][k], float | int)
|
99
|
+
or rule[key][k].startswith("rand")
|
100
|
+
or rule[key][k] == "none"
|
101
|
+
):
|
102
|
+
raise ValueError(f"Invalid augmentation value for {k}: {rule[key][k]}")
|
87
103
|
|
88
104
|
expanded_rules.append(rule)
|
89
105
|
return expanded_rules
|
@@ -116,21 +132,22 @@ def _generate_random_rule(rule: dict, num_ir: int = 0) -> dict:
|
|
116
132
|
from random import randint
|
117
133
|
|
118
134
|
out_rule = deepcopy(rule)
|
119
|
-
for key in
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
135
|
+
for key in ("pre", "post"):
|
136
|
+
for k in out_rule[key]:
|
137
|
+
if k == "ir" and out_rule[key][k] == "rand":
|
138
|
+
# IR is special case
|
139
|
+
if num_ir == 0:
|
140
|
+
out_rule[key][k] = None
|
141
|
+
else:
|
142
|
+
out_rule[key][k] = randint(0, num_ir - 1) # noqa: S311
|
124
143
|
else:
|
125
|
-
out_rule[key] =
|
126
|
-
else:
|
127
|
-
out_rule[key] = evaluate_random_rule(str(out_rule[key]))
|
144
|
+
out_rule[key][k] = evaluate_random_rule(str(out_rule[key][k]))
|
128
145
|
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
146
|
+
# convert EQ values from strings to numbers
|
147
|
+
if k in ("eq1", "eq2", "eq3"):
|
148
|
+
for n in range(3):
|
149
|
+
if isinstance(out_rule[key][k][n], str):
|
150
|
+
out_rule[key][k][n] = eval(out_rule[key][k][n]) # noqa: S307
|
134
151
|
|
135
152
|
return out_rule
|
136
153
|
|
@@ -141,7 +158,7 @@ def _rule_has_rand(rule: dict) -> bool:
|
|
141
158
|
:param rule: Rule
|
142
159
|
:return: True if rule contains 'rand'
|
143
160
|
"""
|
144
|
-
return any("rand" in str(rule[key]) for key in rule)
|
161
|
+
return any("rand" in str(rule[key][k]) for key in rule for k in rule[key])
|
145
162
|
|
146
163
|
|
147
164
|
def estimate_augmented_length_from_length(length: int, tempo: OptionalNumberStr = None, frame_length: int = 1) -> int:
|
@@ -259,67 +276,165 @@ def _parse_ir(rule: dict, num_ir: int) -> dict:
|
|
259
276
|
raise ValueError(f"Invalid ir entry of {rule_in}")
|
260
277
|
return rule_out
|
261
278
|
|
262
|
-
|
263
|
-
|
279
|
+
def _process(rule_in: dict) -> dict:
|
280
|
+
if "ir" not in rule_in:
|
281
|
+
return rule_in
|
264
282
|
|
265
|
-
|
283
|
+
ir = rule_in["ir"]
|
266
284
|
|
267
|
-
|
268
|
-
|
285
|
+
if ir is None:
|
286
|
+
return rule_in
|
269
287
|
|
270
|
-
|
271
|
-
|
272
|
-
|
288
|
+
if isinstance(ir, str):
|
289
|
+
rule_in["ir"] = _resolve_str(ir)
|
290
|
+
return rule_in
|
273
291
|
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
292
|
+
if isinstance(ir, list):
|
293
|
+
rule_in["ir"] = []
|
294
|
+
for item in ir:
|
295
|
+
result = _resolve_str(item)
|
296
|
+
if isinstance(result, str):
|
297
|
+
rule_in["ir"].append(_resolve_str(item))
|
298
|
+
else:
|
299
|
+
rule_in["ir"] += _resolve_str(item)
|
282
300
|
|
283
|
-
|
301
|
+
return rule_in
|
302
|
+
|
303
|
+
if isinstance(ir, int):
|
304
|
+
if ir not in range(num_ir):
|
305
|
+
raise ValueError(f"Invalid ir of {ir}")
|
306
|
+
return rule_in
|
284
307
|
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
308
|
+
raise ValueError(f"Invalid ir of {ir}")
|
309
|
+
|
310
|
+
for key in rule:
|
311
|
+
if key in ("pre", "post"):
|
312
|
+
rule[key] = _process(rule[key])
|
289
313
|
|
290
|
-
|
314
|
+
return rule
|
291
315
|
|
292
316
|
|
293
|
-
def apply_augmentation(
|
294
|
-
|
317
|
+
def apply_augmentation(
|
318
|
+
mixdb: MixtureDatabase,
|
319
|
+
audio: AudioT,
|
320
|
+
augmentation: AugmentationEffects,
|
321
|
+
frame_length: int = 1,
|
322
|
+
) -> AudioT:
|
323
|
+
"""Apply augmentations to audio data using torchaudio.sox_effects
|
295
324
|
|
325
|
+
:param mixdb: Mixture database
|
296
326
|
:param audio: Audio
|
297
327
|
:param augmentation: Augmentation
|
298
328
|
:param frame_length: Pad resulting audio to be a multiple of this
|
299
329
|
:return: Augmented audio
|
300
330
|
"""
|
301
|
-
|
331
|
+
import numpy as np
|
332
|
+
import torch
|
333
|
+
import torchaudio
|
334
|
+
|
335
|
+
from .audio import read_ir
|
336
|
+
from .constants import SAMPLE_RATE
|
337
|
+
|
338
|
+
effects: list[list[str]] = []
|
339
|
+
|
340
|
+
# TODO: Always normalize and remove normalize from list of available augmentations
|
341
|
+
# Normalize to globally set level (should this be a global config parameter, or hard-coded into the script?)
|
342
|
+
# TODO: Support all sox effects supported by torchaudio (torchaudio.sox_effects.effect_names())
|
343
|
+
if augmentation.normalize is not None:
|
344
|
+
effects.append(["norm", str(augmentation.normalize)])
|
345
|
+
|
346
|
+
if augmentation.gain is not None:
|
347
|
+
effects.append(["gain", str(augmentation.gain)])
|
302
348
|
|
303
|
-
|
349
|
+
if augmentation.pitch is not None:
|
350
|
+
effects.append(["pitch", str(augmentation.pitch)])
|
351
|
+
effects.append(["rate", str(SAMPLE_RATE)])
|
352
|
+
|
353
|
+
if augmentation.tempo is not None:
|
354
|
+
effects.append(["tempo", "-s", str(augmentation.tempo)])
|
355
|
+
|
356
|
+
if augmentation.eq1 is not None:
|
357
|
+
effects.append(["equalizer", *[str(item) for item in augmentation.eq1]])
|
358
|
+
|
359
|
+
if augmentation.eq2 is not None:
|
360
|
+
effects.append(["equalizer", *[str(item) for item in augmentation.eq2]])
|
361
|
+
|
362
|
+
if augmentation.eq3 is not None:
|
363
|
+
effects.append(["equalizer", *[str(item) for item in augmentation.eq3]])
|
364
|
+
|
365
|
+
if augmentation.lpf is not None:
|
366
|
+
effects.append(["lowpass", "-2", str(augmentation.lpf), "0.707"])
|
367
|
+
|
368
|
+
if effects:
|
369
|
+
if audio.ndim == 1:
|
370
|
+
audio = np.reshape(audio, (1, audio.shape[0]))
|
371
|
+
out = torch.tensor(audio)
|
372
|
+
|
373
|
+
try:
|
374
|
+
out, _ = torchaudio.sox_effects.apply_effects_tensor(out, sample_rate=SAMPLE_RATE, effects=effects)
|
375
|
+
except Exception as e:
|
376
|
+
raise RuntimeError(f"Error applying {augmentation}: {e}") from e
|
377
|
+
|
378
|
+
audio_out = np.squeeze(np.array(out))
|
379
|
+
else:
|
380
|
+
audio_out = audio
|
381
|
+
|
382
|
+
if augmentation.ir is not None:
|
383
|
+
audio_out = apply_impulse_response(
|
384
|
+
audio=audio_out,
|
385
|
+
ir=read_ir(
|
386
|
+
name=mixdb.impulse_response_file(augmentation.ir), # pyright: ignore [reportArgumentType]
|
387
|
+
delay=mixdb.impulse_response_delay(augmentation.ir), # pyright: ignore [reportArgumentType]
|
388
|
+
use_cache=mixdb.use_cache,
|
389
|
+
),
|
390
|
+
)
|
391
|
+
|
392
|
+
# make sure length is multiple of frame_length
|
393
|
+
return pad_audio_to_frame(audio=audio_out, frame_length=frame_length)
|
304
394
|
|
305
395
|
|
306
396
|
def apply_impulse_response(audio: AudioT, ir: ImpulseResponseData) -> AudioT:
|
307
|
-
"""Apply impulse response to audio data
|
397
|
+
"""Apply impulse response to audio data using scipy
|
308
398
|
|
309
399
|
:param audio: Audio
|
310
400
|
:param ir: Impulse response data
|
311
401
|
:return: Augmented audio
|
312
402
|
"""
|
313
|
-
|
403
|
+
import numpy as np
|
404
|
+
from librosa import resample
|
405
|
+
from scipy.signal import fftconvolve
|
406
|
+
|
407
|
+
from .constants import SAMPLE_RATE
|
408
|
+
|
409
|
+
# Early exit if no ir or if all audio is zero
|
410
|
+
if ir is None or not audio.any():
|
411
|
+
return audio
|
412
|
+
|
413
|
+
# Convert audio to IR sample rate
|
414
|
+
audio_in = resample(audio, orig_sr=SAMPLE_RATE, target_sr=ir.sample_rate, res_type="soxr_hq")
|
415
|
+
max_in = np.max(np.abs(audio_in))
|
416
|
+
|
417
|
+
# Apply IR
|
418
|
+
audio_out = fftconvolve(audio_in, ir.data, mode="full")
|
314
419
|
|
315
|
-
|
420
|
+
# Delay compensation
|
421
|
+
audio_out = audio_out[ir.delay :]
|
422
|
+
|
423
|
+
# Convert back to global sample rate
|
424
|
+
audio_out = resample(audio_out, orig_sr=ir.sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_hq")
|
425
|
+
|
426
|
+
# Trim to length
|
427
|
+
audio_out = audio_out[: len(audio)]
|
428
|
+
max_out = np.max(np.abs(audio_out))
|
429
|
+
|
430
|
+
compensation_gain = max_in / max_out
|
431
|
+
|
432
|
+
return audio_out * compensation_gain
|
316
433
|
|
317
434
|
|
318
435
|
def augmentation_from_rule(rule: AugmentationRule, num_ir: int) -> Augmentation:
|
319
436
|
from sonusai.utils import dataclass_from_dict
|
320
437
|
|
321
|
-
from .datatypes import Augmentation
|
322
|
-
|
323
438
|
processed_rule = rule.to_dict()
|
324
439
|
del processed_rule["mixup"]
|
325
440
|
processed_rule = _generate_none_rule(processed_rule)
|
sonusai/mixture/config.py
CHANGED
@@ -529,7 +529,7 @@ def get_impulse_response_files(config: dict) -> list[ImpulseResponseFile]:
|
|
529
529
|
return list(
|
530
530
|
chain.from_iterable(
|
531
531
|
[
|
532
|
-
append_impulse_response_files(entry=ImpulseResponseFile(entry["name"], entry.get("tags", [])))
|
532
|
+
append_impulse_response_files(entry=ImpulseResponseFile(entry["name"], entry.get("tags", []), 0))
|
533
533
|
for entry in config["impulse_responses"]
|
534
534
|
]
|
535
535
|
)
|
@@ -552,6 +552,7 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
|
|
552
552
|
from os.path import splitext
|
553
553
|
|
554
554
|
from .audio import validate_input_file
|
555
|
+
from .ir_delay import get_impulse_response_delay
|
555
556
|
from .tokenized_shell_vars import tokenized_expand
|
556
557
|
from .tokenized_shell_vars import tokenized_replace
|
557
558
|
|
@@ -572,7 +573,7 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
|
|
572
573
|
for file in listdir(name):
|
573
574
|
if not isabs(file):
|
574
575
|
file = join(dir_name, file)
|
575
|
-
child = ImpulseResponseFile(file, entry.tags)
|
576
|
+
child = ImpulseResponseFile(file, entry.tags, get_impulse_response_delay(file))
|
576
577
|
impulse_response_files.extend(append_impulse_response_files(entry=child, tokens=tokens))
|
577
578
|
else:
|
578
579
|
try:
|
@@ -587,7 +588,7 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
|
|
587
588
|
tokens.update(new_tokens)
|
588
589
|
if not isabs(file):
|
589
590
|
file = join(dir_name, file)
|
590
|
-
child = ImpulseResponseFile(file, entry.tags)
|
591
|
+
child = ImpulseResponseFile(file, entry.tags, get_impulse_response_delay(file))
|
591
592
|
impulse_response_files.extend(append_impulse_response_files(entry=child, tokens=tokens))
|
592
593
|
elif ext == ".yml":
|
593
594
|
try:
|
@@ -602,7 +603,11 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
|
|
602
603
|
raise OSError(f"Error processing {name}: {e}") from e
|
603
604
|
else:
|
604
605
|
validate_input_file(name)
|
605
|
-
impulse_response_files.append(
|
606
|
+
impulse_response_files.append(
|
607
|
+
ImpulseResponseFile(
|
608
|
+
tokenized_replace(name, tokens), entry.tags, get_impulse_response_delay(name)
|
609
|
+
)
|
610
|
+
)
|
606
611
|
except Exception as e:
|
607
612
|
raise OSError(f"Error processing {name}: {e}") from e
|
608
613
|
|
sonusai/mixture/constants.py
CHANGED
sonusai/mixture/datatypes.py
CHANGED
@@ -75,7 +75,7 @@ EQ: TypeAlias = tuple[float | int, float | int, float | int]
|
|
75
75
|
|
76
76
|
|
77
77
|
@dataclass
|
78
|
-
class
|
78
|
+
class AugmentationRuleEffects(DataClassSonusAIMixin):
|
79
79
|
normalize: OptionalNumberStr = None
|
80
80
|
pitch: OptionalNumberStr = None
|
81
81
|
tempo: OptionalNumberStr = None
|
@@ -85,11 +85,17 @@ class AugmentationRule(DataClassSonusAIMixin):
|
|
85
85
|
eq3: OptionalListNumberStr = None
|
86
86
|
lpf: OptionalNumberStr = None
|
87
87
|
ir: OptionalNumberStr = None
|
88
|
+
|
89
|
+
|
90
|
+
@dataclass
|
91
|
+
class AugmentationRule(DataClassSonusAIMixin):
|
92
|
+
pre: AugmentationRuleEffects
|
93
|
+
post: AugmentationRuleEffects | None = None
|
88
94
|
mixup: int = 1
|
89
95
|
|
90
96
|
|
91
97
|
@dataclass
|
92
|
-
class
|
98
|
+
class AugmentationEffects(DataClassSonusAIMixin):
|
93
99
|
normalize: float | None = None
|
94
100
|
pitch: float | None = None
|
95
101
|
tempo: float | None = None
|
@@ -101,6 +107,12 @@ class Augmentation(DataClassSonusAIMixin):
|
|
101
107
|
ir: int | None = None
|
102
108
|
|
103
109
|
|
110
|
+
@dataclass
|
111
|
+
class Augmentation(DataClassSonusAIMixin):
|
112
|
+
pre: AugmentationEffects
|
113
|
+
post: AugmentationEffects
|
114
|
+
|
115
|
+
|
104
116
|
@dataclass(frozen=True)
|
105
117
|
class UniversalSNRGenerator:
|
106
118
|
is_random: bool
|
@@ -191,19 +203,16 @@ class GenFTData:
|
|
191
203
|
|
192
204
|
@dataclass
|
193
205
|
class ImpulseResponseData:
|
194
|
-
name: str
|
195
|
-
sample_rate: int
|
196
206
|
data: AudioT
|
197
|
-
|
198
|
-
|
199
|
-
def length(self) -> int:
|
200
|
-
return len(self.data)
|
207
|
+
sample_rate: int
|
208
|
+
delay: int
|
201
209
|
|
202
210
|
|
203
211
|
@dataclass
|
204
212
|
class ImpulseResponseFile:
|
205
213
|
file: str
|
206
214
|
tags: list[str]
|
215
|
+
delay: int
|
207
216
|
|
208
217
|
|
209
218
|
@dataclass(frozen=True)
|
@@ -230,9 +239,9 @@ class Target(DataClassSonusAIMixin):
|
|
230
239
|
def gain(self) -> float:
|
231
240
|
# gain is used to back out the gain augmentation in order to return the target audio
|
232
241
|
# to its normalized level when calculating truth (if needed).
|
233
|
-
if self.augmentation.gain is None:
|
242
|
+
if self.augmentation.pre.gain is None:
|
234
243
|
return 1.0
|
235
|
-
return round(10 ** (self.augmentation.gain / 20), ndigits=5)
|
244
|
+
return round(10 ** (self.augmentation.pre.gain / 20), ndigits=5)
|
236
245
|
|
237
246
|
|
238
247
|
Targets: TypeAlias = list[Target]
|