sonusai 0.19.6__py3-none-any.whl → 0.19.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sonusai/__init__.py +1 -1
  2. sonusai/aawscd_probwrite.py +1 -1
  3. sonusai/calc_metric_spenh.py +1 -1
  4. sonusai/genft.py +29 -14
  5. sonusai/genmetrics.py +60 -42
  6. sonusai/genmix.py +41 -29
  7. sonusai/genmixdb.py +56 -64
  8. sonusai/metrics/calc_class_weights.py +1 -3
  9. sonusai/metrics/calc_optimal_thresholds.py +2 -2
  10. sonusai/metrics/calc_phase_distance.py +1 -1
  11. sonusai/metrics/calc_speech.py +6 -6
  12. sonusai/metrics/class_summary.py +6 -15
  13. sonusai/metrics/confusion_matrix_summary.py +11 -27
  14. sonusai/metrics/one_hot.py +3 -3
  15. sonusai/metrics/snr_summary.py +7 -7
  16. sonusai/mixture/__init__.py +2 -17
  17. sonusai/mixture/augmentation.py +5 -6
  18. sonusai/mixture/class_count.py +1 -1
  19. sonusai/mixture/config.py +36 -46
  20. sonusai/mixture/data_io.py +30 -1
  21. sonusai/mixture/datatypes.py +29 -40
  22. sonusai/mixture/db_datatypes.py +1 -1
  23. sonusai/mixture/feature.py +3 -23
  24. sonusai/mixture/generation.py +161 -204
  25. sonusai/mixture/helpers.py +29 -187
  26. sonusai/mixture/mixdb.py +386 -159
  27. sonusai/mixture/soundfile_audio.py +1 -1
  28. sonusai/mixture/sox_audio.py +4 -4
  29. sonusai/mixture/sox_augmentation.py +1 -1
  30. sonusai/mixture/target_class_balancing.py +9 -11
  31. sonusai/mixture/targets.py +23 -20
  32. sonusai/mixture/torchaudio_audio.py +18 -7
  33. sonusai/mixture/torchaudio_augmentation.py +3 -4
  34. sonusai/mixture/truth.py +21 -34
  35. sonusai/mixture/truth_functions/__init__.py +6 -0
  36. sonusai/mixture/truth_functions/crm.py +51 -37
  37. sonusai/mixture/truth_functions/energy.py +95 -50
  38. sonusai/mixture/truth_functions/file.py +12 -8
  39. sonusai/mixture/truth_functions/metadata.py +24 -0
  40. sonusai/mixture/truth_functions/metrics.py +28 -0
  41. sonusai/mixture/truth_functions/phoneme.py +4 -5
  42. sonusai/mixture/truth_functions/sed.py +32 -23
  43. sonusai/mixture/truth_functions/target.py +62 -29
  44. sonusai/mkwav.py +20 -19
  45. sonusai/queries/queries.py +9 -15
  46. sonusai/speech/l2arctic.py +6 -2
  47. sonusai/summarize_metric_spenh.py +1 -1
  48. sonusai/utils/__init__.py +1 -0
  49. sonusai/utils/asr_functions/aaware_whisper.py +1 -1
  50. sonusai/utils/audio_devices.py +27 -18
  51. sonusai/utils/docstring.py +6 -3
  52. sonusai/utils/energy_f.py +5 -3
  53. sonusai/utils/human_readable_size.py +6 -6
  54. sonusai/utils/load_object.py +15 -0
  55. sonusai/utils/onnx_utils.py +2 -2
  56. sonusai/utils/print_mixture_details.py +3 -3
  57. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/METADATA +2 -2
  58. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/RECORD +60 -58
  59. sonusai/mixture/truth_functions/datatypes.py +0 -37
  60. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/WHEEL +0 -0
  61. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/entry_points.txt +0 -0
sonusai/mixture/config.py CHANGED
@@ -1,9 +1,8 @@
1
1
  from sonusai.mixture.datatypes import ImpulseResponseFile
2
- from sonusai.mixture.datatypes import ImpulseResponseFiles
3
- from sonusai.mixture.datatypes import NoiseFiles
4
- from sonusai.mixture.datatypes import SpectralMasks
5
- from sonusai.mixture.datatypes import TargetFiles
6
- from sonusai.mixture.datatypes import TruthParameters
2
+ from sonusai.mixture.datatypes import NoiseFile
3
+ from sonusai.mixture.datatypes import SpectralMask
4
+ from sonusai.mixture.datatypes import TargetFile
5
+ from sonusai.mixture.datatypes import TruthParameter
7
6
 
