sonusai 0.19.10__py3-none-any.whl → 0.20.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/mixture/audio.py CHANGED
@@ -58,9 +58,62 @@ def get_sample_rate(name: str | Path, use_cache: bool = True) -> int:
58
58
 
59
59
  @lru_cache
60
60
  def _get_sample_rate(name: str | Path) -> int:
61
- from .soundfile_audio import get_sample_rate
61
+ """Get sample rate from audio file using soundfile
62
62
 
63
- return get_sample_rate(name)
63
+ :param name: File name
64
+ :return: Sample rate
65
+ """
66
+ import soundfile
67
+ from pydub import AudioSegment
68
+
69
+ from .tokenized_shell_vars import tokenized_expand
70
+
71
+ expanded_name, _ = tokenized_expand(name)
72
+
73
+ try:
74
+ if expanded_name.endswith(".mp3"):
75
+ return AudioSegment.from_mp3(expanded_name).frame_rate
76
+
77
+ if expanded_name.endswith(".m4a"):
78
+ return AudioSegment.from_file(expanded_name).frame_rate
79
+
80
+ return soundfile.info(expanded_name).samplerate
81
+ except Exception as e:
82
+ if name != expanded_name:
83
+ raise OSError(f"Error reading {name} (expanded: {expanded_name}): {e}") from e
84
+ else:
85
+ raise OSError(f"Error reading {name}: {e}") from e
86
+
87
+
88
+ def raw_read_audio(name: str | Path) -> tuple[AudioT, int]:
89
+ import numpy as np
90
+ import soundfile
91
+ from pydub import AudioSegment
92
+
93
+ from .tokenized_shell_vars import tokenized_expand
94
+
95
+ expanded_name, _ = tokenized_expand(name)
96
+
97
+ try:
98
+ if expanded_name.endswith(".mp3"):
99
+ sound = AudioSegment.from_mp3(expanded_name)
100
+ raw = np.array(sound.get_array_of_samples()).astype(np.float32).reshape((-1, sound.channels))
101
+ raw = raw / 2 ** (sound.sample_width * 8 - 1)
102
+ sample_rate = sound.frame_rate
103
+ elif expanded_name.endswith(".m4a"):
104
+ sound = AudioSegment.from_file(expanded_name)
105
+ raw = np.array(sound.get_array_of_samples()).astype(np.float32).reshape((-1, sound.channels))
106
+ raw = raw / 2 ** (sound.sample_width * 8 - 1)
107
+ sample_rate = sound.frame_rate
108
+ else:
109
+ raw, sample_rate = soundfile.read(expanded_name, always_2d=True, dtype="float32")
110
+ except Exception as e:
111
+ if name != expanded_name:
112
+ raise OSError(f"Error reading {name} (expanded: {expanded_name}): {e}") from e
113
+ else:
114
+ raise OSError(f"Error reading {name}: {e}") from e
115
+
116
+ return np.squeeze(raw[:, 0].astype(np.float32)), sample_rate
64
117
 
65
118
 
66
119
  def read_audio(name: str | Path, use_cache: bool = True) -> AudioT:
@@ -77,28 +130,45 @@ def read_audio(name: str | Path, use_cache: bool = True) -> AudioT:
77
130
 
78
131
  @lru_cache
79
132
  def _read_audio(name: str | Path) -> AudioT:
80
- from .soundfile_audio import read_audio
133
+ """Read audio data from a file using soundfile
134
+
135
+ :param name: File name
136
+ :return: Array of time domain audio data
137
+ """
138
+ import librosa
139
+
140
+ from .constants import SAMPLE_RATE
141
+
142
+ out, sample_rate = raw_read_audio(name)
143
+ out = librosa.resample(out, orig_sr=sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_hq")
81
144
 
82
- return read_audio(name)
145
+ return out
83
146
 
84
147
 
85
- def read_ir(name: str | Path, use_cache: bool = True) -> ImpulseResponseData:
148
+ def read_ir(name: str | Path, delay: int, use_cache: bool = True) -> ImpulseResponseData:
86
149
  """Read impulse response data
87
150
 
88
151
  :param name: File name
152
+ :param delay: Delay in samples
89
153
  :param use_cache: If true, use LRU caching
90
154
  :return: ImpulseResponseData object
91
155
  """
92
156
  if use_cache:
93
- return _read_ir(name)
94
- return _read_ir.__wrapped__(name)
157
+ return _read_ir(name, delay)
158
+ return _read_ir.__wrapped__(name, delay)
95
159
 
96
160
 
97
161
  @lru_cache
98
- def _read_ir(name: str | Path) -> ImpulseResponseData:
99
- from .soundfile_audio import read_ir
162
+ def _read_ir(name: str | Path, delay: int) -> ImpulseResponseData:
163
+ """Read impulse response data using soundfile
100
164
 
101
- return read_ir(name)
165
+ :param name: File name
166
+ :param delay: Delay in samples
167
+ :return: ImpulseResponseData object
168
+ """
169
+ out, sample_rate = raw_read_audio(name)
170
+
171
+ return ImpulseResponseData(data=out, sample_rate=sample_rate, delay=delay)
102
172
 
103
173
 
104
174
  def get_num_samples(name: str | Path, use_cache: bool = True) -> int:
@@ -120,6 +190,27 @@ def _get_num_samples(name: str | Path) -> int:
120
190
  :param name: File name
121
191
  :return: number of samples in resampled audio
122
192
  """
123
- from .soundfile_audio import get_num_samples
193
+ import math
194
+
195
+ import soundfile
196
+ from pydub import AudioSegment
124
197
 
125
- return get_num_samples(name)
198
+ from .constants import SAMPLE_RATE
199
+ from .tokenized_shell_vars import tokenized_expand
200
+
201
+ expanded_name, _ = tokenized_expand(name)
202
+
203
+ if expanded_name.endswith(".mp3"):
204
+ sound = AudioSegment.from_mp3(expanded_name)
205
+ samples = sound.frame_count()
206
+ sample_rate = sound.frame_rate
207
+ elif expanded_name.endswith(".m4a"):
208
+ sound = AudioSegment.from_file(expanded_name)
209
+ samples = sound.frame_count()
210
+ sample_rate = sound.frame_rate
211
+ else:
212
+ info = soundfile.info(name)
213
+ samples = info.frames
214
+ sample_rate = info.samplerate
215
+
216
+ return math.ceil(SAMPLE_RATE * samples / sample_rate)
@@ -1,8 +1,10 @@
1
1
  from sonusai.mixture.datatypes import AudioT
2
2
  from sonusai.mixture.datatypes import Augmentation
3
+ from sonusai.mixture.datatypes import AugmentationEffects
3
4
  from sonusai.mixture.datatypes import AugmentationRule
4
5
  from sonusai.mixture.datatypes import ImpulseResponseData
5
6
  from sonusai.mixture.datatypes import OptionalNumberStr
7
+ from sonusai.mixture.mixdb import MixtureDatabase
6
8
 
7
9
 
8
10
  def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> list[AugmentationRule]:
@@ -41,49 +43,63 @@ def _expand_rules(expanded_rules: list[dict], rule: dict) -> list[dict]:
41
43
  from .constants import VALID_AUGMENTATIONS
42
44
  from .eq_rule_is_valid import eq_rule_is_valid
43
45
 
44
- for key, value in list(rule.items()):
45
- if value is None:
46
- del rule[key]
46
+ if "pre" not in rule:
47
+ raise ValueError("Rule must have 'pre' key")
47
48
 
48
- # replace old 'eq' rule with new 'eq1' rule to allow both for backward compatibility
49
- rule = {"eq1" if key == "eq" else key: value for key, value in rule.items()}
49
+ if "post" not in rule:
50
+ rule["post"] = {}
51
+
52
+ for key in rule:
53
+ if key not in ("pre", "post", "mixup"):
54
+ raise ValueError(f"Invalid augmentation key: {key}")
50
55
 
56
+ if key in ("pre", "post"):
57
+ for k, v in list(rule[key].items()):
58
+ if v is None:
59
+ del rule[key][k]
60
+
61
+ # replace old 'eq' rule with new 'eq1' rule to allow both for backward compatibility
51
62
  for key in rule:
52
- if key not in VALID_AUGMENTATIONS:
53
- nice_list = "\n".join([f" {item}" for item in VALID_AUGMENTATIONS])
54
- raise ValueError(f"Invalid augmentation: {key}.\nValid augmentations are:\n{nice_list}")
55
-
56
- if key in ["eq1", "eq2", "eq3"]:
57
- if not eq_rule_is_valid(rule[key]):
58
- raise ValueError(f"Invalid augmentation value for {key}: {rule[key]}")
59
-
60
- if all(isinstance(el, list) or (isinstance(el, str) and el == "none") for el in rule[key]):
61
- # Expand multiple rules
62
- for value in rule[key]:
63
- expanded_rule = deepcopy(rule)
64
- if isinstance(value, str) and value == "none":
65
- expanded_rule[key] = None
66
- else:
67
- expanded_rule[key] = deepcopy(value)
68
- _expand_rules(expanded_rules, expanded_rule)
69
- return expanded_rules
70
-
71
- elif key in ["mixup"]:
72
- pass
73
-
74
- else:
75
- if isinstance(rule[key], list):
76
- for value in rule[key]:
77
- if isinstance(value, list):
78
- raise TypeError(f"Invalid augmentation value for {key}: {rule[key]}")
79
- expanded_rule = deepcopy(rule)
80
- expanded_rule[key] = deepcopy(value)
81
- _expand_rules(expanded_rules, expanded_rule)
82
- return expanded_rules
63
+ rule[key] = {"eq1" if k == "eq" else k: v for k, v in rule[key].items()}
64
+
65
+ for key in ("pre", "post"):
66
+ for k in rule[key]:
67
+ if k not in VALID_AUGMENTATIONS:
68
+ nice_list = "\n".join([f" {item}" for item in VALID_AUGMENTATIONS])
69
+ raise ValueError(f"Invalid augmentation: {k}.\nValid augmentations are:\n{nice_list}")
70
+
71
+ if k in ["eq1", "eq2", "eq3"]:
72
+ if not eq_rule_is_valid(rule[key][k]):
73
+ raise ValueError(f"Invalid augmentation value for {k}: {rule[key][k]}")
74
+
75
+ if all(isinstance(el, list) or (isinstance(el, str) and el == "none") for el in rule[key][k]):
76
+ # Expand multiple rules
77
+ for value in rule[key][k]:
78
+ expanded_rule = deepcopy(rule)
79
+ if isinstance(value, str) and value == "none":
80
+ expanded_rule[key][k] = None
81
+ else:
82
+ expanded_rule[key][k] = deepcopy(value)
83
+ _expand_rules(expanded_rules, expanded_rule)
84
+ return expanded_rules
85
+
83
86
  else:
84
- rule[key] = convert_string_to_number(rule[key])
85
- if not (isinstance(rule[key], float | int) or rule[key].startswith("rand") or rule[key] == "none"):
86
- raise ValueError(f"Invalid augmentation value for {key}: {rule[key]}")
87
+ if isinstance(rule[key][k], list):
88
+ for value in rule[key][k]:
89
+ if isinstance(value, list):
90
+ raise TypeError(f"Invalid augmentation value for {k}: {rule[key][k]}")
91
+ expanded_rule = deepcopy(rule)
92
+ expanded_rule[key][k] = deepcopy(value)
93
+ _expand_rules(expanded_rules, expanded_rule)
94
+ return expanded_rules
95
+ else:
96
+ rule[key][k] = convert_string_to_number(rule[key][k])
97
+ if not (
98
+ isinstance(rule[key][k], float | int)
99
+ or rule[key][k].startswith("rand")
100
+ or rule[key][k] == "none"
101
+ ):
102
+ raise ValueError(f"Invalid augmentation value for {k}: {rule[key][k]}")
87
103
 
88
104
  expanded_rules.append(rule)
89
105
  return expanded_rules
@@ -116,21 +132,22 @@ def _generate_random_rule(rule: dict, num_ir: int = 0) -> dict:
116
132
  from random import randint
117
133
 
118
134
  out_rule = deepcopy(rule)
119
- for key in out_rule:
120
- if key == "ir" and out_rule[key] == "rand":
121
- # IR is special case
122
- if num_ir == 0:
123
- out_rule[key] = None
135
+ for key in ("pre", "post"):
136
+ for k in out_rule[key]:
137
+ if k == "ir" and out_rule[key][k] == "rand":
138
+ # IR is special case
139
+ if num_ir == 0:
140
+ out_rule[key][k] = None
141
+ else:
142
+ out_rule[key][k] = randint(0, num_ir - 1) # noqa: S311
124
143
  else:
125
- out_rule[key] = randint(0, num_ir - 1) # noqa: S311
126
- else:
127
- out_rule[key] = evaluate_random_rule(str(out_rule[key]))
144
+ out_rule[key][k] = evaluate_random_rule(str(out_rule[key][k]))
128
145
 
129
- # convert EQ values from strings to numbers
130
- if key in ["eq1", "eq2", "eq3"]:
131
- for n in range(3):
132
- if isinstance(out_rule[key][n], str):
133
- out_rule[key][n] = eval(out_rule[key][n]) # noqa: S307
146
+ # convert EQ values from strings to numbers
147
+ if k in ("eq1", "eq2", "eq3"):
148
+ for n in range(3):
149
+ if isinstance(out_rule[key][k][n], str):
150
+ out_rule[key][k][n] = eval(out_rule[key][k][n]) # noqa: S307
134
151
 
135
152
  return out_rule
136
153
 
@@ -141,7 +158,7 @@ def _rule_has_rand(rule: dict) -> bool:
141
158
  :param rule: Rule
142
159
  :return: True if rule contains 'rand'
143
160
  """
144
- return any("rand" in str(rule[key]) for key in rule)
161
+ return any("rand" in str(rule[key][k]) for key in rule for k in rule[key])
145
162
 
146
163
 
147
164
  def estimate_augmented_length_from_length(length: int, tempo: OptionalNumberStr = None, frame_length: int = 1) -> int:
@@ -259,67 +276,165 @@ def _parse_ir(rule: dict, num_ir: int) -> dict:
259
276
  raise ValueError(f"Invalid ir entry of {rule_in}")
260
277
  return rule_out
261
278
 
262
- if "ir" not in rule:
263
- return rule
279
+ def _process(rule_in: dict) -> dict:
280
+ if "ir" not in rule_in:
281
+ return rule_in
264
282
 
265
- ir = rule["ir"]
283
+ ir = rule_in["ir"]
266
284
 
267
- if ir is None:
268
- return rule
285
+ if ir is None:
286
+ return rule_in
269
287
 
270
- if isinstance(ir, str):
271
- rule["ir"] = _resolve_str(ir)
272
- return rule
288
+ if isinstance(ir, str):
289
+ rule_in["ir"] = _resolve_str(ir)
290
+ return rule_in
273
291
 
274
- if isinstance(ir, list):
275
- rule["ir"] = []
276
- for item in ir:
277
- result = _resolve_str(item)
278
- if isinstance(result, str):
279
- rule["ir"].append(_resolve_str(item))
280
- else:
281
- rule["ir"] += _resolve_str(item)
292
+ if isinstance(ir, list):
293
+ rule_in["ir"] = []
294
+ for item in ir:
295
+ result = _resolve_str(item)
296
+ if isinstance(result, str):
297
+ rule_in["ir"].append(_resolve_str(item))
298
+ else:
299
+ rule_in["ir"] += _resolve_str(item)
282
300
 
283
- return rule
301
+ return rule_in
302
+
303
+ if isinstance(ir, int):
304
+ if ir not in range(num_ir):
305
+ raise ValueError(f"Invalid ir of {ir}")
306
+ return rule_in
284
307
 
285
- if isinstance(ir, int):
286
- if ir not in range(num_ir):
287
- raise ValueError(f"Invalid ir of {ir}")
288
- return rule
308
+ raise ValueError(f"Invalid ir of {ir}")
309
+
310
+ for key in rule:
311
+ if key in ("pre", "post"):
312
+ rule[key] = _process(rule[key])
289
313
 
290
- raise ValueError(f"Invalid ir of {ir}")
314
+ return rule
291
315
 
292
316
 
293
- def apply_augmentation(audio: AudioT, augmentation: Augmentation, frame_length: int = 1) -> AudioT:
294
- """Apply augmentations to audio data
317
+ def apply_augmentation(
318
+ mixdb: MixtureDatabase,
319
+ audio: AudioT,
320
+ augmentation: AugmentationEffects,
321
+ frame_length: int = 1,
322
+ ) -> AudioT:
323
+ """Apply augmentations to audio data using torchaudio.sox_effects
295
324
 
325
+ :param mixdb: Mixture database
296
326
  :param audio: Audio
297
327
  :param augmentation: Augmentation
298
328
  :param frame_length: Pad resulting audio to be a multiple of this
299
329
  :return: Augmented audio
300
330
  """
301
- from .torchaudio_augmentation import apply_augmentation
331
+ import numpy as np
332
+ import torch
333
+ import torchaudio
334
+
335
+ from .audio import read_ir
336
+ from .constants import SAMPLE_RATE
337
+
338
+ effects: list[list[str]] = []
339
+
340
+ # TODO: Always normalize and remove normalize from list of available augmentations
341
+ # Normalize to globally set level (should this be a global config parameter, or hard-coded into the script?)
342
+ # TODO: Support all sox effects supported by torchaudio (torchaudio.sox_effects.effect_names())
343
+ if augmentation.normalize is not None:
344
+ effects.append(["norm", str(augmentation.normalize)])
345
+
346
+ if augmentation.gain is not None:
347
+ effects.append(["gain", str(augmentation.gain)])
302
348
 
303
- return apply_augmentation(audio, augmentation, frame_length)
349
+ if augmentation.pitch is not None:
350
+ effects.append(["pitch", str(augmentation.pitch)])
351
+ effects.append(["rate", str(SAMPLE_RATE)])
352
+
353
+ if augmentation.tempo is not None:
354
+ effects.append(["tempo", "-s", str(augmentation.tempo)])
355
+
356
+ if augmentation.eq1 is not None:
357
+ effects.append(["equalizer", *[str(item) for item in augmentation.eq1]])
358
+
359
+ if augmentation.eq2 is not None:
360
+ effects.append(["equalizer", *[str(item) for item in augmentation.eq2]])
361
+
362
+ if augmentation.eq3 is not None:
363
+ effects.append(["equalizer", *[str(item) for item in augmentation.eq3]])
364
+
365
+ if augmentation.lpf is not None:
366
+ effects.append(["lowpass", "-2", str(augmentation.lpf), "0.707"])
367
+
368
+ if effects:
369
+ if audio.ndim == 1:
370
+ audio = np.reshape(audio, (1, audio.shape[0]))
371
+ out = torch.tensor(audio)
372
+
373
+ try:
374
+ out, _ = torchaudio.sox_effects.apply_effects_tensor(out, sample_rate=SAMPLE_RATE, effects=effects)
375
+ except Exception as e:
376
+ raise RuntimeError(f"Error applying {augmentation}: {e}") from e
377
+
378
+ audio_out = np.squeeze(np.array(out))
379
+ else:
380
+ audio_out = audio
381
+
382
+ if augmentation.ir is not None:
383
+ audio_out = apply_impulse_response(
384
+ audio=audio_out,
385
+ ir=read_ir(
386
+ name=mixdb.impulse_response_file(augmentation.ir), # pyright: ignore [reportArgumentType]
387
+ delay=mixdb.impulse_response_delay(augmentation.ir), # pyright: ignore [reportArgumentType]
388
+ use_cache=mixdb.use_cache,
389
+ ),
390
+ )
391
+
392
+ # make sure length is multiple of frame_length
393
+ return pad_audio_to_frame(audio=audio_out, frame_length=frame_length)
304
394
 
305
395
 
306
396
  def apply_impulse_response(audio: AudioT, ir: ImpulseResponseData) -> AudioT:
307
- """Apply impulse response to audio data
397
+ """Apply impulse response to audio data using scipy
308
398
 
309
399
  :param audio: Audio
310
400
  :param ir: Impulse response data
311
401
  :return: Augmented audio
312
402
  """
313
- from .torchaudio_augmentation import apply_impulse_response
403
+ import numpy as np
404
+ from librosa import resample
405
+ from scipy.signal import fftconvolve
406
+
407
+ from .constants import SAMPLE_RATE
408
+
409
+ # Early exit if no ir or if all audio is zero
410
+ if ir is None or not audio.any():
411
+ return audio
412
+
413
+ # Convert audio to IR sample rate
414
+ audio_in = resample(audio, orig_sr=SAMPLE_RATE, target_sr=ir.sample_rate, res_type="soxr_hq")
415
+ max_in = np.max(np.abs(audio_in))
416
+
417
+ # Apply IR
418
+ audio_out = fftconvolve(audio_in, ir.data, mode="full")
314
419
 
315
- return apply_impulse_response(audio, ir)
420
+ # Delay compensation
421
+ audio_out = audio_out[ir.delay :]
422
+
423
+ # Convert back to global sample rate
424
+ audio_out = resample(audio_out, orig_sr=ir.sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_hq")
425
+
426
+ # Trim to length
427
+ audio_out = audio_out[: len(audio)]
428
+ max_out = np.max(np.abs(audio_out))
429
+
430
+ compensation_gain = max_in / max_out
431
+
432
+ return audio_out * compensation_gain
316
433
 
317
434
 
318
435
  def augmentation_from_rule(rule: AugmentationRule, num_ir: int) -> Augmentation:
319
436
  from sonusai.utils import dataclass_from_dict
320
437
 
321
- from .datatypes import Augmentation
322
-
323
438
  processed_rule = rule.to_dict()
324
439
  del processed_rule["mixup"]
325
440
  processed_rule = _generate_none_rule(processed_rule)
sonusai/mixture/config.py CHANGED
@@ -529,7 +529,7 @@ def get_impulse_response_files(config: dict) -> list[ImpulseResponseFile]:
529
529
  return list(
530
530
  chain.from_iterable(
531
531
  [
532
- append_impulse_response_files(entry=ImpulseResponseFile(entry["name"], entry.get("tags", [])))
532
+ append_impulse_response_files(entry=ImpulseResponseFile(entry["name"], entry.get("tags", []), 0))
533
533
  for entry in config["impulse_responses"]
534
534
  ]
535
535
  )
@@ -552,6 +552,7 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
552
552
  from os.path import splitext
553
553
 
554
554
  from .audio import validate_input_file
555
+ from .ir_delay import get_impulse_response_delay
555
556
  from .tokenized_shell_vars import tokenized_expand
556
557
  from .tokenized_shell_vars import tokenized_replace
557
558
 
@@ -572,7 +573,7 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
572
573
  for file in listdir(name):
573
574
  if not isabs(file):
574
575
  file = join(dir_name, file)
575
- child = ImpulseResponseFile(file, entry.tags)
576
+ child = ImpulseResponseFile(file, entry.tags, get_impulse_response_delay(file))
576
577
  impulse_response_files.extend(append_impulse_response_files(entry=child, tokens=tokens))
577
578
  else:
578
579
  try:
@@ -587,7 +588,7 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
587
588
  tokens.update(new_tokens)
588
589
  if not isabs(file):
589
590
  file = join(dir_name, file)
590
- child = ImpulseResponseFile(file, entry.tags)
591
+ child = ImpulseResponseFile(file, entry.tags, get_impulse_response_delay(file))
591
592
  impulse_response_files.extend(append_impulse_response_files(entry=child, tokens=tokens))
592
593
  elif ext == ".yml":
593
594
  try:
@@ -602,7 +603,11 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
602
603
  raise OSError(f"Error processing {name}: {e}") from e
603
604
  else:
604
605
  validate_input_file(name)
605
- impulse_response_files.append(ImpulseResponseFile(tokenized_replace(name, tokens), entry.tags))
606
+ impulse_response_files.append(
607
+ ImpulseResponseFile(
608
+ tokenized_replace(name, tokens), entry.tags, get_impulse_response_delay(name)
609
+ )
610
+ )
606
611
  except Exception as e:
607
612
  raise OSError(f"Error processing {name}: {e}") from e
608
613
 
@@ -38,7 +38,6 @@ VALID_AUGMENTATIONS = [
38
38
  "eq3",
39
39
  "lpf",
40
40
  "ir",
41
- "mixup",
42
41
  ]
43
42
  VALID_NOISE_MIX_MODES = ["exhaustive", "non-exhaustive", "non-combinatorial"]
44
43
  RAND_PATTERN = re.compile(r"rand\(([-+]?(\d+(\.\d*)?|\.\d+)),\s*([-+]?(\d+(\.\d*)?|\.\d+))\)")
@@ -75,7 +75,7 @@ EQ: TypeAlias = tuple[float | int, float | int, float | int]
75
75
 
76
76
 
77
77
  @dataclass
78
- class AugmentationRule(DataClassSonusAIMixin):
78
+ class AugmentationRuleEffects(DataClassSonusAIMixin):
79
79
  normalize: OptionalNumberStr = None
80
80
  pitch: OptionalNumberStr = None
81
81
  tempo: OptionalNumberStr = None
@@ -85,11 +85,17 @@ class AugmentationRule(DataClassSonusAIMixin):
85
85
  eq3: OptionalListNumberStr = None
86
86
  lpf: OptionalNumberStr = None
87
87
  ir: OptionalNumberStr = None
88
+
89
+
90
+ @dataclass
91
+ class AugmentationRule(DataClassSonusAIMixin):
92
+ pre: AugmentationRuleEffects
93
+ post: AugmentationRuleEffects | None = None
88
94
  mixup: int = 1
89
95
 
90
96
 
91
97
  @dataclass
92
- class Augmentation(DataClassSonusAIMixin):
98
+ class AugmentationEffects(DataClassSonusAIMixin):
93
99
  normalize: float | None = None
94
100
  pitch: float | None = None
95
101
  tempo: float | None = None
@@ -101,6 +107,12 @@ class Augmentation(DataClassSonusAIMixin):
101
107
  ir: int | None = None
102
108
 
103
109
 
110
+ @dataclass
111
+ class Augmentation(DataClassSonusAIMixin):
112
+ pre: AugmentationEffects
113
+ post: AugmentationEffects
114
+
115
+
104
116
  @dataclass(frozen=True)
105
117
  class UniversalSNRGenerator:
106
118
  is_random: bool
@@ -191,19 +203,16 @@ class GenFTData:
191
203
 
192
204
  @dataclass
193
205
  class ImpulseResponseData:
194
- name: str
195
- sample_rate: int
196
206
  data: AudioT
197
-
198
- @property
199
- def length(self) -> int:
200
- return len(self.data)
207
+ sample_rate: int
208
+ delay: int
201
209
 
202
210
 
203
211
  @dataclass
204
212
  class ImpulseResponseFile:
205
213
  file: str
206
214
  tags: list[str]
215
+ delay: int
207
216
 
208
217
 
209
218
  @dataclass(frozen=True)
@@ -230,9 +239,9 @@ class Target(DataClassSonusAIMixin):
230
239
  def gain(self) -> float:
231
240
  # gain is used to back out the gain augmentation in order to return the target audio
232
241
  # to its normalized level when calculating truth (if needed).
233
- if self.augmentation.gain is None:
242
+ if self.augmentation.pre.gain is None:
234
243
  return 1.0
235
- return round(10 ** (self.augmentation.gain / 20), ndigits=5)
244
+ return round(10 ** (self.augmentation.pre.gain / 20), ndigits=5)
236
245
 
237
246
 
238
247
  Targets: TypeAlias = list[Target]