sonusai 0.18.8__py3-none-any.whl → 0.19.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +93 -77
  13. sonusai/genmetrics.py +59 -46
  14. sonusai/genmix.py +116 -104
  15. sonusai/genmixdb.py +194 -153
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +15 -17
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +19 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +50 -46
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +677 -473
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +52 -85
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +40 -27
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
  113. sonusai-0.19.5.dist-info/RECORD +125 -0
  114. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.8.dist-info/RECORD +0 -125
  118. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
@@ -1,17 +1,13 @@
1
- from typing import Any
2
- from typing import Optional
1
+ from pyaaware import ForwardTransform
2
+ from pyaaware import InverseTransform
3
3
 
4
- from praatio.utilities.constants import Interval
5
-
6
- from sonusai import ForwardTransform
7
- from sonusai import InverseTransform
8
- from sonusai.mixture import EnergyT
9
4
  from sonusai.mixture.datatypes import AudioF
10
- from sonusai.mixture.datatypes import AudioT
11
5
  from sonusai.mixture.datatypes import AudiosT
6
+ from sonusai.mixture.datatypes import AudioT
12
7
  from sonusai.mixture.datatypes import Augmentation
13
8
  from sonusai.mixture.datatypes import AugmentationRules
14
9
  from sonusai.mixture.datatypes import Augmentations
10
+ from sonusai.mixture.datatypes import EnergyT
15
11
  from sonusai.mixture.datatypes import Feature
16
12
  from sonusai.mixture.datatypes import FeatureGeneratorConfig
17
13
  from sonusai.mixture.datatypes import FeatureGeneratorInfo
@@ -25,37 +21,33 @@ from sonusai.mixture.datatypes import Target
25
21
  from sonusai.mixture.datatypes import TargetFiles
26
22
  from sonusai.mixture.datatypes import Targets
27
23
  from sonusai.mixture.datatypes import TransformConfig
28
- from sonusai.mixture.datatypes import Truth
24
+ from sonusai.mixture.datatypes import TruthDict
29
25
  from sonusai.mixture.db_datatypes import MixtureRecord
30
26
  from sonusai.mixture.db_datatypes import TargetRecord
31
27
  from sonusai.mixture.mixdb import MixtureDatabase
32
28
 
33
29
 
34
- def generic_ids_to_list(num_ids: int, ids: GeneralizedIDs = None) -> list[int]:
30
+ def generic_ids_to_list(num_ids: int, ids: GeneralizedIDs = "*") -> list[int]:
35
31
  """Resolve generalized IDs to a list of integers
36
32
 
37
33
  :param num_ids: Total number of indices
38
34
  :param ids: Generalized IDs
39
35
  :return: List of ID integers
40
36
  """
41
- from sonusai import SonusAIError
42
-
43
37
  all_ids = list(range(num_ids))
44
38
 
45
- if ids is None:
46
- return all_ids
47
-
48
39
  if isinstance(ids, str):
49
- if ids == '*':
40
+ if ids == "*":
50
41
  return all_ids
51
42
 
52
43
  try:
53
- result = eval(f'{all_ids}[{ids}]')
54
- if not isinstance(result, list):
55
- result = [result]
56
- return result
57
- except NameError:
58
- raise SonusAIError(f'Empty ids {ids}')
44
+ result = eval(f"{all_ids}[{ids}]") # noqa: S307
45
+ if isinstance(result, list):
46
+ return result
47
+ else:
48
+ return [result]
49
+ except NameError as e:
50
+ raise ValueError(f"Empty ids {ids}: {e}") from e
59
51
 
60
52
  if isinstance(ids, range):
61
53
  result = list(ids)
@@ -65,15 +57,17 @@ def generic_ids_to_list(num_ids: int, ids: GeneralizedIDs = None) -> list[int]:
65
57
  result = ids
66
58
 
67
59
  if not all(isinstance(x, int) and 0 <= x < num_ids for x in result):
68
- raise SonusAIError(f'Invalid entries in ids of {ids}')
60
+ raise ValueError(f"Invalid entries in ids of {ids}")
69
61
 
70
62
  if not result:
71
- raise SonusAIError(f'Empty ids {ids}')
63
+ raise ValueError(f"Empty ids {ids}")
72
64
 
73
65
  return result
74
66
 
75
67
 