8
7
 
9
8
  def raw_load_config(name: str) -> dict:
@@ -210,7 +209,7 @@ def update_config_from_hierarchy(root: str, leaf: str, config: dict) -> dict:
210
209
  return new_config
211
210
 
212
211
 
213
- def get_target_files(config: dict, show_progress: bool = False) -> TargetFiles:
212
+ def get_target_files(config: dict, show_progress: bool = False) -> list[TargetFile]:
214
213
  """Get the list of target files from a config
215
214
 
216
215
  :param config: Config dictionary
@@ -223,7 +222,7 @@ def get_target_files(config: dict, show_progress: bool = False) -> TargetFiles:
223
222
  from sonusai.utils import par_track
224
223
  from sonusai.utils import track
225
224
 
226
- from .datatypes import TargetFiles
225
+ from .datatypes import TargetFile
227
226
 
228
227
  class_indices = config["class_indices"]
229
228
  if not isinstance(class_indices, list):
@@ -255,7 +254,7 @@ def get_target_files(config: dict, show_progress: bool = False) -> TargetFiles:
255
254
  if any(class_index > num_classes for class_index in target_file["class_indices"]):
256
255
  raise ValueError(f"class index elements must not be greater than {num_classes}")
257
256
 
258
- return dataclass_from_dict(TargetFiles, target_files)
257
+ return dataclass_from_dict(list[TargetFile], target_files)
259
258
 
260
259
 
261
260
  def append_target_files(
@@ -294,6 +293,7 @@ def append_target_files(
294
293
  if tokens is None:
295
294
  tokens = {}
296
295
 
296
+ truth_configs_merged = deepcopy(truth_configs)
297
297
  if isinstance(entry, dict):
298
298
  if "name" in entry:
299
299
  in_name = entry["name"]
@@ -312,15 +312,11 @@ def append_target_files(
312
312
  raise AttributeError(
313
313
  f"Truth config '{key}' override specified for {entry['name']} is not defined at top level"
314
314
  )
315
- truth_configs_merged = {}
316
- for key in truth_configs_override:
317
- truth_configs_merged[key] = deepcopy(truth_configs[key])
318
- if truth_configs_override[key] is not None:
315
+ if key in truth_configs_override:
319
316
  truth_configs_merged[key] |= truth_configs_override[key]
320
317
  level_type = entry.get("level_type", level_type)
321
318
  else:
322
319
  in_name = entry
323
- truth_configs_merged = deepcopy(truth_configs)
324
320
 
325
321
  in_name, new_tokens = tokenized_expand(in_name)
326
322
  tokens.update(new_tokens)
@@ -416,7 +412,7 @@ def append_target_files(
416
412
  return target_files
417
413
 
418
414
 
419
- def get_noise_files(config: dict, show_progress: bool = False) -> NoiseFiles:
415
+ def get_noise_files(config: dict, show_progress: bool = False) -> list[NoiseFile]:
420
416
  """Get the list of noise files from a config
421
417
 
422
418
  :param config: Config dictionary
@@ -429,7 +425,7 @@ def get_noise_files(config: dict, show_progress: bool = False) -> NoiseFiles:
429
425
  from sonusai.utils import par_track
430
426
  from sonusai.utils import track
431
427
 
432
- from .datatypes import NoiseFiles
428
+ from .datatypes import NoiseFile
433
429
 
434
430
  noise_files = list(chain.from_iterable([append_noise_files(entry=entry) for entry in config["noises"]]))
435
431
 
@@ -437,7 +433,7 @@ def get_noise_files(config: dict, show_progress: bool = False) -> NoiseFiles:
437
433
  noise_files = par_track(_get_num_samples, noise_files, progress=progress)
438
434
  progress.close()
439
435
 
440
- return dataclass_from_dict(NoiseFiles, noise_files)
436
+ return dataclass_from_dict(list[NoiseFile], noise_files)
441
437
 
442
438
 
443
439
  def append_noise_files(entry: dict | str, tokens: dict | None = None) -> list[dict]:
@@ -522,26 +518,25 @@ def append_noise_files(entry: dict | str, tokens: dict | None = None) -> list[di
522
518
  return noise_files
523
519
 
524
520
 
525
- def get_impulse_response_files(config: dict) -> ImpulseResponseFiles:
521
+ def get_impulse_response_files(config: dict) -> list[ImpulseResponseFile]:
526
522
  """Get the list of impulse response files from a config
527
523
 
528
524
  :param config: Config dictionary
529
525
  :return: List of impulse response files
530
526
  """
531
- return [ImpulseResponseFile(entry["name"], entry["tags"]) for entry in config["impulse_responses"]]
532
- # from itertools import chain
533
- #
534
- # return list(
535
- # chain.from_iterable(
536
- # [
537
- # append_impulse_response_files(entry=ImpulseResponseFile(entry["name"], entry["tags"]))
538
- # for entry in config["impulse_responses"]
539
- # ]
540
- # )
541
- # )
542
-
543
-
544
- def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | None = None) -> list[str]:
527
+ from itertools import chain
528
+
529
+ return list(
530
+ chain.from_iterable(
531
+ [
532
+ append_impulse_response_files(entry=ImpulseResponseFile(entry["name"], entry.get("tags", [])))
533
+ for entry in config["impulse_responses"]
534
+ ]
535
+ )
536
+ )
537
+
538
+
539
+ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | None = None) -> list[ImpulseResponseFile]:
545
540
  """Process impulse response files list and append as needed
546
541
 
547
542
  :param entry: Impulse response file entry to append to the list
@@ -569,7 +564,7 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
569
564
  if not names:
570
565
  raise OSError(f"Could not find {in_name}. Make sure path exists")
571
566
 
572
- impulse_response_files: list[str] = []
567
+ impulse_response_files: list[ImpulseResponseFile] = []
573
568
  for name in names:
574
569
  ext = splitext(name)[1].lower()
575
570
  dir_name = dirname(name)
@@ -607,14 +602,14 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
607
602
  raise OSError(f"Error processing {name}: {e}") from e
608
603
  else:
609
604
  validate_input_file(name)
610
- impulse_response_files.append(tokenized_replace(name, tokens))
605
+ impulse_response_files.append(ImpulseResponseFile(tokenized_replace(name, tokens), entry.tags))
611
606
  except Exception as e:
612
607
  raise OSError(f"Error processing {name}: {e}") from e
613
608
 
614
609
  return impulse_response_files
615
610
 
616
611
 
617
- def get_spectral_masks(config: dict) -> SpectralMasks:
612
+ def get_spectral_masks(config: dict) -> list[SpectralMask]:
618
613
  """Get the list of spectral masks from a config
619
614
 
620
615
  :param config: Config dictionary
@@ -623,12 +618,12 @@ def get_spectral_masks(config: dict) -> SpectralMasks:
623
618
  from sonusai.utils import dataclass_from_dict
624
619
 
625
620
  try:
626
- return dataclass_from_dict(SpectralMasks, config["spectral_masks"])
621
+ return dataclass_from_dict(list[SpectralMask], config["spectral_masks"])
627
622
  except Exception as e:
628
623
  raise ValueError(f"Error in spectral_masks: {e}") from e
629
624
 
630
625
 
