sonusai 0.18.9__py3-none-any.whl → 0.19.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +93 -77
  13. sonusai/genmetrics.py +59 -46
  14. sonusai/genmix.py +116 -104
  15. sonusai/genmixdb.py +194 -153
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +15 -17
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +19 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +40 -38
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +669 -477
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +52 -85
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +40 -27
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
  113. sonusai-0.19.5.dist-info/RECORD +125 -0
  114. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.9.dist-info/RECORD +0 -125
  118. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
@@ -1,10 +1,9 @@
1
- from typing import Optional
2
-
3
1
  from sonusai.mixture.datatypes import AudioT
4
2
  from sonusai.mixture.datatypes import Augmentation
5
3
  from sonusai.mixture.datatypes import AugmentationRule
6
4
  from sonusai.mixture.datatypes import AugmentationRules
7
5
  from sonusai.mixture.datatypes import ImpulseResponseData
6
+ from sonusai.mixture.datatypes import OptionalNumberStr
8
7
 
9
8
 
10
9
  def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> AugmentationRules:
@@ -15,6 +14,7 @@ def get_augmentation_rules(rules: list[dict] | dict, num_ir: int = 0) -> Augment
15
14
  :return: List of augmentation rules
16
15
  """
17
16
  from sonusai.utils import dataclass_from_dict
17
+
18
18
  from .datatypes import AugmentationRule
19
19
 
20
20
  processed_rules: list[dict] = []
@@ -37,8 +37,8 @@ def _expand_rules(expanded_rules: list[dict], rule: dict) -> list[dict]:
37
37
  """
38
38
  from copy import deepcopy
39
39
 
40
- from sonusai import SonusAIError
41
40
  from sonusai.utils import convert_string_to_number
41
+
42
42
  from .constants import VALID_AUGMENTATIONS
43
43
  from .eq_rule_is_valid import eq_rule_is_valid
44
44
 
@@ -47,46 +47,44 @@ def _expand_rules(expanded_rules: list[dict], rule: dict) -> list[dict]:
47
47
  del rule[key]
48
48
 
49
49
  # replace old 'eq' rule with new 'eq1' rule to allow both for backward compatibility
50
- rule = {'eq1' if key == 'eq' else key: value for key, value in rule.items()}
50
+ rule = {"eq1" if key == "eq" else key: value for key, value in rule.items()}
51
51
 
52
52
  for key in rule:
53
53
  if key not in VALID_AUGMENTATIONS:
54
- nice_list = '\n'.join([f' {item}' for item in VALID_AUGMENTATIONS])
55
- raise SonusAIError(f'Invalid augmentation: {key}.\nValid augmentations are:\n{nice_list}')
54
+ nice_list = "\n".join([f" {item}" for item in VALID_AUGMENTATIONS])
55
+ raise ValueError(f"Invalid augmentation: {key}.\nValid augmentations are:\n{nice_list}")
56
56
 
57
- if key in ['eq1', 'eq2', 'eq3']:
57
+ if key in ["eq1", "eq2", "eq3"]:
58
58
  if not eq_rule_is_valid(rule[key]):
59
- raise SonusAIError(f'Invalid augmentation value for {key}: {rule[key]}')
59
+ raise ValueError(f"Invalid augmentation value for {key}: {rule[key]}")
60
60
 
61
- if all(isinstance(el, list) or (isinstance(el, str) and el == 'none') for el in rule[key]):
61
+ if all(isinstance(el, list) or (isinstance(el, str) and el == "none") for el in rule[key]):
62
62
  # Expand multiple rules
63
63
  for value in rule[key]:
64
64
  expanded_rule = deepcopy(rule)
65
- if isinstance(value, str) and value == 'none':
65
+ if isinstance(value, str) and value == "none":
66
66
  expanded_rule[key] = None
67
67
  else:
68
68
  expanded_rule[key] = deepcopy(value)
69
69
  _expand_rules(expanded_rules, expanded_rule)
70
70
  return expanded_rules
71
71
 
72
- elif key in ['mixup']:
72
+ elif key in ["mixup"]:
73
73
  pass
74
74
 
75
75
  else:
76
76
  if isinstance(rule[key], list):
77
77
  for value in rule[key]:
78
78
  if isinstance(value, list):
79
- raise SonusAIError(f'Invalid augmentation value for {key}: {rule[key]}')
79
+ raise TypeError(f"Invalid augmentation value for {key}: {rule[key]}")
80
80
  expanded_rule = deepcopy(rule)
81
81
  expanded_rule[key] = deepcopy(value)
82
82
  _expand_rules(expanded_rules, expanded_rule)