76
- def get_feature_generator_info(fg_config: FeatureGeneratorConfig) -> FeatureGeneratorInfo:
68
+ def get_feature_generator_info(
69
+ fg_config: FeatureGeneratorConfig,
70
+ ) -> FeatureGeneratorInfo:
77
71
  from dataclasses import asdict
78
72
 
79
73
  from pyaaware import FeatureGenerator
@@ -88,49 +82,36 @@ def get_feature_generator_info(fg_config: FeatureGeneratorConfig) -> FeatureGene
88
82
  stride=fg.stride,
89
83
  step=fg.step,
90
84
  feature_parameters=fg.feature_parameters,
91
- ft_config=TransformConfig(N=fg.ftransform_N,
92
- R=fg.ftransform_R,
93
- bin_start=fg.bin_start,
94
- bin_end=fg.bin_end,
95
- ttype=fg.ftransform_ttype),
96
- eft_config=TransformConfig(N=fg.eftransform_N,
97
- R=fg.eftransform_R,
98
- bin_start=fg.bin_start,
99
- bin_end=fg.bin_end,
100
- ttype=fg.eftransform_ttype),
101
- it_config=TransformConfig(N=fg.itransform_N,
102
- R=fg.itransform_R,
103
- bin_start=fg.bin_start,
104
- bin_end=fg.bin_end,
105
- ttype=fg.itransform_ttype)
85
+ ft_config=TransformConfig(
86
+ length=fg.ftransform_length,
87
+ overlap=fg.ftransform_overlap,
88
+ bin_start=fg.bin_start,
89
+ bin_end=fg.bin_end,
90
+ ttype=fg.ftransform_ttype,
91
+ ),
92
+ eft_config=TransformConfig(
93
+ length=fg.eftransform_length,
94
+ overlap=fg.eftransform_overlap,
95
+ bin_start=fg.bin_start,
96
+ bin_end=fg.bin_end,
97
+ ttype=fg.eftransform_ttype,
98
+ ),
99
+ it_config=TransformConfig(
100
+ length=fg.itransform_length,
101
+ overlap=fg.itransform_overlap,
102
+ bin_start=fg.bin_start,
103
+ bin_end=fg.bin_end,
104
+ ttype=fg.itransform_ttype,
105
+ ),
106
106
  )
107
107
 
108
108
 
109
- def write_mixture_data(mixdb: MixtureDatabase,
110
- mixture: Mixture,
111
- items: list[tuple[str, Any]] | tuple[str, Any]) -> None:
112
- """Write mixture data to a mixture HDF5 file
113
-
114
- :param mixdb: Mixture database
115
- :param mixture: Mixture record
116
- :param items: Tuple(s) of (name, data)
117
- """
118
- import h5py
119
-
120
- if not isinstance(items, list):
121
- items = [items]
122
-
123
- name = mixdb.location_filename(mixture.name)
124
- with h5py.File(name=name, mode='a') as f:
125
- for item in items:
126
- if item[0] in f:
127
- del f[item[0]]
128
- f.create_dataset(name=item[0], data=item[1])
109
+ def mixture_all_speech_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> list[dict[str, SpeechMetadata]]:
110
+ """Get a list of all speech metadata for the given mixture"""
111
+ from praatio.utilities.constants import Interval
129
112
 
113
+ from .datatypes import SpeechMetadata
130
114
 
131
- def mixture_all_speech_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> list[dict[str, SpeechMetadata]]:
132
- """Get a list of all speech metadata for the given mixture
133
- """
134
115
  results: list[dict[str, SpeechMetadata]] = []
135
116
  for target in mixture.targets:
136
117
  data: dict[str, SpeechMetadata] = {}
@@ -144,9 +125,13 @@ def mixture_all_speech_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> lis
144
125
  entries = []
145
126
  for entry in item:
146
127
  if target.augmentation.tempo is not None:
147
- entries.append(Interval(entry.start / target.augmentation.tempo,
148
- entry.end / target.augmentation.tempo,
149
- entry.label))
128
+ entries.append(
129
+ Interval(
130
+ entry.start / target.augmentation.tempo,
131
+ entry.end / target.augmentation.tempo,
132
+ entry.label,
133
+ )
134
+ )
150
135
  else:
151
136
  entries.append(entry)
152
137
  data[tier] = entries
@@ -164,41 +149,32 @@ def mixture_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> str:
164
149
  :param mixture: Mixture record
165
150
  :return: String of metadata
