sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sonusai/__init__.py +16 -3
  2. sonusai/audiofe.py +241 -77
  3. sonusai/calc_metric_spenh.py +71 -73
  4. sonusai/config/__init__.py +3 -0
  5. sonusai/config/config.py +61 -0
  6. sonusai/config/config.yml +20 -0
  7. sonusai/config/constants.py +8 -0
  8. sonusai/constants.py +11 -0
  9. sonusai/data/genmixdb.yml +21 -36
  10. sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
  11. sonusai/deprecated/plot.py +4 -5
  12. sonusai/doc/doc.py +4 -4
  13. sonusai/doc.py +11 -4
  14. sonusai/genft.py +43 -45
  15. sonusai/genmetrics.py +25 -19
  16. sonusai/genmix.py +54 -82
  17. sonusai/genmixdb.py +88 -264
  18. sonusai/ir_metric.py +30 -34
  19. sonusai/lsdb.py +41 -48
  20. sonusai/main.py +15 -22
  21. sonusai/metrics/calc_audio_stats.py +4 -293
  22. sonusai/metrics/calc_class_weights.py +4 -4
  23. sonusai/metrics/calc_optimal_thresholds.py +8 -5
  24. sonusai/metrics/calc_pesq.py +2 -2
  25. sonusai/metrics/calc_segsnr_f.py +4 -4
  26. sonusai/metrics/calc_speech.py +25 -13
  27. sonusai/metrics/class_summary.py +7 -7
  28. sonusai/metrics/confusion_matrix_summary.py +5 -5
  29. sonusai/metrics/one_hot.py +4 -4
  30. sonusai/metrics/snr_summary.py +7 -7
  31. sonusai/metrics_summary.py +38 -45
  32. sonusai/mixture/__init__.py +4 -104
  33. sonusai/mixture/audio.py +10 -39
  34. sonusai/mixture/class_balancing.py +103 -0
  35. sonusai/mixture/config.py +251 -271
  36. sonusai/mixture/constants.py +35 -39
  37. sonusai/mixture/data_io.py +25 -36
  38. sonusai/mixture/db_datatypes.py +58 -22
  39. sonusai/mixture/effects.py +386 -0
  40. sonusai/mixture/feature.py +7 -11
  41. sonusai/mixture/generation.py +478 -628
  42. sonusai/mixture/helpers.py +82 -184
  43. sonusai/mixture/ir_delay.py +3 -4
  44. sonusai/mixture/ir_effects.py +77 -0
  45. sonusai/mixture/log_duration_and_sizes.py +6 -12
  46. sonusai/mixture/mixdb.py +910 -729
  47. sonusai/mixture/pad_audio.py +35 -0
  48. sonusai/mixture/resample.py +7 -0
  49. sonusai/mixture/sox_effects.py +195 -0
  50. sonusai/mixture/sox_help.py +650 -0
  51. sonusai/mixture/spectral_mask.py +2 -2
  52. sonusai/mixture/truth.py +17 -15
  53. sonusai/mixture/truth_functions/crm.py +12 -12
  54. sonusai/mixture/truth_functions/energy.py +22 -22
  55. sonusai/mixture/truth_functions/file.py +5 -5
  56. sonusai/mixture/truth_functions/metadata.py +4 -4
  57. sonusai/mixture/truth_functions/metrics.py +4 -4
  58. sonusai/mixture/truth_functions/phoneme.py +3 -3
  59. sonusai/mixture/truth_functions/sed.py +11 -13
  60. sonusai/mixture/truth_functions/target.py +10 -10
  61. sonusai/mkwav.py +26 -29
  62. sonusai/onnx_predict.py +240 -88
  63. sonusai/queries/__init__.py +2 -2
  64. sonusai/queries/queries.py +38 -34
  65. sonusai/speech/librispeech.py +1 -1
  66. sonusai/speech/mcgill.py +1 -1
  67. sonusai/speech/timit.py +2 -2
  68. sonusai/summarize_metric_spenh.py +10 -17
  69. sonusai/utils/__init__.py +7 -1
  70. sonusai/utils/asl_p56.py +2 -2
  71. sonusai/utils/asr.py +2 -2
  72. sonusai/utils/asr_functions/aaware_whisper.py +4 -5
  73. sonusai/utils/choice.py +31 -0
  74. sonusai/utils/compress.py +1 -1
  75. sonusai/utils/dataclass_from_dict.py +19 -1
  76. sonusai/utils/energy_f.py +3 -3
  77. sonusai/utils/evaluate_random_rule.py +15 -0
  78. sonusai/utils/keyboard_interrupt.py +12 -0
  79. sonusai/utils/onnx_utils.py +3 -17
  80. sonusai/utils/print_mixture_details.py +21 -19
  81. sonusai/utils/{temp_seed.py → rand.py} +3 -3
  82. sonusai/utils/read_predict_data.py +2 -2
  83. sonusai/utils/reshape.py +3 -3
  84. sonusai/utils/stratified_shuffle_split.py +3 -3
  85. sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
  86. sonusai/utils/write_audio.py +2 -2
  87. sonusai/vars.py +11 -4
  88. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
  89. sonusai-1.0.2.dist-info/RECORD +138 -0
  90. sonusai/mixture/augmentation.py +0 -444
  91. sonusai/mixture/class_count.py +0 -15
  92. sonusai/mixture/eq_rule_is_valid.py +0 -45
  93. sonusai/mixture/target_class_balancing.py +0 -107
  94. sonusai/mixture/targets.py +0 -175
  95. sonusai-0.20.3.dist-info/RECORD +0 -128
  96. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
  97. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,386 @@