631
- def get_truth_parameters(config: dict) -> TruthParameters:
626
+ def get_truth_parameters(config: dict) -> list[TruthParameter]:
632
627
  """Get the list of truth parameters from a config
633
628
 
634
629
  :param config: Config dictionary
@@ -637,26 +632,21 @@ def get_truth_parameters(config: dict) -> TruthParameters:
637
632
  from copy import deepcopy
638
633
 
639
634
  from sonusai.mixture import truth_functions
640
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
641
635
 
642
636
  from .constants import REQUIRED_TRUTH_CONFIGS
643
637
  from .datatypes import TruthParameter
644
638
 
645
- truth_parameters: TruthParameters = []
639
+ truth_parameters: list[TruthParameter] = []
646
640
  for name, truth_config in config["truth_configs"].items():
647
641
  optional_config = deepcopy(truth_config)
648
642
  for key in REQUIRED_TRUTH_CONFIGS:
649
643
  del optional_config[key]
650
644
 
651
- t_config = TruthFunctionConfig(
652
- feature=config["feature"],
653
- num_classes=config["num_classes"],
654
- class_indices=[1],
655
- target_gain=1,
656
- config=optional_config,
645
+ parameters = getattr(truth_functions, truth_config["function"] + "_parameters")(
646
+ config["feature"],
647
+ config["num_classes"],
648
+ optional_config,
657
649
  )
658
-
659
- parameters = getattr(truth_functions, truth_config["function"] + "_parameters")(t_config)
660
650
  truth_parameters.append(TruthParameter(name, parameters))
661
651
 
662
652
  return truth_parameters
@@ -128,6 +128,22 @@ def write_pickle_data(location: str, index: str, items: list[tuple[str, Any]] |
128
128
  f.write(pickle.dumps(item[1]))
129
129
 
130
130
 
131
+ def clear_pickle_data(location: str, index: str, items: list[str] | str) -> None:
132
+ """Clear mixture, target, or noise data pickle file
133
+
134
+ :param location: Location of the file
135
+ :param index: Mixture, target, or noise index
136
+ :param items: String(s) of data to retrieve
137
+ """
138
+ from pathlib import Path
139
+
140
+ if not isinstance(items, list):
141
+ items = [items]
142
+
143
+ for item in items:
144
+ Path(_get_pickle_name(location, index, item)).unlink(missing_ok=True)
145
+
146
+
131
147
  def read_cached_data(location: str, name: str, index: str, items: list[str] | str) -> Any:
132
148
  """Read cached data from a file
133
149
 