166
151
  """
167
- metadata = ''
152
+ metadata = ""
168
153
  speech_metadata = mixture_all_speech_metadata(mixdb, mixture)
169
154
  for mi, target in enumerate(mixture.targets):
170
155
  target_file = mixdb.target_file(target.file_id)
171
156
  target_augmentation = target.augmentation
172
- metadata += f'target {mi} name: {target_file.name}\n'
173
- metadata += f'target {mi} augmentation: {target.augmentation.to_dict()}\n'
174
- if target_augmentation.ir is None:
175
- ir_name = None
176
- else:
177
- ir_name = mixdb.impulse_response_file(target_augmentation.ir)
178
- metadata += f'target {mi} ir: {ir_name}\n'
179
- metadata += f'target {mi} target_gain: {target.gain}\n'
180
- truth_settings = target_file.truth_settings
181
- for tsi in range(len(truth_settings)):
182
- metadata += f'target {mi} truth index {tsi}: {truth_settings[tsi].index}\n'
183
- metadata += f'target {mi} truth function {tsi}: {truth_settings[tsi].function}\n'
184
- metadata += f'target {mi} truth config {tsi}: {truth_settings[tsi].config}\n'
185
- for key in speech_metadata[mi].keys():
186
- metadata += f'target {mi} speech {key}: {speech_metadata[mi][key]}\n'
157
+ metadata += f"target {mi} name: {target_file.name}\n"
158
+ metadata += f"target {mi} augmentation: {target.augmentation.to_dict()}\n"
159
+ metadata += f"target {mi} ir: {mixdb.impulse_response_file(target_augmentation.ir)}\n"
160
+ metadata += f"target {mi} target_gain: {target.gain}\n"
161
+ metadata += f"target {mi} class indices: {target_file.class_indices}\n"
162
+ for key in target_file.truth_configs:
163
+ metadata += f"target {mi} truth '{key}' function: {target_file.truth_configs[key].function}\n"
164
+ metadata += f"target {mi} truth '{key}' config: {target_file.truth_configs[key].config}\n"
165
+ for key in speech_metadata[mi]:
166
+ metadata += f"target {mi} speech {key}: {speech_metadata[mi][key]}\n"
187
167
  noise = mixdb.noise_file(mixture.noise.file_id)
188
168
  noise_augmentation = mixture.noise.augmentation
189
- metadata += f'noise name: {noise.name}\n'
190
- metadata += f'noise augmentation: {noise_augmentation.to_dict()}\n'
191
- if noise_augmentation.ir is None:
192
- ir_name = None
193
- else:
194
- ir_name = mixdb.impulse_response_file(noise_augmentation.ir)
195
- metadata += f'noise ir: {ir_name}\n'
196
- metadata += f'noise offset: {mixture.noise.offset}\n'
197
- metadata += f'snr: {mixture.snr}\n'
198
- metadata += f'random_snr: {mixture.snr.is_random}\n'
199
- metadata += f'samples: {mixture.samples}\n'
200
- metadata += f'target_snr_gain: {float(mixture.target_snr_gain)}\n'
201
- metadata += f'noise_snr_gain: {float(mixture.noise_snr_gain)}\n'
169
+ metadata += f"noise name: {noise.name}\n"
170
+ metadata += f"noise augmentation: {noise_augmentation.to_dict()}\n"
171
+ metadata += f"noise ir: {mixdb.impulse_response_file(noise_augmentation.ir)}\n"
172
+ metadata += f"noise offset: {mixture.noise.offset}\n"
173
+ metadata += f"snr: {mixture.snr}\n"
174
+ metadata += f"random_snr: {mixture.snr.is_random}\n"
175
+ metadata += f"samples: {mixture.samples}\n"
176
+ metadata += f"target_snr_gain: {float(mixture.target_snr_gain)}\n"
177
+ metadata += f"noise_snr_gain: {float(mixture.noise_snr_gain)}\n"
202
178
 
203
179
  return metadata
204
180
 
@@ -209,47 +185,54 @@ def write_mixture_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> None:
209
185
  :param mixdb: Mixture database
210
186
  :param mixture: Mixture record
211
187
  """
212
- from os.path import splitext
188
+ from os.path import join
213
189
 
214
- name = mixdb.location_filename(splitext(mixture.name)[0] + '.txt')
215
- with open(file=name, mode='w') as f:
190
+ name = join(mixdb.location, "mixture", mixture.name, "metadata.txt")
191
+ with open(file=name, mode="w") as f:
216
192
  f.write(mixture_metadata(mixdb, mixture))
