sonusai 0.19.6__py3-none-any.whl → 0.19.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +1 -1
- sonusai/aawscd_probwrite.py +1 -1
- sonusai/calc_metric_spenh.py +1 -1
- sonusai/genft.py +29 -14
- sonusai/genmetrics.py +60 -42
- sonusai/genmix.py +41 -29
- sonusai/genmixdb.py +56 -64
- sonusai/metrics/calc_class_weights.py +1 -3
- sonusai/metrics/calc_optimal_thresholds.py +2 -2
- sonusai/metrics/calc_phase_distance.py +1 -1
- sonusai/metrics/calc_speech.py +6 -6
- sonusai/metrics/class_summary.py +6 -15
- sonusai/metrics/confusion_matrix_summary.py +11 -27
- sonusai/metrics/one_hot.py +3 -3
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/mixture/__init__.py +2 -17
- sonusai/mixture/augmentation.py +5 -6
- sonusai/mixture/class_count.py +1 -1
- sonusai/mixture/config.py +36 -46
- sonusai/mixture/data_io.py +30 -1
- sonusai/mixture/datatypes.py +29 -40
- sonusai/mixture/db_datatypes.py +1 -1
- sonusai/mixture/feature.py +3 -23
- sonusai/mixture/generation.py +161 -204
- sonusai/mixture/helpers.py +29 -187
- sonusai/mixture/mixdb.py +386 -159
- sonusai/mixture/soundfile_audio.py +1 -1
- sonusai/mixture/sox_audio.py +4 -4
- sonusai/mixture/sox_augmentation.py +1 -1
- sonusai/mixture/target_class_balancing.py +9 -11
- sonusai/mixture/targets.py +23 -20
- sonusai/mixture/torchaudio_audio.py +18 -7
- sonusai/mixture/torchaudio_augmentation.py +3 -4
- sonusai/mixture/truth.py +21 -34
- sonusai/mixture/truth_functions/__init__.py +6 -0
- sonusai/mixture/truth_functions/crm.py +51 -37
- sonusai/mixture/truth_functions/energy.py +95 -50
- sonusai/mixture/truth_functions/file.py +12 -8
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +4 -5
- sonusai/mixture/truth_functions/sed.py +32 -23
- sonusai/mixture/truth_functions/target.py +62 -29
- sonusai/mkwav.py +20 -19
- sonusai/queries/queries.py +9 -15
- sonusai/speech/l2arctic.py +6 -2
- sonusai/summarize_metric_spenh.py +1 -1
- sonusai/utils/__init__.py +1 -0
- sonusai/utils/asr_functions/aaware_whisper.py +1 -1
- sonusai/utils/audio_devices.py +27 -18
- sonusai/utils/docstring.py +6 -3
- sonusai/utils/energy_f.py +5 -3
- sonusai/utils/human_readable_size.py +6 -6
- sonusai/utils/load_object.py +15 -0
- sonusai/utils/onnx_utils.py +2 -2
- sonusai/utils/print_mixture_details.py +3 -3
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/METADATA +2 -2
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/RECORD +60 -58
- sonusai/mixture/truth_functions/datatypes.py +0 -37
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/WHEEL +0 -0
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/entry_points.txt +0 -0
sonusai/mixture/config.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1
1
|
from sonusai.mixture.datatypes import ImpulseResponseFile
|
2
|
-
from sonusai.mixture.datatypes import
|
3
|
-
from sonusai.mixture.datatypes import
|
4
|
-
from sonusai.mixture.datatypes import
|
5
|
-
from sonusai.mixture.datatypes import
|
6
|
-
from sonusai.mixture.datatypes import TruthParameters
|
2
|
+
from sonusai.mixture.datatypes import NoiseFile
|
3
|
+
from sonusai.mixture.datatypes import SpectralMask
|
4
|
+
from sonusai.mixture.datatypes import TargetFile
|
5
|
+
from sonusai.mixture.datatypes import TruthParameter
|
7
6
|
|
8
7
|
|
9
8
|
def raw_load_config(name: str) -> dict:
|
@@ -210,7 +209,7 @@ def update_config_from_hierarchy(root: str, leaf: str, config: dict) -> dict:
|
|
210
209
|
return new_config
|
211
210
|
|
212
211
|
|
213
|
-
def get_target_files(config: dict, show_progress: bool = False) ->
|
212
|
+
def get_target_files(config: dict, show_progress: bool = False) -> list[TargetFile]:
|
214
213
|
"""Get the list of target files from a config
|
215
214
|
|
216
215
|
:param config: Config dictionary
|
@@ -223,7 +222,7 @@ def get_target_files(config: dict, show_progress: bool = False) -> TargetFiles:
|
|
223
222
|
from sonusai.utils import par_track
|
224
223
|
from sonusai.utils import track
|
225
224
|
|
226
|
-
from .datatypes import
|
225
|
+
from .datatypes import TargetFile
|
227
226
|
|
228
227
|
class_indices = config["class_indices"]
|
229
228
|
if not isinstance(class_indices, list):
|
@@ -255,7 +254,7 @@ def get_target_files(config: dict, show_progress: bool = False) -> TargetFiles:
|
|
255
254
|
if any(class_index > num_classes for class_index in target_file["class_indices"]):
|
256
255
|
raise ValueError(f"class index elements must not be greater than {num_classes}")
|
257
256
|
|
258
|
-
return dataclass_from_dict(
|
257
|
+
return dataclass_from_dict(list[TargetFile], target_files)
|
259
258
|
|
260
259
|
|
261
260
|
def append_target_files(
|
@@ -294,6 +293,7 @@ def append_target_files(
|
|
294
293
|
if tokens is None:
|
295
294
|
tokens = {}
|
296
295
|
|
296
|
+
truth_configs_merged = deepcopy(truth_configs)
|
297
297
|
if isinstance(entry, dict):
|
298
298
|
if "name" in entry:
|
299
299
|
in_name = entry["name"]
|
@@ -312,15 +312,11 @@ def append_target_files(
|
|
312
312
|
raise AttributeError(
|
313
313
|
f"Truth config '{key}' override specified for {entry['name']} is not defined at top level"
|
314
314
|
)
|
315
|
-
|
316
|
-
for key in truth_configs_override:
|
317
|
-
truth_configs_merged[key] = deepcopy(truth_configs[key])
|
318
|
-
if truth_configs_override[key] is not None:
|
315
|
+
if key in truth_configs_override:
|
319
316
|
truth_configs_merged[key] |= truth_configs_override[key]
|
320
317
|
level_type = entry.get("level_type", level_type)
|
321
318
|
else:
|
322
319
|
in_name = entry
|
323
|
-
truth_configs_merged = deepcopy(truth_configs)
|
324
320
|
|
325
321
|
in_name, new_tokens = tokenized_expand(in_name)
|
326
322
|
tokens.update(new_tokens)
|
@@ -416,7 +412,7 @@ def append_target_files(
|
|
416
412
|
return target_files
|
417
413
|
|
418
414
|
|
419
|
-
def get_noise_files(config: dict, show_progress: bool = False) ->
|
415
|
+
def get_noise_files(config: dict, show_progress: bool = False) -> list[NoiseFile]:
|
420
416
|
"""Get the list of noise files from a config
|
421
417
|
|
422
418
|
:param config: Config dictionary
|
@@ -429,7 +425,7 @@ def get_noise_files(config: dict, show_progress: bool = False) -> NoiseFiles:
|
|
429
425
|
from sonusai.utils import par_track
|
430
426
|
from sonusai.utils import track
|
431
427
|
|
432
|
-
from .datatypes import
|
428
|
+
from .datatypes import NoiseFile
|
433
429
|
|
434
430
|
noise_files = list(chain.from_iterable([append_noise_files(entry=entry) for entry in config["noises"]]))
|
435
431
|
|
@@ -437,7 +433,7 @@ def get_noise_files(config: dict, show_progress: bool = False) -> NoiseFiles:
|
|
437
433
|
noise_files = par_track(_get_num_samples, noise_files, progress=progress)
|
438
434
|
progress.close()
|
439
435
|
|
440
|
-
return dataclass_from_dict(
|
436
|
+
return dataclass_from_dict(list[NoiseFile], noise_files)
|
441
437
|
|
442
438
|
|
443
439
|
def append_noise_files(entry: dict | str, tokens: dict | None = None) -> list[dict]:
|
@@ -522,26 +518,25 @@ def append_noise_files(entry: dict | str, tokens: dict | None = None) -> list[di
|
|
522
518
|
return noise_files
|
523
519
|
|
524
520
|
|
525
|
-
def get_impulse_response_files(config: dict) ->
|
521
|
+
def get_impulse_response_files(config: dict) -> list[ImpulseResponseFile]:
|
526
522
|
"""Get the list of impulse response files from a config
|
527
523
|
|
528
524
|
:param config: Config dictionary
|
529
525
|
:return: List of impulse response files
|
530
526
|
"""
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | None = None) -> list[str]:
|
527
|
+
from itertools import chain
|
528
|
+
|
529
|
+
return list(
|
530
|
+
chain.from_iterable(
|
531
|
+
[
|
532
|
+
append_impulse_response_files(entry=ImpulseResponseFile(entry["name"], entry.get("tags", [])))
|
533
|
+
for entry in config["impulse_responses"]
|
534
|
+
]
|
535
|
+
)
|
536
|
+
)
|
537
|
+
|
538
|
+
|
539
|
+
def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | None = None) -> list[ImpulseResponseFile]:
|
545
540
|
"""Process impulse response files list and append as needed
|
546
541
|
|
547
542
|
:param entry: Impulse response file entry to append to the list
|
@@ -569,7 +564,7 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
|
|
569
564
|
if not names:
|
570
565
|
raise OSError(f"Could not find {in_name}. Make sure path exists")
|
571
566
|
|
572
|
-
impulse_response_files: list[
|
567
|
+
impulse_response_files: list[ImpulseResponseFile] = []
|
573
568
|
for name in names:
|
574
569
|
ext = splitext(name)[1].lower()
|
575
570
|
dir_name = dirname(name)
|
@@ -607,14 +602,14 @@ def append_impulse_response_files(entry: ImpulseResponseFile, tokens: dict | Non
|
|
607
602
|
raise OSError(f"Error processing {name}: {e}") from e
|
608
603
|
else:
|
609
604
|
validate_input_file(name)
|
610
|
-
impulse_response_files.append(tokenized_replace(name, tokens))
|
605
|
+
impulse_response_files.append(ImpulseResponseFile(tokenized_replace(name, tokens), entry.tags))
|
611
606
|
except Exception as e:
|
612
607
|
raise OSError(f"Error processing {name}: {e}") from e
|
613
608
|
|
614
609
|
return impulse_response_files
|
615
610
|
|
616
611
|
|
617
|
-
def get_spectral_masks(config: dict) ->
|
612
|
+
def get_spectral_masks(config: dict) -> list[SpectralMask]:
|
618
613
|
"""Get the list of spectral masks from a config
|
619
614
|
|
620
615
|
:param config: Config dictionary
|
@@ -623,12 +618,12 @@ def get_spectral_masks(config: dict) -> SpectralMasks:
|
|
623
618
|
from sonusai.utils import dataclass_from_dict
|
624
619
|
|
625
620
|
try:
|
626
|
-
return dataclass_from_dict(
|
621
|
+
return dataclass_from_dict(list[SpectralMask], config["spectral_masks"])
|
627
622
|
except Exception as e:
|
628
623
|
raise ValueError(f"Error in spectral_masks: {e}") from e
|
629
624
|
|
630
625
|
|
631
|
-
def get_truth_parameters(config: dict) ->
|
626
|
+
def get_truth_parameters(config: dict) -> list[TruthParameter]:
|
632
627
|
"""Get the list of truth parameters from a config
|
633
628
|
|
634
629
|
:param config: Config dictionary
|
@@ -637,26 +632,21 @@ def get_truth_parameters(config: dict) -> TruthParameters:
|
|
637
632
|
from copy import deepcopy
|
638
633
|
|
639
634
|
from sonusai.mixture import truth_functions
|
640
|
-
from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
|
641
635
|
|
642
636
|
from .constants import REQUIRED_TRUTH_CONFIGS
|
643
637
|
from .datatypes import TruthParameter
|
644
638
|
|
645
|
-
truth_parameters:
|
639
|
+
truth_parameters: list[TruthParameter] = []
|
646
640
|
for name, truth_config in config["truth_configs"].items():
|
647
641
|
optional_config = deepcopy(truth_config)
|
648
642
|
for key in REQUIRED_TRUTH_CONFIGS:
|
649
643
|
del optional_config[key]
|
650
644
|
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
target_gain=1,
|
656
|
-
config=optional_config,
|
645
|
+
parameters = getattr(truth_functions, truth_config["function"] + "_parameters")(
|
646
|
+
config["feature"],
|
647
|
+
config["num_classes"],
|
648
|
+
optional_config,
|
657
649
|
)
|
658
|
-
|
659
|
-
parameters = getattr(truth_functions, truth_config["function"] + "_parameters")(t_config)
|
660
650
|
truth_parameters.append(TruthParameter(name, parameters))
|
661
651
|
|
662
652
|
return truth_parameters
|
sonusai/mixture/data_io.py
CHANGED
@@ -128,6 +128,22 @@ def write_pickle_data(location: str, index: str, items: list[tuple[str, Any]] |
|
|
128
128
|
f.write(pickle.dumps(item[1]))
|
129
129
|
|
130
130
|
|
131
|
+
def clear_pickle_data(location: str, index: str, items: list[str] | str) -> None:
|
132
|
+
"""Clear mixture, target, or noise data pickle file
|
133
|
+
|
134
|
+
:param location: Location of the file
|
135
|
+
:param index: Mixture, target, or noise index
|
136
|
+
:param items: String(s) of data to retrieve
|
137
|
+
"""
|
138
|
+
from pathlib import Path
|
139
|
+
|
140
|
+
if not isinstance(items, list):
|
141
|
+
items = [items]
|
142
|
+
|
143
|
+
for item in items:
|
144
|
+
Path(_get_pickle_name(location, index, item)).unlink(missing_ok=True)
|
145
|
+
|
146
|
+
|
131
147
|
def read_cached_data(location: str, name: str, index: str, items: list[str] | str) -> Any:
|
132
148
|
"""Read cached data from a file
|
133
149
|
|
@@ -143,7 +159,7 @@ def read_cached_data(location: str, name: str, index: str, items: list[str] | st
|
|
143
159
|
|
144
160
|
|
145
161
|
def write_cached_data(location: str, name: str, index: str, items: list[tuple[str, Any]] | tuple[str, Any]) -> None:
|
146
|
-
"""Write
|
162
|
+
"""Write data to a file
|
147
163
|
|
148
164
|
:param location: Location of the mixture database
|
149
165
|
:param name: Data name ('mixture', 'target', or 'noise')
|
@@ -153,3 +169,16 @@ def write_cached_data(location: str, name: str, index: str, items: list[tuple[st
|
|
153
169
|
from os.path import join
|
154
170
|
|
155
171
|
write_pickle_data(join(location, name), index, items)
|
172
|
+
|
173
|
+
|
174
|
+
def clear_cached_data(location: str, name: str, index: str, items: list[str] | str) -> None:
|
175
|
+
"""Remove cached data file(s)
|
176
|
+
|
177
|
+
:param location: Location of the mixture database
|
178
|
+
:param name: Data name ('mixture', 'target', or 'noise')
|
179
|
+
:param index: Data index (mixture, target, or noise ID)
|
180
|
+
:param items: String(s) of data to clear
|
181
|
+
"""
|
182
|
+
from os.path import join
|
183
|
+
|
184
|
+
clear_pickle_data(join(location, name), index, items)
|
sonusai/mixture/datatypes.py
CHANGED
@@ -12,16 +12,12 @@ from dataclasses_json import DataClassJsonMixin
|
|
12
12
|
from praatio.utilities.constants import Interval
|
13
13
|
|
14
14
|
AudioT: TypeAlias = npt.NDArray[np.float32]
|
15
|
-
AudiosT: TypeAlias = list[AudioT]
|
16
15
|
|
17
|
-
|
18
|
-
|
19
|
-
Truth: TypeAlias = npt.NDArray[np.float32]
|
16
|
+
Truth: TypeAlias = Any
|
20
17
|
TruthDict: TypeAlias = dict[str, Truth]
|
21
18
|
Segsnr: TypeAlias = npt.NDArray[np.float32]
|
22
19
|
|
23
20
|
AudioF: TypeAlias = npt.NDArray[np.complex64]
|
24
|
-
AudiosF: TypeAlias = list[AudioF]
|
25
21
|
|
26
22
|
EnergyT: TypeAlias = npt.NDArray[np.float32]
|
27
23
|
EnergyF: TypeAlias = npt.NDArray[np.float32]
|
@@ -92,9 +88,6 @@ class AugmentationRule(DataClassSonusAIMixin):
|
|
92
88
|
mixup: int = 1
|
93
89
|
|
94
90
|
|
95
|
-
AugmentationRules: TypeAlias = list[AugmentationRule]
|
96
|
-
|
97
|
-
|
98
91
|
@dataclass
|
99
92
|
class Augmentation(DataClassSonusAIMixin):
|
100
93
|
normalize: float | None = None
|
@@ -108,9 +101,6 @@ class Augmentation(DataClassSonusAIMixin):
|
|
108
101
|
ir: int | None = None
|
109
102
|
|
110
103
|
|
111
|
-
Augmentations: TypeAlias = list[Augmentation]
|
112
|
-
|
113
|
-
|
114
104
|
@dataclass(frozen=True)
|
115
105
|
class UniversalSNRGenerator:
|
116
106
|
is_random: bool
|
@@ -159,18 +149,12 @@ class TargetFile(DataClassSonusAIMixin):
|
|
159
149
|
return self.samples / SAMPLE_RATE
|
160
150
|
|
161
151
|
|
162
|
-
TargetFiles: TypeAlias = list[TargetFile]
|
163
|
-
|
164
|
-
|
165
152
|
@dataclass
|
166
153
|
class AugmentedTarget(DataClassSonusAIMixin):
|
167
154
|
target_id: int
|
168
155
|
target_augmentation_id: int
|
169
156
|
|
170
157
|
|
171
|
-
AugmentedTargets: TypeAlias = list[AugmentedTarget]
|
172
|
-
|
173
|
-
|
174
158
|
@dataclass
|
175
159
|
class NoiseFile(DataClassSonusAIMixin):
|
176
160
|
name: str
|
@@ -183,7 +167,6 @@ class NoiseFile(DataClassSonusAIMixin):
|
|
183
167
|
return self.samples / SAMPLE_RATE
|
184
168
|
|
185
169
|
|
186
|
-
NoiseFiles: TypeAlias = list[NoiseFile]
|
187
170
|
ClassCount: TypeAlias = list[int]
|
188
171
|
|
189
172
|
GeneralizedIDs: TypeAlias = str | int | list[int] | range
|
@@ -191,11 +174,11 @@ GeneralizedIDs: TypeAlias = str | int | list[int] | range
|
|
191
174
|
|
192
175
|
@dataclass
|
193
176
|
class GenMixData:
|
194
|
-
targets:
|
177
|
+
targets: list[AudioT] | None = None
|
195
178
|
target: AudioT | None = None
|
196
179
|
noise: AudioT | None = None
|
197
180
|
mixture: AudioT | None = None
|
198
|
-
truth_t: TruthDict | None = None
|
181
|
+
truth_t: list[TruthDict] | None = None
|
199
182
|
segsnr_t: Segsnr | None = None
|
200
183
|
|
201
184
|
|
@@ -223,9 +206,6 @@ class ImpulseResponseFile:
|
|
223
206
|
tags: list[str]
|
224
207
|
|
225
208
|
|
226
|
-
ImpulseResponseFiles: TypeAlias = list[ImpulseResponseFile]
|
227
|
-
|
228
|
-
|
229
209
|
@dataclass(frozen=True)
|
230
210
|
class SpectralMask(DataClassSonusAIMixin):
|
231
211
|
f_max_width: int
|
@@ -235,23 +215,24 @@ class SpectralMask(DataClassSonusAIMixin):
|
|
235
215
|
t_max_percent: int
|
236
216
|
|
237
217
|
|
238
|
-
SpectralMasks: TypeAlias = list[SpectralMask]
|
239
|
-
|
240
|
-
|
241
218
|
@dataclass(frozen=True)
|
242
219
|
class TruthParameter(DataClassSonusAIMixin):
|
243
220
|
name: str
|
244
|
-
parameters: int
|
245
|
-
|
246
|
-
|
247
|
-
TruthParameters: TypeAlias = list[TruthParameter]
|
221
|
+
parameters: int | None
|
248
222
|
|
249
223
|
|
250
224
|
@dataclass
|
251
225
|
class Target(DataClassSonusAIMixin):
|
252
226
|
file_id: int
|
253
227
|
augmentation: Augmentation
|
254
|
-
|
228
|
+
|
229
|
+
@property
|
230
|
+
def gain(self) -> float:
|
231
|
+
# gain is used to back out the gain augmentation in order to return the target audio
|
232
|
+
# to its normalized level when calculating truth (if needed).
|
233
|
+
if self.augmentation.gain is None:
|
234
|
+
return 1.0
|
235
|
+
return round(10 ** (self.augmentation.gain / 20), ndigits=5)
|
255
236
|
|
256
237
|
|
257
238
|
Targets: TypeAlias = list[Target]
|
@@ -261,14 +242,14 @@ Targets: TypeAlias = list[Target]
|
|
261
242
|
class Noise(DataClassSonusAIMixin):
|
262
243
|
file_id: int
|
263
244
|
augmentation: Augmentation
|
264
|
-
offset: int = 0
|
265
245
|
|
266
246
|
|
267
247
|
@dataclass
|
268
248
|
class Mixture(DataClassSonusAIMixin):
|
269
249
|
name: str
|
270
|
-
targets:
|
250
|
+
targets: list[Target]
|
271
251
|
noise: Noise
|
252
|
+
noise_offset: int
|
272
253
|
samples: int
|
273
254
|
snr: UniversalSNR
|
274
255
|
spectral_mask_id: int
|
@@ -288,8 +269,16 @@ class Mixture(DataClassSonusAIMixin):
|
|
288
269
|
def target_augmentations(self) -> list[Augmentation]:
|
289
270
|
return [target.augmentation for target in self.targets]
|
290
271
|
|
272
|
+
@property
|
273
|
+
def is_noise_only(self) -> bool:
|
274
|
+
return self.snr < -96
|
275
|
+
|
276
|
+
@property
|
277
|
+
def is_target_only(self) -> bool:
|
278
|
+
return self.snr > 96
|
291
279
|
|
292
|
-
|
280
|
+
def target_gain(self, target_index: int) -> float:
|
281
|
+
return (self.targets[target_index].gain if not self.is_noise_only else 0) * self.target_snr_gain
|
293
282
|
|
294
283
|
|
295
284
|
@dataclass(frozen=True)
|
@@ -304,7 +293,7 @@ class TransformConfig:
|
|
304
293
|
@dataclass(frozen=True)
|
305
294
|
class FeatureGeneratorConfig:
|
306
295
|
feature_mode: str
|
307
|
-
truth_parameters: dict[str, int]
|
296
|
+
truth_parameters: dict[str, int | None]
|
308
297
|
|
309
298
|
|
310
299
|
@dataclass(frozen=True)
|
@@ -328,13 +317,13 @@ class MixtureDatabaseConfig(DataClassSonusAIMixin):
|
|
328
317
|
class_labels: list[str]
|
329
318
|
class_weights_threshold: list[float]
|
330
319
|
feature: str
|
331
|
-
impulse_response_files:
|
332
|
-
mixtures:
|
320
|
+
impulse_response_files: list[ImpulseResponseFile]
|
321
|
+
mixtures: list[Mixture]
|
333
322
|
noise_mix_mode: str
|
334
|
-
noise_files:
|
323
|
+
noise_files: list[NoiseFile]
|
335
324
|
num_classes: int
|
336
|
-
spectral_masks:
|
337
|
-
target_files:
|
325
|
+
spectral_masks: list[SpectralMask]
|
326
|
+
target_files: list[TargetFile]
|
338
327
|
|
339
328
|
|
340
329
|
SpeechMetadata: TypeAlias = str | list[Interval] | None
|
sonusai/mixture/db_datatypes.py
CHANGED
@@ -35,7 +35,7 @@ SpectralMaskRecord = namedtuple(
|
|
35
35
|
["id", "f_max_width", "f_num", "t_max_width", "t_num", "t_max_percent"],
|
36
36
|
)
|
37
37
|
|
38
|
-
TargetRecord = namedtuple("TargetRecord", ["id", "file_id", "augmentation"
|
38
|
+
TargetRecord = namedtuple("TargetRecord", ["id", "file_id", "augmentation"])
|
39
39
|
|
40
40
|
MixtureRecord = namedtuple(
|
41
41
|
"MixtureRecord",
|
sonusai/mixture/feature.py
CHANGED
@@ -12,7 +12,6 @@ def get_feature_from_audio(
|
|
12
12
|
:param feature_mode: Feature mode
|
13
13
|
:return: Feature data [frames, strides, feature_parameters]
|
14
14
|
"""
|
15
|
-
import numpy as np
|
16
15
|
from pyaaware import FeatureGenerator
|
17
16
|
|
18
17
|
from .datatypes import TransformConfig
|
@@ -31,33 +30,14 @@ def get_feature_from_audio(
|
|
31
30
|
),
|
32
31
|
)
|
33
32
|
|
34
|
-
|
35
|
-
feature_frames = transform_frames // (fg.decimation * fg.step)
|
36
|
-
feature = np.empty((feature_frames, fg.stride, fg.feature_parameters), dtype=np.float32)
|
37
|
-
|
38
|
-
feature_frame = 0
|
39
|
-
for transform_frame in range(transform_frames):
|
40
|
-
fg.execute(audio_f[transform_frame])
|
41
|
-
|
42
|
-
if fg.eof():
|
43
|
-
feature[feature_frame] = fg.feature()
|
44
|
-
feature_frame += 1
|
33
|
+
return fg.execute_all(audio_f)[0]
|
45
34
|
|
46
|
-
return feature
|
47
35
|
|
48
|
-
|
49
|
-
def get_audio_from_feature(
|
50
|
-
feature: Feature,
|
51
|
-
feature_mode: str,
|
52
|
-
num_classes: int | None = 1,
|
53
|
-
truth_mutex: bool | None = False,
|
54
|
-
) -> AudioT:
|
36
|
+
def get_audio_from_feature(feature: Feature, feature_mode: str) -> AudioT:
|
55
37
|
"""Apply inverse transform to feature data to generate audio data
|
56
38
|
|
57
39
|
:param feature: Feature data [frames, stride=1, feature_parameters]
|
58
40
|
:param feature_mode: Feature mode
|
59
|
-
:param num_classes: Number of classes
|
60
|
-
:param truth_mutex: Whether to calculate 'other' label
|
61
41
|
:return: Audio data [samples]
|
62
42
|
"""
|
63
43
|
import numpy as np
|
@@ -75,7 +55,7 @@ def get_audio_from_feature(
|
|
75
55
|
if feature.shape[1] != 1:
|
76
56
|
raise ValueError("Strided feature data is not supported for audio extraction; stride must be 1.")
|
77
57
|
|
78
|
-
fg = FeatureGenerator(feature_mode=feature_mode
|
58
|
+
fg = FeatureGenerator(feature_mode=feature_mode)
|
79
59
|
|
80
60
|
feature_complex = unstack_complex(feature.squeeze())
|
81
61
|
if feature_mode[0:1] == "h":
|