sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sonusai/__init__.py +16 -3
  2. sonusai/audiofe.py +241 -77
  3. sonusai/calc_metric_spenh.py +71 -73
  4. sonusai/config/__init__.py +3 -0
  5. sonusai/config/config.py +61 -0
  6. sonusai/config/config.yml +20 -0
  7. sonusai/config/constants.py +8 -0
  8. sonusai/constants.py +11 -0
  9. sonusai/data/genmixdb.yml +21 -36
  10. sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
  11. sonusai/deprecated/plot.py +4 -5
  12. sonusai/doc/doc.py +4 -4
  13. sonusai/doc.py +11 -4
  14. sonusai/genft.py +43 -45
  15. sonusai/genmetrics.py +25 -19
  16. sonusai/genmix.py +54 -82
  17. sonusai/genmixdb.py +88 -264
  18. sonusai/ir_metric.py +30 -34
  19. sonusai/lsdb.py +41 -48
  20. sonusai/main.py +15 -22
  21. sonusai/metrics/calc_audio_stats.py +4 -293
  22. sonusai/metrics/calc_class_weights.py +4 -4
  23. sonusai/metrics/calc_optimal_thresholds.py +8 -5
  24. sonusai/metrics/calc_pesq.py +2 -2
  25. sonusai/metrics/calc_segsnr_f.py +4 -4
  26. sonusai/metrics/calc_speech.py +25 -13
  27. sonusai/metrics/class_summary.py +7 -7
  28. sonusai/metrics/confusion_matrix_summary.py +5 -5
  29. sonusai/metrics/one_hot.py +4 -4
  30. sonusai/metrics/snr_summary.py +7 -7
  31. sonusai/metrics_summary.py +38 -45
  32. sonusai/mixture/__init__.py +4 -104
  33. sonusai/mixture/audio.py +10 -39
  34. sonusai/mixture/class_balancing.py +103 -0
  35. sonusai/mixture/config.py +251 -271
  36. sonusai/mixture/constants.py +35 -39
  37. sonusai/mixture/data_io.py +25 -36
  38. sonusai/mixture/db_datatypes.py +58 -22
  39. sonusai/mixture/effects.py +386 -0
  40. sonusai/mixture/feature.py +7 -11
  41. sonusai/mixture/generation.py +478 -628
  42. sonusai/mixture/helpers.py +82 -184
  43. sonusai/mixture/ir_delay.py +3 -4
  44. sonusai/mixture/ir_effects.py +77 -0
  45. sonusai/mixture/log_duration_and_sizes.py +6 -12
  46. sonusai/mixture/mixdb.py +910 -729
  47. sonusai/mixture/pad_audio.py +35 -0
  48. sonusai/mixture/resample.py +7 -0
  49. sonusai/mixture/sox_effects.py +195 -0
  50. sonusai/mixture/sox_help.py +650 -0
  51. sonusai/mixture/spectral_mask.py +2 -2
  52. sonusai/mixture/truth.py +17 -15
  53. sonusai/mixture/truth_functions/crm.py +12 -12
  54. sonusai/mixture/truth_functions/energy.py +22 -22
  55. sonusai/mixture/truth_functions/file.py +5 -5
  56. sonusai/mixture/truth_functions/metadata.py +4 -4
  57. sonusai/mixture/truth_functions/metrics.py +4 -4
  58. sonusai/mixture/truth_functions/phoneme.py +3 -3
  59. sonusai/mixture/truth_functions/sed.py +11 -13
  60. sonusai/mixture/truth_functions/target.py +10 -10
  61. sonusai/mkwav.py +26 -29
  62. sonusai/onnx_predict.py +240 -88
  63. sonusai/queries/__init__.py +2 -2
  64. sonusai/queries/queries.py +38 -34
  65. sonusai/speech/librispeech.py +1 -1
  66. sonusai/speech/mcgill.py +1 -1
  67. sonusai/speech/timit.py +2 -2
  68. sonusai/summarize_metric_spenh.py +10 -17
  69. sonusai/utils/__init__.py +7 -1
  70. sonusai/utils/asl_p56.py +2 -2
  71. sonusai/utils/asr.py +2 -2
  72. sonusai/utils/asr_functions/aaware_whisper.py +4 -5
  73. sonusai/utils/choice.py +31 -0
  74. sonusai/utils/compress.py +1 -1
  75. sonusai/utils/dataclass_from_dict.py +19 -1
  76. sonusai/utils/energy_f.py +3 -3
  77. sonusai/utils/evaluate_random_rule.py +15 -0
  78. sonusai/utils/keyboard_interrupt.py +12 -0
  79. sonusai/utils/onnx_utils.py +3 -17
  80. sonusai/utils/print_mixture_details.py +21 -19
  81. sonusai/utils/{temp_seed.py → rand.py} +3 -3
  82. sonusai/utils/read_predict_data.py +2 -2
  83. sonusai/utils/reshape.py +3 -3
  84. sonusai/utils/stratified_shuffle_split.py +3 -3
  85. sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
  86. sonusai/utils/write_audio.py +2 -2
  87. sonusai/vars.py +11 -4
  88. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
  89. sonusai-1.0.2.dist-info/RECORD +138 -0
  90. sonusai/mixture/augmentation.py +0 -444
  91. sonusai/mixture/class_count.py +0 -15
  92. sonusai/mixture/eq_rule_is_valid.py +0 -45
  93. sonusai/mixture/target_class_balancing.py +0 -107
  94. sonusai/mixture/targets.py +0 -175
  95. sonusai-0.20.3.dist-info/RECORD +0 -128
  96. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
  97. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
@@ -1,444 +0,0 @@
1
- from sonusai.mixture.datatypes import AudioT
2
- from sonusai.mixture.datatypes import Augmentation
3
- from sonusai.mixture.datatypes import AugmentationEffects
4
- from sonusai.mixture.datatypes import AugmentationRule
5
- from sonusai.mixture.datatypes import ImpulseResponseData
6
- from sonusai.mixture.datatypes import OptionalNumberStr
7
- from sonusai.mixture.mixdb import MixtureDatabase
8
-
9
-
10
- def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> list[AugmentationRule]:
11
- """Generate augmentation rules from list of input rules
12
-
13
- :param rules: Dictionary of augmentation config rule[s]
14
- :param num_ir: Number of impulse responses in config
15
- :return: List of augmentation rules
16
- """
17
- from sonusai.utils import dataclass_from_dict
18
-
19
- from .datatypes import AugmentationRule
20
-
21
- processed_rules: list[dict] = []
22
- if not isinstance(rules, list):
23
- rules = [rules]
24
-
25
- for rule in rules:
26
- rule = _parse_ir(rule, num_ir)
27
- processed_rules = _expand_rules(expanded_rules=processed_rules, rule=rule)
28
-
29
- return [dataclass_from_dict(AugmentationRule, processed_rule) for processed_rule in processed_rules] # pyright: ignore [reportReturnType]
30
-
31
-
32
- def _expand_rules(expanded_rules: list[dict], rule: dict) -> list[dict]:
33
- """Expand rules
34
-
35
- :param expanded_rules: Working list of expanded rules
36
- :param rule: Rule to process
37
- :return: List of expanded rules
38
- """
39
- from copy import deepcopy
40
-
41
- from sonusai.utils import convert_string_to_number
42
-
43
- from .constants import VALID_AUGMENTATIONS
44
- from .eq_rule_is_valid import eq_rule_is_valid
45
-
46
- if "pre" not in rule:
47
- raise ValueError("Rule must have 'pre' key")
48
-
49
- if "post" not in rule:
50
- rule["post"] = {}
51
-
52
- for key in rule:
53
- if key not in ("pre", "post", "mixup"):
54
- raise ValueError(f"Invalid augmentation key: {key}")
55
-
56
- if key in ("pre", "post"):
57
- for k, v in list(rule[key].items()):
58
- if v is None:
59
- del rule[key][k]
60
-
61
- # replace old 'eq' rule with new 'eq1' rule to allow both for backward compatibility
62
- for key in rule:
63
- rule[key] = {"eq1" if k == "eq" else k: v for k, v in rule[key].items()}
64
-
65
- for key in ("pre", "post"):
66
- for k in rule[key]:
67
- if k not in VALID_AUGMENTATIONS:
68
- nice_list = "\n".join([f" {item}" for item in VALID_AUGMENTATIONS])
69
- raise ValueError(f"Invalid augmentation: {k}.\nValid augmentations are:\n{nice_list}")
70
-
71
- if k in ["eq1", "eq2", "eq3"]:
72
- if not eq_rule_is_valid(rule[key][k]):
73
- raise ValueError(f"Invalid augmentation value for {k}: {rule[key][k]}")
74
-
75
- if all(isinstance(el, list) or (isinstance(el, str) and el == "none") for el in rule[key][k]):
76
- # Expand multiple rules
77
- for value in rule[key][k]:
78
- expanded_rule = deepcopy(rule)
79
- if isinstance(value, str) and value == "none":
80
- expanded_rule[key][k] = None
81
- else:
82
- expanded_rule[key][k] = deepcopy(value)
83
- _expand_rules(expanded_rules, expanded_rule)
84
- return expanded_rules
85
-
86
- else:
87
- if isinstance(rule[key][k], list):
88
- for value in rule[key][k]:
89
- if isinstance(value, list):
90
- raise TypeError(f"Invalid augmentation value for {k}: {rule[key][k]}")
91
- expanded_rule = deepcopy(rule)
92
- expanded_rule[key][k] = deepcopy(value)
93
- _expand_rules(expanded_rules, expanded_rule)
94
- return expanded_rules
95
- else:
96
- rule[key][k] = convert_string_to_number(rule[key][k])
97
- if not (
98
- isinstance(rule[key][k], float | int)
99
- or rule[key][k].startswith("rand")
100
- or rule[key][k] == "none"
101
- ):
102
- raise ValueError(f"Invalid augmentation value for {k}: {rule[key][k]}")
103
-
104
- expanded_rules.append(rule)
105
- return expanded_rules
106
-
107
-
108
- def _generate_none_rule(rule: dict) -> dict:
109
- """Generate a new rule from a rule that contains 'none' directives
110
-
111
- :param rule: Rule
112
- :return: New rule
113
- """
114
- from copy import deepcopy
115
-
116
- out_rule = deepcopy(rule)
117
- for key in out_rule:
118
- if out_rule[key] == "none":
119
- out_rule[key] = None
120
-
121
- return out_rule
122
-
123
-
124
- def _generate_random_rule(rule: dict, num_ir: int = 0) -> dict:
125
- """Generate a new rule from a rule that contains 'rand' directives
126
-
127
- :param rule: Rule
128
- :param num_ir: Number of impulse responses in config
129
- :return: Randomized rule
130
- """
131
- from copy import deepcopy
132
- from random import randint
133
-
134
- out_rule = deepcopy(rule)
135
- for key in ("pre", "post"):
136
- for k in out_rule[key]:
137
- if k == "ir" and out_rule[key][k] == "rand":
138
- # IR is special case
139
- if num_ir == 0:
140
- out_rule[key][k] = None
141
- else:
142
- out_rule[key][k] = randint(0, num_ir - 1) # noqa: S311
143
- else:
144
- out_rule[key][k] = evaluate_random_rule(str(out_rule[key][k]))
145
-
146
- # convert EQ values from strings to numbers
147
- if k in ("eq1", "eq2", "eq3"):
148
- for n in range(3):
149
- if isinstance(out_rule[key][k][n], str):
150
- out_rule[key][k][n] = eval(out_rule[key][k][n]) # noqa: S307
151
-
152
- return out_rule
153
-
154
-
155
- def _rule_has_rand(rule: dict) -> bool:
156
- """Determine if any keys in the given rule contain 'rand'
157
-
158
- :param rule: Rule
159
- :return: True if rule contains 'rand'
160
- """
161
- return any("rand" in str(rule[key][k]) for key in rule for k in rule[key])
162
-
163
-
164
- def estimate_augmented_length_from_length(length: int, tempo: OptionalNumberStr = None, frame_length: int = 1) -> int:
165
- """Estimate the length of audio after augmentation
166
-
167
- :param length: Number of samples in audio
168
- :param tempo: Tempo rule
169
- :param frame_length: Pad resulting audio to be a multiple of this
170
- :return: Estimated length of augmented audio
171
- """
172
- import numpy as np
173
-
174
- if tempo is not None:
175
- length = int(np.round(length / float(tempo)))
176
-
177
- length = _get_padded_length(length, frame_length)
178
-
179
- return length
180
-
181
-
182
- def get_mixups(augmentations: list[AugmentationRule]) -> list[int]:
183
- """Get a list of mixup values used
184
-
185
- :param augmentations: List of augmentations
186
- :return: List of mixup values used
187
- """
188
- return sorted({augmentation.mixup for augmentation in augmentations})
189
-
190
-
191
- def get_augmentation_indices_for_mixup(augmentations: list[AugmentationRule], mixup: int) -> list[int]:
192
- """Get a list of augmentation indices for a given mixup value
193
-
194
- :param augmentations: List of augmentations
195
- :param mixup: Mixup value of interest
196
- :return: List of augmentation indices
197
- """
198
- indices = []
199
- for idx, augmentation in enumerate(augmentations):
200
- if mixup == augmentation.mixup:
201
- indices.append(idx)
202
-
203
- return indices
204
-
205
-
206
- def pad_audio_to_frame(audio: AudioT, frame_length: int = 1) -> AudioT:
207
- """Pad audio to be a multiple of frame length
208
-
209
- :param audio: Audio
210
- :param frame_length: Pad resulting audio to be a multiple of this
211
- :return: Padded audio
212
- """
213
- return pad_audio_to_length(audio, _get_padded_length(len(audio), frame_length))
214
-
215
-
216
- def _get_padded_length(length: int, frame_length: int) -> int:
217
- """Get the number of pad samples needed
218
-
219
- :param length: Length of audio
220
- :param frame_length: Desired length will be a multiple of this
221
- :return: Padded length
222
- """
223
- mod = int(length % frame_length)
224
- pad_length = frame_length - mod if mod else 0
225
- return length + pad_length
226
-
227
-
228
- def pad_audio_to_length(audio: AudioT, length: int) -> AudioT:
229
- """Pad audio to given length
230
-
231
- :param audio: Audio
232
- :param length: Length of output
233
- :return: Padded audio
234
- """
235
- import numpy as np
236
-
237
- return np.pad(array=audio, pad_width=(0, length - len(audio)))
238
-
239
-
240
- def apply_gain(audio: AudioT, gain: float) -> AudioT:
241
- """Apply gain to audio
242
-
243
- :param audio: Audio
244
- :param gain: Amount of gain
245
- :return: Adjusted audio
246
- """
247
- return audio * gain
248
-
249
-
250
- def evaluate_random_rule(rule: str) -> str | float:
251
- """Evaluate 'rand' directive
252
-
253
- :param rule: Rule
254
- :return: Resolved value
255
- """
256
- import re
257
- from random import uniform
258
-
259
- from .constants import RAND_PATTERN
260
-
261
- def rand_repl(m):
262
- return f"{uniform(float(m.group(1)), float(m.group(4))):.2f}" # noqa: S311
263
-
264
- return eval(re.sub(RAND_PATTERN, rand_repl, rule)) # noqa: S307
265
-
266
-
267
- def _parse_ir(rule: dict, num_ir: int) -> dict:
268
- from .helpers import generic_ids_to_list
269
-
270
- def _resolve_str(rule_in: str) -> str | list[int]:
271
- if rule_in in ["rand", "none"]:
272
- return rule_in
273
-
274
- rule_out = generic_ids_to_list(num_ir, rule_in)
275
- if not all(ro in range(num_ir) for ro in rule_out):
276
- raise ValueError(f"Invalid ir entry of {rule_in}")
277
- return rule_out
278
-
279
- def _process(rule_in: dict) -> dict:
280
- if "ir" not in rule_in:
281
- return rule_in
282
-
283
- ir = rule_in["ir"]
284
-
285
- if ir is None:
286
- return rule_in
287
-
288
- if isinstance(ir, str):
289
- rule_in["ir"] = _resolve_str(ir)
290
- return rule_in
291
-
292
- if isinstance(ir, list):
293
- rule_in["ir"] = []
294
- for item in ir:
295
- result = _resolve_str(item)
296
- if isinstance(result, str):
297
- rule_in["ir"].append(_resolve_str(item))
298
- else:
299
- rule_in["ir"] += _resolve_str(item)
300
-
301
- return rule_in
302
-
303
- if isinstance(ir, int):
304
- if ir not in range(num_ir):
305
- raise ValueError(f"Invalid ir of {ir}")
306
- return rule_in
307
-
308
- raise ValueError(f"Invalid ir of {ir}")
309
-
310
- for key in rule:
311
- if key in ("pre", "post"):
312
- rule[key] = _process(rule[key])
313
-
314
- return rule
315
-
316
-
317
- def apply_augmentation(
318
- mixdb: MixtureDatabase,
319
- audio: AudioT,
320
- augmentation: AugmentationEffects,
321
- frame_length: int = 1,
322
- ) -> AudioT:
323
- """Apply augmentations to audio data using torchaudio.sox_effects
324
-
325
- :param mixdb: Mixture database
326
- :param audio: Audio
327
- :param augmentation: Augmentation
328
- :param frame_length: Pad resulting audio to be a multiple of this
329
- :return: Augmented audio
330
- """
331
- import numpy as np
332
- import torch
333
- import torchaudio
334
-
335
- from .audio import read_ir
336
- from .constants import SAMPLE_RATE
337
-
338
- effects: list[list[str]] = []
339
-
340
- # TODO: Always normalize and remove normalize from list of available augmentations
341
- # Normalize to globally set level (should this be a global config parameter, or hard-coded into the script?)
342
- # TODO: Support all sox effects supported by torchaudio (torchaudio.sox_effects.effect_names())
343
- if augmentation.normalize is not None:
344
- effects.append(["norm", str(augmentation.normalize)])
345
-
346
- if augmentation.gain is not None:
347
- effects.append(["gain", str(augmentation.gain)])
348
-
349
- if augmentation.pitch is not None:
350
- effects.append(["pitch", str(augmentation.pitch)])
351
- effects.append(["rate", str(SAMPLE_RATE)])
352
-
353
- if augmentation.tempo is not None:
354
- effects.append(["tempo", "-s", str(augmentation.tempo)])
355
-
356
- if augmentation.eq1 is not None:
357
- effects.append(["equalizer", *[str(item) for item in augmentation.eq1]])
358
-
359
- if augmentation.eq2 is not None:
360
- effects.append(["equalizer", *[str(item) for item in augmentation.eq2]])
361
-
362
- if augmentation.eq3 is not None:
363
- effects.append(["equalizer", *[str(item) for item in augmentation.eq3]])
364
-
365
- if augmentation.lpf is not None:
366
- effects.append(["lowpass", "-2", str(augmentation.lpf), "0.707"])
367
-
368
- if effects:
369
- if audio.ndim == 1:
370
- audio = np.reshape(audio, (1, audio.shape[0]))
371
- out = torch.tensor(audio)
372
-
373
- try:
374
- out, _ = torchaudio.sox_effects.apply_effects_tensor(out, sample_rate=SAMPLE_RATE, effects=effects)
375
- except Exception as e:
376
- raise RuntimeError(f"Error applying {augmentation}: {e}") from e
377
-
378
- audio_out = np.squeeze(np.array(out))
379
- else:
380
- audio_out = audio
381
-
382
- if augmentation.ir is not None:
383
- audio_out = apply_impulse_response(
384
- audio=audio_out,
385
- ir=read_ir(
386
- name=mixdb.impulse_response_file(augmentation.ir), # pyright: ignore [reportArgumentType]
387
- delay=mixdb.impulse_response_delay(augmentation.ir), # pyright: ignore [reportArgumentType]
388
- use_cache=mixdb.use_cache,
389
- ),
390
- )
391
-
392
- # make sure length is multiple of frame_length
393
- return pad_audio_to_frame(audio=audio_out, frame_length=frame_length)
394
-
395
-
396
- def apply_impulse_response(audio: AudioT, ir: ImpulseResponseData) -> AudioT:
397
- """Apply impulse response to audio data using scipy
398
-
399
- :param audio: Audio
400
- :param ir: Impulse response data
401
- :return: Augmented audio
402
- """
403
- import numpy as np
404
- from librosa import resample
405
- from scipy.signal import fftconvolve
406
-
407
- from .constants import SAMPLE_RATE
408
-
409
- # Early exit if no ir or if all audio is zero
410
- if ir is None or not audio.any():
411
- return audio
412
-
413
- # Convert audio to IR sample rate
414
- audio_in = resample(audio, orig_sr=SAMPLE_RATE, target_sr=ir.sample_rate, res_type="soxr_hq")
415
- max_in = np.max(np.abs(audio_in))
416
-
417
- # Apply IR
418
- audio_out = fftconvolve(audio_in, ir.data, mode="full")
419
-
420
- # Delay compensation
421
- audio_out = audio_out[ir.delay :]
422
-
423
- # Convert back to global sample rate
424
- audio_out = resample(audio_out, orig_sr=ir.sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_hq")
425
-
426
- # Trim to length
427
- audio_out = audio_out[: len(audio)]
428
- max_out = np.max(np.abs(audio_out))
429
-
430
- compensation_gain = max_in / max_out
431
-
432
- return audio_out * compensation_gain
433
-
434
-
435
- def augmentation_from_rule(rule: AugmentationRule, num_ir: int) -> Augmentation:
436
- from sonusai.utils import dataclass_from_dict
437
-
438
- processed_rule = rule.to_dict()
439
- del processed_rule["mixup"]
440
- processed_rule = _generate_none_rule(processed_rule)
441
- if _rule_has_rand(processed_rule):
442
- processed_rule = _generate_random_rule(processed_rule, num_ir)
443
-
444
- return dataclass_from_dict(Augmentation, processed_rule) # pyright: ignore [reportReturnType]
@@ -1,15 +0,0 @@
1
- from sonusai.mixture.datatypes import ClassCount
2
- from sonusai.mixture.datatypes import GeneralizedIDs
3
- from sonusai.mixture.mixdb import MixtureDatabase
4
-
5
-
6
- def get_class_count_from_mixids(mixdb: MixtureDatabase, mixids: GeneralizedIDs = "*") -> ClassCount:
7
- """Sums the class counts for given mixids"""
8
- total_class_count = [0] * mixdb.num_classes
9
- m_ids = mixdb.mixids_to_list(mixids)
10
- for m_id in m_ids:
11
- class_count = mixdb.mixture_class_count(m_id)
12
- for cl in range(mixdb.num_classes):
13
- total_class_count[cl] += class_count[cl]
14
-
15
- return total_class_count
@@ -1,45 +0,0 @@
1
- from typing import Any
2
-
3
-
4
- def eq_rule_is_valid(rule: Any) -> bool:
5
- """Check if EQ rule is valid
6
-
7
- An EQ rule must be a tuple of length 3 or a list of length 3 tuples.
8
- """
9
-
10
- # Must be a list or string equal to 'none'
11
- if isinstance(rule, str) and rule == "none":
12
- return True
13
-
14
- if not isinstance(rule, list):
15
- return False
16
-
17
- if len(rule) != 3:
18
- # If the length is not 3, then all elements must also be lists
19
- if not all(_check_for_none(el) for el in rule):
20
- return False
21
- rules = rule
22
- else:
23
- rules = [rule]
24
-
25
- for r in rules:
26
- # Each item must be a number or string
27
- if not all(isinstance(el, float | int | str) for el in r):
28
- return False
29
-
30
- if isinstance(r, str) and r == "none":
31
- continue
32
-
33
- for el in r:
34
- # If a string, item must start with 'rand'
35
- if isinstance(el, str) and not el.startswith("rand"):
36
- return False
37
-
38
- return True
39
-
40
-
41
- def _check_for_none(rule: Any) -> bool:
42
- """Check if EQ rule is 'none'"""
43
- if isinstance(rule, str) and rule == "none":
44
- return True
45
- return bool(isinstance(rule, list) and len(rule) == 3)
@@ -1,107 +0,0 @@
1
- from sonusai.mixture.datatypes import AugmentationRule
2
- from sonusai.mixture.datatypes import AugmentedTarget
3
- from sonusai.mixture.datatypes import TargetFile
4
-
5
-
6
- def balance_targets(
7
- augmented_targets: list[AugmentedTarget],
8
- targets: list[TargetFile],
9
- target_augmentations: list[AugmentationRule],
10
- class_balancing_augmentation: AugmentationRule,
11
- num_classes: int,
12
- num_ir: int,
13
- mixups: list[int] | None = None,
14
- ) -> tuple[list[AugmentedTarget], list[AugmentationRule]]:
15
- import math
16
-
17
- from .augmentation import get_mixups
18
- from .datatypes import AugmentedTarget
19
- from .targets import get_augmented_target_ids_by_class
20
-
21
- first_cba_id = len(target_augmentations)
22
-
23
- if mixups is None:
24
- mixups = get_mixups(target_augmentations)
25
-
26
- for mixup in mixups:
27
- if mixup == 1:
28
- continue
29
-
30
- augmented_target_indices_by_class = get_augmented_target_ids_by_class(
31
- augmented_targets=augmented_targets,
32
- targets=targets,
33
- target_augmentations=target_augmentations,
34
- mixup=mixup,
35
- num_classes=num_classes,
36
- )
37
-
38
- largest = max([len(item) for item in augmented_target_indices_by_class])
39
- largest = math.ceil(largest / mixup) * mixup
40
- for at_indices in augmented_target_indices_by_class:
41
- additional_augmentations_needed = largest - len(at_indices)
42
- target_ids = sorted({augmented_targets[at_index].target_id for at_index in at_indices})
43
-
44
- tfi_idx = 0
45
- for _ in range(additional_augmentations_needed):
46
- target_id = target_ids[tfi_idx]
47
- tfi_idx = (tfi_idx + 1) % len(target_ids)
48
- augmentation_index, target_augmentations = _get_unused_balancing_augmentation(
49
- augmented_targets=augmented_targets,
50
- targets=targets,
51
- target_augmentations=target_augmentations,
52
- class_balancing_augmentation=class_balancing_augmentation,
53
- target_id=target_id,
54
- mixup=mixup,
55
- num_ir=num_ir,
56
- first_cba_id=first_cba_id,
57
- )
58
- augmented_target = AugmentedTarget(target_id=target_id, target_augmentation_id=augmentation_index)
59
- augmented_targets.append(augmented_target)
60
-
61
- return augmented_targets, target_augmentations
62
-
63
-
64
- def _get_unused_balancing_augmentation(
65
- augmented_targets: list[AugmentedTarget],
66
- targets: list[TargetFile],
67
- target_augmentations: list[AugmentationRule],
68
- class_balancing_augmentation: AugmentationRule,
69
- target_id: int,
70
- mixup: int,
71
- num_ir: int,
72
- first_cba_id: int,
73
- ) -> tuple[int, list[AugmentationRule]]:
74
- """Get an unused balancing augmentation for a given target file index"""
75
- from dataclasses import asdict
76
-
77
- from .augmentation import get_augmentation_rules
78
-
79
- balancing_augmentations = [item for item in range(len(target_augmentations)) if item >= first_cba_id]
80
- used_balancing_augmentations = [
81
- at.target_augmentation_id
82
- for at in augmented_targets
83
- if at.target_id == target_id and at.target_augmentation_id in balancing_augmentations
84
- ]
85
-
86
- augmentation_indices = [
87
- item
88
- for item in balancing_augmentations
89
- if item not in used_balancing_augmentations and target_augmentations[item].mixup == mixup
90
- ]
91
- if len(augmentation_indices) > 0:
92
- return augmentation_indices[0], target_augmentations
93
-
94
- class_balancing_augmentation = get_class_balancing_augmentation(
95
- target=targets[target_id], default_cba=class_balancing_augmentation
96
- )
97
- new_augmentation = get_augmentation_rules(rules=asdict(class_balancing_augmentation), num_ir=num_ir)[0]
98
- new_augmentation.mixup = mixup
99
- target_augmentations.append(new_augmentation)
100
- return len(target_augmentations) - 1, target_augmentations
101
-
102
-
103
- def get_class_balancing_augmentation(target: TargetFile, default_cba: AugmentationRule) -> AugmentationRule:
104
- """Get the class balancing augmentation rule for the given target"""
105
- if target.class_balancing_augmentation is not None:
106
- return target.class_balancing_augmentation
107
- return default_cba