217
193
 
218
194
 
219
- def from_mixture(mixture: Mixture) -> tuple[str, int, str, int, float, bool, float, int, int, int, float]:
220
- return (mixture.name,
221
- mixture.noise.file_id,
222
- mixture.noise.augmentation.to_json(),
223
- mixture.noise.offset,
224
- mixture.noise_snr_gain,
225
- mixture.snr.is_random,
226
- mixture.snr,
227
- mixture.samples,
228
- mixture.spectral_mask_id,
229
- mixture.spectral_mask_seed,
230
- mixture.target_snr_gain)
195
+ def from_mixture(
196
+ mixture: Mixture,
197
+ ) -> tuple[str, int, str, int, float, bool, float, int, int, int, float]:
198
+ return (
199
+ mixture.name,
200
+ mixture.noise.file_id,
201
+ mixture.noise.augmentation.to_json(),
202
+ mixture.noise.offset,
203
+ mixture.noise_snr_gain,
204
+ mixture.snr.is_random,
205
+ mixture.snr,
206
+ mixture.samples,
207
+ mixture.spectral_mask_id,
208
+ mixture.spectral_mask_seed,
209
+ mixture.target_snr_gain,
210
+ )
231
211
 
232
212
 
233
213
  def to_mixture(entry: MixtureRecord, targets: Targets) -> Mixture:
234
214
  import json
235
215
 
236
216
  from sonusai.utils import dataclass_from_dict
237
- from .datatypes import Augmentation
238
- from .datatypes import Mixture
217
+
239
218
  from .datatypes import Noise
240
219
  from .datatypes import UniversalSNR
241
220
 
242
- return Mixture(targets=targets,
243
- name=entry.name,
244
- noise=Noise(file_id=entry.noise_file_id,
245
- augmentation=dataclass_from_dict(Augmentation, json.loads(entry.noise_augmentation)),
246
- offset=entry.noise_offset),
247
- noise_snr_gain=entry.noise_snr_gain,
248
- snr=UniversalSNR(is_random=entry.random_snr, value=entry.snr),
249
- samples=entry.samples,
250
- spectral_mask_id=entry.spectral_mask_id,
251
- spectral_mask_seed=entry.spectral_mask_seed,
252
- target_snr_gain=entry.target_snr_gain)
221
+ return Mixture(
222
+ targets=targets,
223
+ name=entry.name,
224
+ noise=Noise(
225
+ file_id=entry.noise_file_id,
226
+ augmentation=dataclass_from_dict(Augmentation, json.loads(entry.noise_augmentation)),
227
+ offset=entry.noise_offset,
228
+ ),
229
+ noise_snr_gain=entry.noise_snr_gain,
230
+ snr=UniversalSNR(is_random=entry.random_snr, value=entry.snr),
231
+ samples=entry.samples,
232
+ spectral_mask_id=entry.spectral_mask_id,
233
+ spectral_mask_seed=entry.spectral_mask_seed,
234
+ target_snr_gain=entry.target_snr_gain,
235
+ )
253
236
 
254
237
 
255
238
  def from_target(target: Target) -> tuple[int, str, float]:
@@ -260,105 +243,67 @@ def to_target(entry: TargetRecord) -> Target:
260
243
  import json
261
244
 
262
245
  from sonusai.utils import dataclass_from_dict
246
+
263
247
  from .datatypes import Augmentation
264
248
  from .datatypes import Target
265
249
 
266
- return Target(file_id=entry.file_id,
267
- augmentation=dataclass_from_dict(Augmentation, json.loads(entry.augmentation)),
268
- gain=entry.gain)
269
-
270
-
271
- def read_mixture_data(name: str, items: list[str] | str) -> Any:
272
- """Read mixture data from a mixture HDF5 file
273
-
274
- :param name: Mixture file name
275
- :param items: String(s) of dataset(s) to retrieve
276
- :return: Data (or tuple of data)
277
- """
278
- from os.path import exists
279
-
280
- import h5py
281
- import numpy as np
282
-
283
- from sonusai import SonusAIError
284
-
285
- def _get_dataset(file: h5py.File, d_name: str) -> Any:
286
- if d_name in file:
287
- data = np.array(file[d_name])
288
- if data.size == 1:
289
- item = data.item()
290
- if isinstance(item, bytes):
291
- return item.decode('utf-8')
292
- return item
293
- return data
294
- return None
295
-
296
- if not isinstance(items, list):
297
- items = [items]
298
-
299
- if exists(name):
300
- try:
301
- with h5py.File(name, 'r') as f:
302
- result = ([_get_dataset(f, item) for item in items])
303
- except Exception as e:
304
- raise SonusAIError(f'Error reading {name}: {e}')
305
- else:
306
- result = ([None for _ in items])
307
-
308
- if len(items) == 1:
309
- result = result[0]
310
-
311
- return result
250
+ return Target(
251
+ file_id=entry.file_id,
252
+ augmentation=dataclass_from_dict(Augmentation, json.loads(entry.augmentation)),
253
+ gain=entry.gain,
254
+ )
312
255
 