83
83
  return expanded_rules
84
84
  else:
85
85
  rule[key] = convert_string_to_number(rule[key])
86
- if not (isinstance(rule[key], float | int) or
87
- rule[key].startswith('rand') or
88
- rule[key] == 'none'):
89
- raise SonusAIError(f'Invalid augmentation value for {key}: {rule[key]}')
86
+ if not (isinstance(rule[key], float | int) or rule[key].startswith("rand") or rule[key] == "none"):
87
+ raise ValueError(f"Invalid augmentation value for {key}: {rule[key]}")
90
88
 
91
89
  expanded_rules.append(rule)
92
90
  return expanded_rules
@@ -102,7 +100,7 @@ def _generate_none_rule(rule: dict) -> dict:
102
100
 
103
101
  out_rule = deepcopy(rule)
104
102
  for key in out_rule:
105
- if out_rule[key] == 'none':
103
+ if out_rule[key] == "none":
106
104
  out_rule[key] = None
107
105
 
108
106
  return out_rule
@@ -120,20 +118,20 @@ def _generate_random_rule(rule: dict, num_ir: int = 0) -> dict:
120
118
 
121
119
  out_rule = deepcopy(rule)
122
120
  for key in out_rule:
123
- if key == 'ir' and out_rule[key] == 'rand':
121
+ if key == "ir" and out_rule[key] == "rand":
124
122
  # IR is special case
125
123
  if num_ir == 0:
126
124
  out_rule[key] = None
127
125
  else:
128
- out_rule[key] = randint(0, num_ir - 1)
126
+ out_rule[key] = randint(0, num_ir - 1) # noqa: S311
129
127
  else:
130
128
  out_rule[key] = evaluate_random_rule(str(out_rule[key]))
131
129
 
132
130
  # convert EQ values from strings to numbers
133
- if key in ['eq1', 'eq2', 'eq3']:
131
+ if key in ["eq1", "eq2", "eq3"]:
134
132
  for n in range(3):
135
133
  if isinstance(out_rule[key][n], str):
136
- out_rule[key][n] = eval(out_rule[key][n])
134
+ out_rule[key][n] = eval(out_rule[key][n]) # noqa: S307
137
135
 
138
136
  return out_rule
139
137
 
@@ -144,14 +142,10 @@ def _rule_has_rand(rule: dict) -> bool:
144
142
  :param rule: Rule
145
143
  :return: True if rule contains 'rand'
146
144
  """
147
- for key in rule:
148
- if 'rand' in str(rule[key]):
149
- return True
150
-
151
- return False
145
+ return any("rand" in str(rule[key]) for key in rule)
152
146
 
153
147
 
154
- def estimate_augmented_length_from_length(length: int, tempo: Optional[float] = None, frame_length: int = 1) -> int:
148
+ def estimate_augmented_length_from_length(length: int, tempo: OptionalNumberStr = None, frame_length: int = 1) -> int:
155
149
  """Estimate the length of audio after augmentation
156
150
 
157
151
  :param length: Number of samples in audio