@@ -143,7 +159,7 @@ def read_cached_data(location: str, name: str, index: str, items: list[str] | st
143
159
 
144
160
 
145
161
  def write_cached_data(location: str, name: str, index: str, items: list[tuple[str, Any]] | tuple[str, Any]) -> None:
146
- """Write mixture data to a file
162
+ """Write data to a file
147
163
 
148
164
  :param location: Location of the mixture database
149
165
  :param name: Data name ('mixture', 'target', or 'noise')
@@ -153,3 +169,16 @@ def write_cached_data(location: str, name: str, index: str, items: list[tuple[st
153
169
  from os.path import join
154
170
 
155
171
  write_pickle_data(join(location, name), index, items)
172
+
173
+
174
+ def clear_cached_data(location: str, name: str, index: str, items: list[str] | str) -> None:
175
+ """Remove cached data file(s)
176
+
177
+ :param location: Location of the mixture database
178
+ :param name: Data name ('mixture', 'target', or 'noise')
179
+ :param index: Data index (mixture, target, or noise ID)
180
+ :param items: String(s) of data to clear
181
+ """
182
+ from os.path import join
183
+
184
+ clear_pickle_data(join(location, name), index, items)
@@ -12,16 +12,12 @@ from dataclasses_json import DataClassJsonMixin
12
12
  from praatio.utilities.constants import Interval
13
13
 
14
14
  AudioT: TypeAlias = npt.NDArray[np.float32]
15
- AudiosT: TypeAlias = list[AudioT]
16
15
 
17
- ListAudiosT: TypeAlias = list[AudiosT]
18
-
19
- Truth: TypeAlias = npt.NDArray[np.float32]
16
+ Truth: TypeAlias = Any
20
17
  TruthDict: TypeAlias = dict[str, Truth]
21
18
  Segsnr: TypeAlias = npt.NDArray[np.float32]
22
19
 
23
20
  AudioF: TypeAlias = npt.NDArray[np.complex64]
24
- AudiosF: TypeAlias = list[AudioF]
25
21
 
26
22
  EnergyT: TypeAlias = npt.NDArray[np.float32]
27
23
  EnergyF: TypeAlias = npt.NDArray[np.float32]
@@ -92,9 +88,6 @@ class AugmentationRule(DataClassSonusAIMixin):
92
88
  mixup: int = 1
93
89
 
94
90
 
95
- AugmentationRules: TypeAlias = list[AugmentationRule]
96
-
97
-
98
91
  @dataclass
99
92
  class Augmentation(DataClassSonusAIMixin):
100
93
  normalize: float | None = None
@@ -108,9 +101,6 @@ class Augmentation(DataClassSonusAIMixin):
108
101
  ir: int | None = None
109
102
 
110
103
 
111
- Augmentations: TypeAlias = list[Augmentation]
112
-
113
-
114
104
  @dataclass(frozen=True)
115
105
  class UniversalSNRGenerator:
116
106
  is_random: bool
@@ -159,18 +149,12 @@ class TargetFile(DataClassSonusAIMixin):
159
149
  return self.samples / SAMPLE_RATE
160
150
 
161
151
 
162
- TargetFiles: TypeAlias = list[TargetFile]
163
-
164
-
165
152
  @dataclass
166
153
  class AugmentedTarget(DataClassSonusAIMixin):
167
154
  target_id: int
168
155
  target_augmentation_id: int
169
156
 
170
157
 
171
- AugmentedTargets: TypeAlias = list[AugmentedTarget]
172
-
173
-
174
158
  @dataclass
175
159
  class NoiseFile(DataClassSonusAIMixin):
176
160
  name: str
@@ -183,7 +167,6 @@ class NoiseFile(DataClassSonusAIMixin):
183
167
  return self.samples / SAMPLE_RATE
184
168
 
185
169
 
186
- NoiseFiles: TypeAlias = list[NoiseFile]
187
170
  ClassCount: TypeAlias = list[int]
188
171
 
189
172
  GeneralizedIDs: TypeAlias = str | int | list[int] | range
@@ -191,11 +174,11 @@ GeneralizedIDs: TypeAlias = str | int | list[int] | range
191
174
 
192
175
  @dataclass
193
176
  class GenMixData:
194
- targets: AudiosT | None = None
177
+ targets: list[AudioT] | None = None
195
178
  target: AudioT | None = None
196
179
  noise: AudioT | None = None
197
180
  mixture: AudioT | None = None
198
- truth_t: TruthDict | None = None
181
+ truth_t: list[TruthDict] | None = None
199
182
  segsnr_t: Segsnr | None = None
200
183
 
201
184
 
@@ -223,9 +206,6 @@ class ImpulseResponseFile:
223
206
  tags: list[str]
224
207
 
225
208
 
226
- ImpulseResponseFiles: TypeAlias = list[ImpulseResponseFile]
227
-
228
-
229
209
  @dataclass(frozen=True)
230
210
  class SpectralMask(DataClassSonusAIMixin):
231
211
  f_max_width: int
@@ -235,23 +215,24 @@ class SpectralMask(DataClassSonusAIMixin):
235
215
  t_max_percent: int
236
216
 
237
217
 
238
- SpectralMasks: TypeAlias = list[SpectralMask]
239
-
240
-
241
218
  @dataclass(frozen=True)
242
219
  class TruthParameter(DataClassSonusAIMixin):
243
220
  name: str
244
- parameters: int
245
-
246
-
247
- TruthParameters: TypeAlias = list[TruthParameter]
221
+ parameters: int | None
248
222
 
249
223
 
250
224
  @dataclass
251
225
  class Target(DataClassSonusAIMixin):
252
226
  file_id: int
253
227
  augmentation: Augmentation
254
- gain: float = 1.0
228
+
229
+ @property
230
+ def gain(self) -> float:
231
+ # gain is used to back out the gain augmentation in order to return the target audio
232
+ # to its normalized level when calculating truth (if needed).
233
+ if self.augmentation.gain is None:
234
+ return 1.0
235
+ return round(10 ** (self.augmentation.gain / 20), ndigits=5)
255
236
 
256
237
 
257
238
  Targets: TypeAlias = list[Target]
@@ -261,14 +242,14 @@ Targets: TypeAlias = list[Target]
261
242
  class Noise(DataClassSonusAIMixin):
262
243
  file_id: int
263
244
  augmentation: Augmentation
264
- offset: int = 0
265
245
 
266
246
 
267
247
  @dataclass
268
248
  class Mixture(DataClassSonusAIMixin):
269
249
  name: str
270
- targets: Targets
250
+ targets: list[Target]
271
251
  noise: Noise
252
+ noise_offset: int
272
253
  samples: int
273
254
  snr: UniversalSNR
274
255
  spectral_mask_id: int
@@ -288,8 +269,16 @@ class Mixture(DataClassSonusAIMixin):
288
269
  def target_augmentations(self) -> list[Augmentation]:
289
270
  return [target.augmentation for target in self.targets]
290
271
 
272
+ @property
273
+ def is_noise_only(self) -> bool:
274
+ return self.snr < -96
275
+
276
+ @property
277
+ def is_target_only(self) -> bool:
278
+ return self.snr > 96
291
279
 
292
- Mixtures: TypeAlias = list[Mixture]
280
+ def target_gain(self, target_index: int) -> float:
281
+ return (self.targets[target_index].gain if not self.is_noise_only else 0) * self.target_snr_gain
293
282
 
294
283
 
295
284
  @dataclass(frozen=True)
@@ -304,7 +293,7 @@ class TransformConfig:
304
293
  @dataclass(frozen=True)
305
294
  class FeatureGeneratorConfig:
306
295
  feature_mode: str
307
- truth_parameters: dict[str, int]
296
+ truth_parameters: dict[str, int | None]
308
297
 
309
298
 
310
299
  @dataclass(frozen=True)
@@ -328,13 +317,13 @@ class MixtureDatabaseConfig(DataClassSonusAIMixin):
328
317
  class_labels: list[str]
329
318
  class_weights_threshold: list[float]
330
319
  feature: str
331
- impulse_response_files: ImpulseResponseFiles
332
- mixtures: Mixtures
320
+ impulse_response_files: list[ImpulseResponseFile]
321
+ mixtures: list[Mixture]
333
322
  noise_mix_mode: str
334
- noise_files: NoiseFiles
323
+ noise_files: list[NoiseFile]
335
324
  num_classes: int
336
- spectral_masks: SpectralMasks
337
- target_files: TargetFiles
325
+ spectral_masks: list[SpectralMask]
326
+ target_files: list[TargetFile]
338
327
 
339
328
 
340
329
  SpeechMetadata: TypeAlias = str | list[Interval] | None
@@ -35,7 +35,7 @@ SpectralMaskRecord = namedtuple(
35
35
  ["id", "f_max_width", "f_num", "t_max_width", "t_num", "t_max_percent"],
36
36
  )
37
37
 
38
- TargetRecord = namedtuple("TargetRecord", ["id", "file_id", "augmentation", "gain"])
38
+ TargetRecord = namedtuple("TargetRecord", ["id", "file_id", "augmentation"])
39
39
 
40
40
  MixtureRecord = namedtuple(
41
41
  "MixtureRecord",
@@ -12,7 +12,6 @@ def get_feature_from_audio(
12
12
  :param feature_mode: Feature mode
13
13
  :return: Feature data [frames, strides, feature_parameters]
14
14
  """
15
- import numpy as np
16
15
  from pyaaware import FeatureGenerator
17
16
 
18
17
  from .datatypes import TransformConfig
@@ -31,33 +30,14 @@ def get_feature_from_audio(
31
30
  ),
32
31
  )
33
32
 
34
- transform_frames = audio_f.shape[0]
35
- feature_frames = transform_frames // (fg.decimation * fg.step)
36
- feature = np.empty((feature_frames, fg.stride, fg.feature_parameters), dtype=np.float32)
37
-
38
- feature_frame = 0
39
- for transform_frame in range(transform_frames):
40
- fg.execute(audio_f[transform_frame])
41
-
42
- if fg.eof():
43
- feature[feature_frame] = fg.feature()
44
- feature_frame += 1
33
+ return fg.execute_all(audio_f)[0]
45
34
 
46
- return feature
47
35
 
48
-
49
- def get_audio_from_feature(
50
- feature: Feature,
51
- feature_mode: str,
52
- num_classes: int | None = 1,
53
- truth_mutex: bool | None = False,
54
- ) -> AudioT:
36
+ def get_audio_from_feature(feature: Feature, feature_mode: str) -> AudioT:
55
37
  """Apply inverse transform to feature data to generate audio data
56
38
 
57
39
  :param feature: Feature data [frames, stride=1, feature_parameters]
58
40
  :param feature_mode: Feature mode
59
- :param num_classes: Number of classes
60
- :param truth_mutex: Whether to calculate 'other' label
61
41
  :return: Audio data [samples]
62
42
  """
63
43
  import numpy as np
@@ -75,7 +55,7 @@ def get_audio_from_feature(
75
55
  if feature.shape[1] != 1:
76
56
  raise ValueError("Strided feature data is not supported for audio extraction; stride must be 1.")
77
57
 
78
- fg = FeatureGenerator(feature_mode=feature_mode, num_classes=num_classes, truth_mutex=truth_mutex)
58
+ fg = FeatureGenerator(feature_mode=feature_mode)
79
59
 
80
60
  feature_complex = unstack_complex(feature.squeeze())
81
61
  if feature_mode[0:1] == "h":