313
256
 
314
- def get_truth_t(mixdb: MixtureDatabase,
315
- mixture: Mixture,
316
- targets_audio: AudiosT,
317
- noise_audio: AudioT,
318
- mixture_audio: AudioT) -> Truth:
319
- """Get the truth_t data for the given mixture record
257
+ def get_truth(
258
+ mixdb: MixtureDatabase,
259
+ mixture: Mixture,
260
+ targets_audio: AudiosT,
261
+ noise_audio: AudioT,
262
+ mixture_audio: AudioT,
263
+ ) -> TruthDict:
264
+ """Get the truth data for the given mixture record
320
265
 
321
266
  :param mixdb: Mixture database
322
267
  :param mixture: Mixture record
323
268
  :param targets_audio: List of augmented target audio data (one per target in the mixup) for the given mixture ID
324
269
  :param noise_audio: Augmented noise audio data for the given mixture ID
325
270
  :param mixture_audio: Mixture audio data for the given mixture ID
326
- :return: truth_t data
271
+ :return: truth data
327
272
  """
328
- import numpy as np
329
-
330
- from sonusai import SonusAIError
331
- from .datatypes import TruthFunctionConfig
273
+ from .datatypes import TruthDict
332
274
  from .truth import truth_function
333
275
 
334
276
  if not all(len(target) == mixture.samples for target in targets_audio):
335
- raise SonusAIError('Lengths of targets do not match length of mixture')
277
+ raise ValueError("Lengths of targets do not match length of mixture")
336
278
 
337
279
  if len(noise_audio) != mixture.samples:
338
- raise SonusAIError('Length of noise does not match length of mixture')
280
+ raise ValueError("Length of noise does not match length of mixture")
339
281
 
340
282
  # TODO: Need to understand how to do this correctly for mixup and target_mixture_f truth
341
- truth_t = np.zeros((mixture.samples, mixdb.num_classes), dtype=np.float32)
283
+ if len(targets_audio) != 1:
284
+ raise NotImplementedError("mixup is not implemented")
285
+
286
+ truth: TruthDict = {}
342
287
  for idx in range(len(targets_audio)):
343
- for truth_setting in mixdb.target_file(mixture.targets[idx].file_id).truth_settings:
344
- config = TruthFunctionConfig(
288
+ target_file = mixdb.target_file(mixture.targets[idx].file_id)
289
+ for key, value in target_file.truth_configs.items():
290
+ truth[key] = truth_function(
291
+ target_audio=targets_audio[idx],
292
+ noise_audio=noise_audio,
293
+ mixture_audio=mixture_audio,
294
+ config=value,
345
295
  feature=mixdb.feature,
346
- index=truth_setting.index,
347
- function=truth_setting.function,
348
- config=truth_setting.config,
349
296
  num_classes=mixdb.num_classes,
350
- mutex=mixdb.truth_mutex,
351
- target_gain=mixture.targets[idx].gain * mixture.target_snr_gain
297
+ class_indices=target_file.class_indices,
298
+ target_gain=mixture.targets[idx].gain * mixture.target_snr_gain,
352
299
  )
353
- truth_t += truth_function(target_audio=targets_audio[idx],
354
- noise_audio=noise_audio,
355
- mixture_audio=mixture_audio,
356
- config=config)
357
300
 
358
- return truth_t
301
+ return truth
359
302
 
360
303
 
361
- def get_ft(mixdb: MixtureDatabase, mixture: Mixture, mixture_audio: AudioT, truth_t: Truth) -> tuple[Feature, Truth]:
304
+ def get_ft(
305
+ mixdb: MixtureDatabase, mixture: Mixture, mixture_audio: AudioT, truth_t: TruthDict
306
+ ) -> tuple[Feature, TruthDict]:
362
307
  """Get the feature and truth_f data for the given mixture record