@@ -162,7 +156,7 @@ def estimate_augmented_length_from_length(length: int, tempo: Optional[float] =
162
156
  import numpy as np
163
157
 
164
158
  if tempo is not None:
165
- length = int(np.round(length / tempo))
159
+ length = int(np.round(length / float(tempo)))
166
160
 
167
161
  length = _get_padded_length(length, frame_length)
168
162
 
@@ -175,7 +169,7 @@ def get_mixups(augmentations: AugmentationRules) -> list[int]:
175
169
  :param augmentations: List of augmentations
176
170
  :return: List of mixup values used
177
171
  """
178
- return sorted(list(set([augmentation.mixup for augmentation in augmentations])))
172
+ return sorted({augmentation.mixup for augmentation in augmentations})
179
173
 
180
174
 
181
175
  def get_augmentation_indices_for_mixup(augmentations: AugmentationRules, mixup: int) -> list[int]:
@@ -249,53 +243,52 @@ def evaluate_random_rule(rule: str) -> str | float:
249
243
  from .constants import RAND_PATTERN
250
244
 
251
245
  def rand_repl(m):
252
- return f'{uniform(float(m.group(1)), float(m.group(4))):.2f}'
246
+ return f"{uniform(float(m.group(1)), float(m.group(4))):.2f}" # noqa: S311
253
247
 
254
- return eval(re.sub(RAND_PATTERN, rand_repl, rule))
248
+ return eval(re.sub(RAND_PATTERN, rand_repl, rule)) # noqa: S307
255
249
 
256
250
 
257
251
  def _parse_ir(rule: dict, num_ir: int) -> dict:
258
- from sonusai import SonusAIError
259
252
  from .helpers import generic_ids_to_list
260
253
 
261
254
  def _resolve_str(rule_in: str) -> str | list[int]:
262
- if rule_in in ['rand', 'none']:
255
+ if rule_in in ["rand", "none"]:
263
256
  return rule_in
264
257
 
265
258
  rule_out = generic_ids_to_list(num_ir, rule_in)
266
259
  if not all(ro in range(num_ir) for ro in rule_out):
267
- raise SonusAIError(f'Invalid ir entry of {rule_in}')
260
+ raise ValueError(f"Invalid ir entry of {rule_in}")
268
261
  return rule_out
269
262
 
270
- if 'ir' not in rule:
263
+ if "ir" not in rule:
271
264
  return rule
272
265
 
273
- ir = rule['ir']
266
+ ir = rule["ir"]
274
267
 
275
268
  if ir is None:
276
269
  return rule
277
270
 
278
271
  if isinstance(ir, str):
279
- rule['ir'] = _resolve_str(ir)
272
+ rule["ir"] = _resolve_str(ir)
280
273
  return rule
281
274
 
282
275
  if isinstance(ir, list):
283
- rule['ir'] = []
276
+ rule["ir"] = []
284
277
  for item in ir:
285
278
  result = _resolve_str(item)
286
279
  if isinstance(result, str):
287
- rule['ir'].append(_resolve_str(item))
280
+ rule["ir"].append(_resolve_str(item))
288
281
  else:
289
- rule['ir'] += _resolve_str(item)
282
+ rule["ir"] += _resolve_str(item)
290
283
 
291
284
  return rule
292
285
 
293
286
  if isinstance(ir, int):
294
287
  if ir not in range(num_ir):
295
- raise SonusAIError(f'Invalid ir of {ir}')
288
+ raise ValueError(f"Invalid ir of {ir}")
296
289
  return rule
297
290
 
298
- raise SonusAIError(f'Invalid ir of {ir}')
291
+ raise ValueError(f"Invalid ir of {ir}")
299
292
 
300
293
 
301
294
  def apply_augmentation(audio: AudioT, augmentation: Augmentation, frame_length: int = 1) -> AudioT:
@@ -325,10 +318,11 @@ def apply_impulse_response(audio: AudioT, ir: ImpulseResponseData) -> AudioT:
325
318
 
326
319
  def augmentation_from_rule(rule: AugmentationRule, num_ir: int) -> Augmentation:
327
320
  from sonusai.utils import dataclass_from_dict
321
+
328
322
  from .datatypes import Augmentation
329
323
 
330
324
  processed_rule = rule.to_dict()
331
- del processed_rule['mixup']
325
+ del processed_rule["mixup"]
332
326
  processed_rule = _generate_none_rule(processed_rule)
333
327
  if _rule_has_rand(processed_rule):
334
328
  processed_rule = _generate_random_rule(processed_rule, num_ir)
@@ -3,22 +3,13 @@ from sonusai.mixture.datatypes import GeneralizedIDs
3
3
  from sonusai.mixture.mixdb import MixtureDatabase
4
4
 
5
5
 
6
- def get_class_count_from_mixids(mixdb: MixtureDatabase, mixids: GeneralizedIDs = None) -> ClassCount:
7
- """ Sums the class counts for given mixids
8
- """
9
- from sonusai import SonusAIError
10
-
6
+ def get_class_count_from_mixids(mixdb: MixtureDatabase, mixids: GeneralizedIDs | None = None) -> ClassCount:
7
+ """Sums the class counts for given mixids"""
11
8
  total_class_count = [0] * mixdb.num_classes
12
- mixids = mixdb.mixids_to_list(mixids)
13
- for mixid in mixids:
14
- class_count = mixdb.mixture_class_count(mixid)
9
+ m_ids = mixdb.mixids_to_list(mixids)
10
+ for m_id in m_ids:
11
+ class_count = mixdb.mixture_class_count(m_id)
15
12
  for cl in range(mixdb.num_classes):
16
13
  total_class_count[cl] += class_count[cl]
17
14
 
18
- if mixdb.truth_mutex:
19
- # Compute the class count for the 'other' class
20
- if total_class_count[-1] != 0:
21
- raise SonusAIError('Error: truth_mutex was set, but the class count for the last count was non-zero.')
22
- total_class_count[-1] = sum([mixdb.mixture(mixid).samples for mixid in mixids]) - sum(total_class_count)
23
-
24
15
  return total_class_count