1
+ # sai_rand
2
+ # sai_rand_choice
3
+ # sai_rand_choice_nr
4
+ # sai_sequence
5
+ # sai_expand
6
+
7
+ from ..datatypes import AudioT
8
+ from ..datatypes import Effects
9
+ from .mixdb import MixtureDatabase
10
+
11
+
12
+ def get_effect_rules(location: str, config: dict, test: bool = False) -> dict[str, list[Effects]]:
13
+ from ..datatypes import Effects
14
+ from ..utils.dataclass_from_dict import list_dataclass_from_dict
15
+ from .mixdb import MixtureDatabase
16
+
17
+ mixdb = MixtureDatabase(location, test)
18
+
19
+ rules: dict[str, list[Effects]] = {}
20
+ for category, source in config["sources"].items():
21
+ processed_rules: list[dict] = []
22
+ for rule in source["effects"]:
23
+ rule = _parse_ir_rule(rule, mixdb.num_ir_files)
24
+ processed_rules = _expand_effect_rules(processed_rules, rule)
25
+ rules[category] = list_dataclass_from_dict(list[Effects], processed_rules)
26
+
27
+ validate_rules(mixdb, rules)
28
+ return rules
29
+
30
+
31
+ def sai_expand(text: str) -> list[str]:
32
+ import re
33
+
34
+ # search pattern
35
+ pattern = re.compile(r"sai_expand\((.+?)\)")
36
+
37
+ # initialize with input
38
+ expanded = [text]
39
+
40
+ # look for pattern
41
+ result = re.search(pattern, text)
42
+
43
+ # if found
44
+ if result:
45
+ # remove entry we are expanding
46
+ expanded.pop()
47
+
48
+ # convert match into list stripped of whitespace
49
+ values = result.group(1).replace(" ", "").split(",")
50
+
51
+ # loop over values
52
+ for value in values:
53
+ # replace pattern with value
54
+ replacement = re.sub(pattern, value, text, count=1)
55
+
56
+ # extend result with expand of replacement (for handling multiple expands in a single rule)
57
+ expanded.extend(sai_expand(replacement))
58
+
59
+ return expanded
60
+
61
+
62
+ def _expand_effect_rules(expanded_rules: list[dict], rule: dict) -> list[dict]:
63
+ from copy import deepcopy
64
+
65
+ for key in ("pre", "post"):
66
+ if key in rule:
67
+ value = rule[key]
68
+ for idx in range(len(value)):
69
+ new_rules = sai_expand(value[idx])
70
+ if len(new_rules) > 1:
71
+ for new_rule in new_rules:
72
+ expanded_effect = deepcopy(rule)
73
+ new_value = deepcopy(value)
74
+ new_value[idx] = new_rule
75
+ expanded_effect[key] = new_value
76
+ _expand_effect_rules(expanded_rules, expanded_effect)
77
+ return expanded_rules
78
+
79
+ expanded_rules.append(rule)
80
+ return expanded_rules
81
+
82
+
83
+ def _parse_ir_rule(rule: dict, num_ir: int) -> dict:
84
+ from ..datatypes import EffectList
85
+ from .helpers import generic_ids_to_list
86
+
87
+ def _resolve_str(parameters: str) -> str:
88
+ if parameters.startswith("sai_"):
89
+ return f"ir {parameters}"
90
+
91
+ irs = generic_ids_to_list(num_ir, parameters)
92
+
93
+ if not all(ro in range(num_ir) for ro in irs):
94
+ raise ValueError(f"Invalid ir of {parameters}")
95
+
96
+ if len(irs) == 1:
97
+ return f"ir {irs[0]}"
98
+ return f"ir sai_expand({', '.join(map(str, irs))})"
99
+
100
+ def _process(rules_in: EffectList) -> EffectList:
101
+ rules_out: EffectList = []
102
+
103
+ for rule_in in rules_in:
104
+ parts = rule_in.split(maxsplit=1)
105
+
106
+ name = parts[0]
107
+ if name != "ir":
108
+ rules_out.append(rule_in)
109
+ continue
110
+
111
+ if len(parts) == 1:
112
+ continue
113
+
114
+ parameters = parts[1]
115
+ if parameters.isnumeric():
116
+ ir = int(parameters)
117
+ if ir not in range(num_ir):
118
+ raise ValueError(f"Invalid ir of {parameters}")
119
+ rules_out.append(rule_in)
120
+ continue
121
+
122
+ if isinstance(parameters, str):
123
+ rules_out.append(_resolve_str(parameters))
124
+ continue
125
+
126
+ raise ValueError(f"Invalid ir of {parameters}")
127
+
128
+ return rules_out
129
+
130
+ for key in ("pre", "post"):
131
+ if key in rule:
132
+ rule[key] = _process(rule[key])
133
+
134
+ return rule
135
+
136
+
137
+ def apply_effects(
138
+ mixdb: MixtureDatabase,
139
+ audio: AudioT,
140
+ effects: Effects,
141
+ pre: bool = True,
142
+ post: bool = True,
143
+ ) -> AudioT:
144
+ """Apply effects to audio data
145
+
146
+ :param mixdb: Mixture database
147
+ :param audio: Input audio
148
+ :param effects: Effects
149
+ :param pre: Apply pre-truth effects
150
+ :param post: Apply post-truth effects
151
+ :return: Output audio
152
+ """
153
+ from ..datatypes import EffectList
154
+ from .ir_effects import apply_ir
155
+ from .ir_effects import read_ir
156
+ from .sox_effects import apply_sox_effects
157
+
158
+ def _process(audio_in: AudioT, effects_in) -> AudioT:
159
+ _effects: EffectList = []
160
+ for effect in effects_in:
161
+ if effect.startswith("ir "):
162
+ # Apply effects gathered so far
163
+ audio_in = apply_sox_effects(audio_in, _effects)
164
+
165
+ # Then empty the effects list
166
+ _effects = []
167
+
168
+ # Apply IR
169
+ index = int(effect.split()[1])
170
+ audio_in = apply_ir(
171
+ audio=audio_in,
172
+ ir=read_ir(
173
+ name=mixdb.ir_file(index),
174
+ delay=mixdb.ir_delay(index),
175
+ use_cache=mixdb.use_cache,
176
+ ),
177
+ )
178
+ else:
179
+ _effects.append(effect)
180
+
181
+ return apply_sox_effects(audio_in, _effects)
182
+
183
+ audio_out = audio.copy()
184
+
185
+ if pre:
186
+ audio_out = _process(audio_out, effects.pre)
187
+
188
+ if post:
189
+ audio_out = _process(audio_out, effects.post)
190
+
191
+ return audio_out
192
+
193
+
194
+ def estimate_effected_length(
195
+ samples: int,
196
+ effects: Effects,
197
+ frame_length: int = 1,
198
+ pre: bool = True,
199
+ post: bool = True,
200
+ ) -> int:
201
+ """Estimate effected audio length
202
+
203
+ :param samples: Original length in samples
204
+ :param effects: Effects
205
+ :param frame_length: Length will be a multiple of this
206
+ :param pre: Apply pre-truth effects
207
+ :param post: Apply post-truth effects
208
+ :return: Estimated length in samples
209
+ """
210
+ from .pad_audio import get_padded_length
211
+
212
+ def _update_samples(s: int, e: str) -> int:
213
+ import re
214
+
215
+ # speed factor[c]
216
+ speed_pattern = re.compile(r"^speed\s+(-?\d+(\.\d+)*)(c?)$")
217
+ result = re.search(speed_pattern, e)
218
+ if result:
219
+ value = float(result.group(1))
220
+ if result.group(3):
221
+ value = float(2 ** (value / 1200))
222
+ return int(s / value + 0.5)
223
+
224
+ # tempo [-q] [-m|-s|-l] factor [segment [search [overlap]]]
225
+ tempo_pattern = re.compile(r"^tempo\s+(-q\s+)?(((-m)|(-s)|(-l))\s+)?(\d+(\.\d+)*)")
226
+ result = re.search(tempo_pattern, e)
227
+ if result:
228
+ value = float(result.group(7))
229
+ return int(s / value + 0.5)
230
+
231
+ # other effects which do not affect length
232
+ return s
233
+
234
+ length = samples
235
+
236
+ if pre:
237
+ for effect in effects.pre:
238
+ length = _update_samples(length, effect)
239
+
240
+ if post:
241
+ for effect in effects.post:
242
+ length = _update_samples(length, effect)
243
+
244
+ return get_padded_length(length, frame_length)
245
+
246
+
247
+ def evaluate_sai_random_float(text: str) -> str:
248
+ """Evaluate 'sai_rand(min, max)' directive
249
+
250
+ :param text: Text to evaluate
251
+ :return: Resolved rule
252
+ """
253
+ import re
254
+ from random import uniform
255
+
256
+ def rand_repl(m):
257
+ value = uniform(float(m.group(1)), float(m.group(4))) # noqa: S311
258
+ return f"{value:.2f}"
259
+
260
+ rand_pattern = re.compile(r"sai_rand\(([-+]?(\d+(\.\d*)?|\.\d+)),\s*([-+]?(\d+(\.\d*)?|\.\d+))\)")
261
+
262
+ resolved = text
263
+ count = 0
264
+ while re.findall(rand_pattern, resolved) and count < 100:
265
+ try:
266
+ resolved = re.sub(rand_pattern, rand_repl, resolved)
267
+ count += 1
268
+ except Exception as e:
269
+ raise ValueError(f"Invalid rule: '{text}'.") from e
270
+
271
+ if count == 100:
272
+ raise ValueError(f"Invalid rule: '{text}'.")
273
+
274
+ return resolved
275
+
276
+
277
+ def evaluate_sai_random_ir(mixdb: MixtureDatabase, text: str) -> str:
278
+ """Evaluate 'sai_rand' directive for ir
279
+
280
+ :param mixdb: Mixture database
281
+ :param text: Text to evaluate
282
+ :return: Resolved value
283
+ """
284
+ import re
285
+ from random import choice
286
+ from random import randint
287
+
288
+ rand_pattern = re.compile(r"^ir sai_rand$")
289
+ rand_range_pattern = re.compile(r"^ir sai_rand\(([-+]?\d+),\s*([-+]?\d+)\)$")
290
+ rand_tag_pattern = re.compile(r"^ir sai_rand\((\w+)\)$")
291
+
292
+ def rand_range_repl(m) -> str:
293
+ lower = int(m.group(1))
294
+ upper = int(m.group(2))
295
+ if (
296
+ lower < 0
297
+ or lower >= mixdb.num_ir_files
298
+ or upper < 0
299
+ or upper >= mixdb.num_ir_files
300
+ or lower >= upper
301
+ or str(lower) != m.group(1)
302
+ or str(upper) != m.group(2)
303
+ ):
304
+ raise ValueError(f"Invalid rule: '{text}'. Values must be integers between 0 and {mixdb.num_ir_files - 1}.")
305
+ return f"ir {randint(lower, upper)}" # noqa: S311
306
+
307
+ def rand_tag_repl(m) -> str:
308
+ return m.group(1)
309
+
310
+ if re.match(rand_pattern, text):
311
+ return f"ir {randint(0, mixdb.num_ir_files)}" # noqa: S311
312
+
313
+ if re.match(rand_range_pattern, text):
314
+ try:
315
+ return f"ir {eval(re.sub(rand_range_pattern, rand_range_repl, text))}" # noqa: S307
316
+ except Exception as e:
317
+ raise ValueError(
318
+ f"Invalid rule: '{text}'. Values must be integers between 0 and {mixdb.num_ir_files - 1}."
319
+ ) from e
320
+
321
+ if re.match(rand_tag_pattern, text):
322
+ tag = re.sub(rand_tag_pattern, rand_tag_repl, text)
323
+ if tag in mixdb.ir_tags:
324
+ return f"ir {choice(mixdb.ir_file_ids_for_tag(tag))}" # noqa: S311
325
+
326
+ raise ValueError(f"Invalid rule: '{text}'. Tag, '{tag}', not found in database.")
327
+
328
+ raise ValueError(f"Invalid rule: '{text}'.")
329
+
330
+
331
+ def effects_from_rules(mixdb: MixtureDatabase, rules: Effects) -> Effects:
332
+ from copy import deepcopy
333
+
334
+ effects = deepcopy(rules)
335
+ for key in ("pre", "post"):
336
+ entries = getattr(effects, key)
337
+ for idx, entry in enumerate(entries):
338
+ if entry.find("sai_rand") != -1:
339
+ if entry.startswith("ir"):
340
+ entries[idx] = evaluate_sai_random_ir(mixdb, entry)
341
+ else:
342
+ entries[idx] = evaluate_sai_random_float(entry)
343
+ setattr(effects, key, entries)
344
+
345
+ return effects
346
+
347
+
348
+ def conform_audio_to_length(audio: AudioT, length: int, repeat: bool, start: int) -> AudioT:
349
+ """Conform audio to given length
350
+
351
+ :param audio: Audio to conform
352
+ :param length: Length of output
353
+ :param repeat: Repeat samples or pad
354
+ :param start: Starting sample offset
355
+ :return: Conformed audio
356
+ """
357
+ import numpy as np
358
+
359
+ if repeat:
360
+ return np.take(audio, range(start, start + length), mode="wrap")
361
+
362
+ end = length + start
363
+ return np.pad(audio[start:], (0, end - len(audio)))
364
+
365
+
366
+ def validate_rules(mixdb: MixtureDatabase, rules: dict[str, list[Effects]]) -> None:
367
+ from .sox_effects import validate_sox_effects
368
+
369
+ for rule_list in rules.values():
370
+ for rule in rule_list:
371
+ sox_effects: list[str] = []
372
+ effects = effects_from_rules(mixdb, rule)
373
+
374
+ for effect in effects.pre:
375
+ if not effect.startswith("ir"):
376
+ sox_effects.append(effect)
377
+
378
+ for effect in effects.post:
379
+ for check in ("speed", "tempo"):
380
+ if check in effect:
381
+ raise ValueError(f"'{check}' effect is not allowed in post-truth effect chain.")
382
+
383
+ if not effect.startswith("ir"):
384
+ sox_effects.append(effect)
385
+
386
+ validate_sox_effects(sox_effects)
@@ -1,11 +1,8 @@
1
- from sonusai.mixture.datatypes import AudioT
2
- from sonusai.mixture.datatypes import Feature
1
+ from ..datatypes import AudioT
2
+ from ..datatypes import Feature
3
3
 