363
308
 
364
309
  :param mixdb: Mixture database
@@ -367,37 +312,19 @@ def get_ft(mixdb: MixtureDatabase, mixture: Mixture, mixture_audio: AudioT, trut
367
312
  :param truth_t: truth_t for the given mixid
368
313
  :return: Tuple of (feature, truth_f) data
369
314
  """
370
- from dataclasses import asdict
371
315
 
372
- import numpy as np
373
316
  from pyaaware import FeatureGenerator
374
317
 
375
- from .truth import truth_reduction
318
+ from .truth import truth_stride_reduction
376
319
 
377
320
  mixture_f = get_mixture_f(mixdb=mixdb, mixture=mixture, mixture_audio=mixture_audio)
378
321
 
379
- transform_frames = frames_from_samples(mixture.samples, mixdb.ft_config.R)
380
- feature_frames = frames_from_samples(mixture.samples, mixdb.feature_step_samples)
322
+ fg = FeatureGenerator(mixdb.fg_config.feature_mode, mixdb.fg_config.truth_parameters)
323
+ feature, truth_f = fg.execute_all(mixture_f, truth_t)
324
+ for name in truth_f:
325
+ truth_f[name] = truth_stride_reduction(truth_f[name], mixdb.truth_configs[name].stride_reduction)
381
326
 
382
- feature = np.empty((feature_frames, mixdb.fg_stride, mixdb.feature_parameters), dtype=np.float32)
383
- truth_f = np.empty((feature_frames, mixdb.num_classes), dtype=np.complex64)
384
-
385
- fg = FeatureGenerator(**asdict(mixdb.fg_config))
386
- feature_frame = 0
387
- for transform_frame in range(transform_frames):
388
- indices = slice(transform_frame * mixdb.ft_config.R, (transform_frame + 1) * mixdb.ft_config.R)
389
- fg.execute(mixture_f[transform_frame],
390
- truth_reduction(truth_t[indices], mixdb.truth_reduction_function))
391
-
392
- if fg.eof():
393
- feature[feature_frame] = fg.feature()
394
- truth_f[feature_frame] = fg.truth()
395
- feature_frame += 1
396
-
397
- if np.isreal(truth_f).all():
398
- return feature, truth_f.real
399
-
400
- return feature, truth_f # type: ignore
327
+ return feature, truth_f
401
328
 
402
329
 
403
330
  def get_segsnr(mixdb: MixtureDatabase, mixture: Mixture, target_audio: AudioT, noise: AudioT) -> Segsnr:
@@ -410,7 +337,7 @@ def get_segsnr(mixdb: MixtureDatabase, mixture: Mixture, target_audio: AudioT, n
410
337
  :return: segsnr data
411
338
  """
412
339
  segsnr_t = get_segsnr_t(mixdb=mixdb, mixture=mixture, target_audio=target_audio, noise_audio=noise)
413
- return segsnr_t[0::mixdb.ft_config.R]
340
+ return segsnr_t[0 :: mixdb.ft_config.overlap]
414
341
 
415
342
 
416
343
  def get_segsnr_t(mixdb: MixtureDatabase, mixture: Mixture, target_audio: AudioT, noise_audio: AudioT) -> Segsnr:
@@ -424,28 +351,29 @@ def get_segsnr_t(mixdb: MixtureDatabase, mixture: Mixture, target_audio: AudioT,
424
351
  """
425
352
  import numpy as np
426
353
  import torch
427
- from sonusai import ForwardTransform
428
-
429
- from sonusai import SonusAIError
430
-
431
- fft = ForwardTransform(N=mixdb.ft_config.N,
432
- R=mixdb.ft_config.R,
433
- bin_start=mixdb.ft_config.bin_start,
434
- bin_end=mixdb.ft_config.bin_end,
435
- ttype=mixdb.ft_config.ttype)
354
+ from pyaaware import ForwardTransform
355
+
356
+ fft = ForwardTransform(
357
+ length=mixdb.ft_config.length,
358
+ overlap=mixdb.ft_config.overlap,
359
+ bin_start=mixdb.ft_config.bin_start,
360
+ bin_end=mixdb.ft_config.bin_end,
361
+ ttype=mixdb.ft_config.ttype,
362
+ )
436
363
 
437
364
  segsnr_t = np.empty(mixture.samples, dtype=np.float32)
438
365
 
439
366
  target_energy = fft.execute_all(torch.from_numpy(target_audio))[1].numpy()
440
367
  noise_energy = fft.execute_all(torch.from_numpy(noise_audio))[1].numpy()
441
368
 
442
- offsets = range(0, mixture.samples, mixdb.ft_config.R)
369
+ offsets = range(0, mixture.samples, mixdb.ft_config.overlap)
443
370
  if len(target_energy) != len(offsets):
444
- raise SonusAIError(f'Number of frames in energy, {len(target_energy)},'
445
- f' is not number of frames in mixture, {len(offsets)}')
371
+ raise ValueError(
372
+ f"Number of frames in energy, {len(target_energy)}," f" is not number of frames in mixture, {len(offsets)}"
373
+ )
446
374
 
447
375
  for idx, offset in enumerate(offsets):
448
- indices = slice(offset, offset + mixdb.ft_config.R)
376
+ indices = slice(offset, offset + mixdb.ft_config.overlap)
449
377
 
450
378
  if noise_energy[idx] == 0:
451
379
  snr = np.float32(np.inf)
@@ -475,8 +403,9 @@ def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: AudiosT)
475
403
  for idx, target in enumerate(targets_audio):
476
404
  ir_idx = mixture.targets[idx].augmentation.ir
477
405
  if ir_idx is not None:
478
- targets_ir.append(apply_impulse_response(audio=target,
479
- ir=read_ir(mixdb.impulse_response_file(int(ir_idx)))))
406
+ targets_ir.append(
407
+ apply_impulse_response(audio=target, ir=read_ir(mixdb.impulse_response_file(int(ir_idx))))
408
+ )
480
409
  else:
481
410
  targets_ir.append(target)
482
411
 
@@ -497,9 +426,11 @@ def get_mixture_f(mixdb: MixtureDatabase, mixture: Mixture, mixture_audio: Audio
497
426
  mixture_f = forward_transform(mixture_audio, mixdb.ft_config)
498
427
 
499
428
  if mixture.spectral_mask_id is not None:
500
- mixture_f = apply_spectral_mask(audio_f=mixture_f,
501
- spectral_mask=mixdb.spectral_mask(mixture.spectral_mask_id),
502
- seed=mixture.spectral_mask_seed)
429
+ mixture_f = apply_spectral_mask(
430
+ audio_f=mixture_f,
431
+ spectral_mask=mixdb.spectral_mask(mixture.spectral_mask_id),
432
+ seed=mixture.spectral_mask_seed,
433
+ )
503
434
 
504
435
  return mixture_f
505
436
 
@@ -527,14 +458,18 @@ def forward_transform(audio: AudioT, config: TransformConfig) -> AudioF:
527
458
  :param config: Transform configuration
528
459
  :return: Frequency domain data [frames, bins]
529
460
  """
530
- from sonusai import ForwardTransform
531
-
532
- audio_f, _ = get_transform_from_audio(audio=audio,
533
- transform=ForwardTransform(N=config.N,
534
- R=config.R,
535
- bin_start=config.bin_start,
536
- bin_end=config.bin_end,
537
- ttype=config.ttype))
461
+ from pyaaware import ForwardTransform
462
+
463
+ audio_f, _ = get_transform_from_audio(
464
+ audio=audio,
465
+ transform=ForwardTransform(
466
+ length=config.length,
467
+ overlap=config.overlap,
468
+ bin_start=config.bin_start,
469
+ bin_end=config.bin_end,
470
+ ttype=config.ttype,
471
+ ),
472
+ )
538
473
  return audio_f
539
474
 
540
475
 
@@ -545,6 +480,7 @@ def get_audio_from_transform(data: AudioF, transform: InverseTransform) -> tuple
545
480
  :param transform: InverseTransform object
546
481
  :return: Time domain data [samples], Energy [frames]
547
482
  """
483
+
548
484
  import torch
549
485
 
550
486
  t, e = transform.execute_all(torch.from_numpy(data))
@@ -562,40 +498,44 @@ def inverse_transform(transform: AudioF, config: TransformConfig) -> AudioT:
562
498
  :return: Time domain data [samples]
563
499
  """
564
500
  import numpy as np
565
- from sonusai import InverseTransform
566
-
567
- audio, _ = get_audio_from_transform(data=transform,
568
- transform=InverseTransform(N=config.N,
569
- R=config.R,
570
- bin_start=config.bin_start,
571
- bin_end=config.bin_end,
572
- ttype=config.ttype,
573
- gain=np.float32(1)))
501
+ from pyaaware import InverseTransform
502
+
503
+ audio, _ = get_audio_from_transform(
504
+ data=transform,
505
+ transform=InverseTransform(
506
+ length=config.length,
507
+ overlap=config.overlap,
508
+ bin_start=config.bin_start,
509
+ bin_end=config.bin_end,
510
+ ttype=config.ttype,
511
+ gain=np.float32(1),
512
+ ),
513
+ )
574
514
  return audio
575
515
 
576
516
 
577
517
  def check_audio_files_exist(mixdb: MixtureDatabase) -> None:
578
- """Walk through all the noise and target audio files in a mixture database ensuring that they exist
579
- """
518
+ """Walk through all the noise and target audio files in a mixture database ensuring that they exist"""
580
519
  from os.path import exists
581
520
 
582
- from sonusai import SonusAIError
583
521
  from .tokenized_shell_vars import tokenized_expand
584
522
 
585
523
  for noise in mixdb.noise_files:
586
524
  file_name, _ = tokenized_expand(noise.name)
587
525
  if not exists(file_name):
588
- raise SonusAIError(f'Could not find {file_name}')
526
+ raise OSError(f"Could not find {file_name}")
589
527
 
590
528
  for target in mixdb.target_files:
591
529
  file_name, _ = tokenized_expand(target.name)
592
530
  if not exists(file_name):
593
- raise SonusAIError(f'Could not find {file_name}')
531
+ raise OSError(f"Could not find {file_name}")
594
532
 
595
533
 
596
- def augmented_target_samples(target_files: TargetFiles,
597
- target_augmentations: AugmentationRules,
598
- feature_step_samples: int) -> int:
534
+ def augmented_target_samples(
535
+ target_files: TargetFiles,
536
+ target_augmentations: AugmentationRules,
537
+ feature_step_samples: int,
538
+ ) -> int:
599
539
  from itertools import product
600
540
 
601
541
  from .augmentation import estimate_augmented_length_from_length
@@ -603,10 +543,16 @@ def augmented_target_samples(target_files: TargetFiles,
603
543
  target_ids = list(range(len(target_files)))
604
544
  target_augmentation_ids = list(range(len(target_augmentations)))
605
545
  it = list(product(*[target_ids, target_augmentation_ids]))
606
- return sum([estimate_augmented_length_from_length(
607
- length=target_files[fi].samples,
608
- tempo=float(target_augmentations[ai].tempo),
609
- frame_length=feature_step_samples) for fi, ai, in it])
546
+ return sum(
547
+ [
548
+ estimate_augmented_length_from_length(
549
+ length=target_files[fi].samples,
550
+ tempo=target_augmentations[ai].tempo,
551
+ frame_length=feature_step_samples,
552
+ )
553
+ for fi, ai in it
554
+ ]
555
+ )
610
556
 
611
557
 
612
558
  def augmented_noise_samples(noise_files: NoiseFiles, noise_augmentations: Augmentations) -> int:
@@ -621,18 +567,17 @@ def augmented_noise_samples(noise_files: NoiseFiles, noise_augmentations: Augmen
621
567
  def augmented_noise_length(noise_file: NoiseFile, noise_augmentation: Augmentation) -> int:
622
568
  from .augmentation import estimate_augmented_length_from_length
623
569
 
624
- return estimate_augmented_length_from_length(length=noise_file.samples,
625
- tempo=noise_augmentation.tempo)
570
+ return estimate_augmented_length_from_length(length=noise_file.samples, tempo=noise_augmentation.tempo)
626
571
 
627
572
 
628
- def get_textgrid_tier_from_target_file(target_file: str, tier: str) -> Optional[SpeechMetadata]:
573
+ def get_textgrid_tier_from_target_file(target_file: str, tier: str) -> SpeechMetadata | None:
629
574
  from pathlib import Path
630
575
 
631
576
  from praatio import textgrid
632
577
 
633
578
  from .tokenized_shell_vars import tokenized_expand
634
579
 
635
- textgrid_file = Path(tokenized_expand(target_file)[0]).with_suffix('.TextGrid')
580
+ textgrid_file = Path(tokenized_expand(target_file)[0]).with_suffix(".TextGrid")
636
581
  if not textgrid_file.exists():
637
582
  return None
638
583