4
4
 
5
- def get_feature_from_audio(
6
- audio: AudioT,
7
- feature_mode: str,
8
- ) -> Feature:
5
+ def get_feature_from_audio(audio: AudioT, feature_mode: str) -> Feature:
9
6
  """Apply forward transform and generate feature data from audio data
10
7
 
11
8
  :param audio: Time domain audio data [samples]
@@ -14,7 +11,7 @@ def get_feature_from_audio(
14
11
  """
15
12
  from pyaaware import FeatureGenerator
16
13
 
17
- from .datatypes import TransformConfig
14
+ from ..datatypes import TransformConfig
18
15
  from .helpers import forward_transform
19
16
 
20
17
  fg = FeatureGenerator(feature_mode=feature_mode)
@@ -43,10 +40,9 @@ def get_audio_from_feature(feature: Feature, feature_mode: str) -> AudioT:
43
40
  import numpy as np
44
41
  from pyaaware import FeatureGenerator
45
42
 
46
- from sonusai.utils.compress import power_uncompress
47
- from sonusai.utils.stacked_complex import unstack_complex
48
-
49
- from .datatypes import TransformConfig
43
+ from ..datatypes import TransformConfig
44
+ from ..utils.compress import power_uncompress
45
+ from ..utils.stacked_complex import unstack_complex
50
46
  from .helpers import inverse_transform
51
47
 
52
48
  if feature.ndim != 3: