sonusai 0.18.9__py3-none-any.whl → 0.19.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +81 -91
- sonusai/genmetrics.py +51 -61
- sonusai/genmix.py +105 -115
- sonusai/genmixdb.py +201 -174
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +16 -18
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +20 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +40 -38
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +669 -477
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +58 -101
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +41 -30
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/METADATA +20 -21
- sonusai-0.19.6.dist-info/RECORD +125 -0
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.9.dist-info/RECORD +0 -125
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/entry_points.txt +0 -0
@@ -5,14 +5,15 @@ from sonusai.mixture.datatypes import TargetFile
|
|
5
5
|
from sonusai.mixture.datatypes import TargetFiles
|
6
6
|
|
7
7
|
|
8
|
-
def balance_targets(
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
8
|
+
def balance_targets(
|
9
|
+
augmented_targets: AugmentedTargets,
|
10
|
+
targets: TargetFiles,
|
11
|
+
target_augmentations: AugmentationRules,
|
12
|
+
class_balancing_augmentation: AugmentationRule,
|
13
|
+
num_classes: int,
|
14
|
+
num_ir: int,
|
15
|
+
mixups: list[int] | None = None,
|
16
|
+
) -> tuple[AugmentedTargets, AugmentationRules]:
|
16
17
|
import math
|
17
18
|
|
18
19
|
from .augmentation import get_mixups
|
@@ -34,14 +35,13 @@ def balance_targets(augmented_targets: AugmentedTargets,
|
|
34
35
|
target_augmentations=target_augmentations,
|
35
36
|
mixup=mixup,
|
36
37
|
num_classes=num_classes,
|
37
|
-
|
38
|
+
)
|
38
39
|
|
39
40
|
largest = max([len(item) for item in augmented_target_indices_by_class])
|
40
41
|
largest = math.ceil(largest / mixup) * mixup
|
41
42
|
for at_indices in augmented_target_indices_by_class:
|
42
43
|
additional_augmentations_needed = largest - len(at_indices)
|
43
|
-
target_ids = sorted(
|
44
|
-
list(set([augmented_targets[at_index].target_id for at_index in at_indices])))
|
44
|
+
target_ids = sorted({augmented_targets[at_index].target_id for at_index in at_indices})
|
45
45
|
|
46
46
|
tfi_idx = 0
|
47
47
|
for _ in range(additional_augmentations_needed):
|
@@ -55,50 +55,55 @@ def balance_targets(augmented_targets: AugmentedTargets,
|
|
55
55
|
target_id=target_id,
|
56
56
|
mixup=mixup,
|
57
57
|
num_ir=num_ir,
|
58
|
-
first_cba_id=first_cba_id
|
59
|
-
|
60
|
-
|
58
|
+
first_cba_id=first_cba_id,
|
59
|
+
)
|
60
|
+
augmented_target = AugmentedTarget(target_id=target_id, target_augmentation_id=augmentation_index)
|
61
61
|
augmented_targets.append(augmented_target)
|
62
62
|
|
63
63
|
return augmented_targets, target_augmentations
|
64
64
|
|
65
65
|
|
66
|
-
def _get_unused_balancing_augmentation(
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
66
|
+
def _get_unused_balancing_augmentation(
|
67
|
+
augmented_targets: AugmentedTargets,
|
68
|
+
targets: TargetFiles,
|
69
|
+
target_augmentations: AugmentationRules,
|
70
|
+
class_balancing_augmentation: AugmentationRule,
|
71
|
+
target_id: int,
|
72
|
+
mixup: int,
|
73
|
+
num_ir: int,
|
74
|
+
first_cba_id: int,
|
75
|
+
) -> tuple[int, AugmentationRules]:
|
76
|
+
"""Get an unused balancing augmentation for a given target file index"""
|
76
77
|
from dataclasses import asdict
|
77
78
|
|
78
79
|
from .augmentation import get_augmentation_rules
|
79
80
|
|
80
81
|
balancing_augmentations = [item for item in range(len(target_augmentations)) if item >= first_cba_id]
|
81
|
-
used_balancing_augmentations = [
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
82
|
+
used_balancing_augmentations = [
|
83
|
+
at.target_augmentation_id
|
84
|
+
for at in augmented_targets
|
85
|
+
if at.target_id == target_id and at.target_augmentation_id in balancing_augmentations
|
86
|
+
]
|
87
|
+
|
88
|
+
augmentation_indices = [
|
89
|
+
item
|
90
|
+
for item in balancing_augmentations
|
91
|
+
if item not in used_balancing_augmentations and target_augmentations[item].mixup == mixup
|
92
|
+
]
|
88
93
|
if len(augmentation_indices) > 0:
|
89
94
|
return augmentation_indices[0], target_augmentations
|
90
95
|
|
91
|
-
class_balancing_augmentation = get_class_balancing_augmentation(
|
92
|
-
|
96
|
+
class_balancing_augmentation = get_class_balancing_augmentation(
|
97
|
+
target=targets[target_id], default_cba=class_balancing_augmentation
|
98
|
+
)
|
93
99
|
new_augmentation = get_augmentation_rules(rules=asdict(class_balancing_augmentation), num_ir=num_ir)[0]
|
94
100
|
new_augmentation.mixup = mixup
|
95
101
|
target_augmentations.append(new_augmentation)
|
96
102
|
return len(target_augmentations) - 1, target_augmentations
|
97
103
|
|
98
104
|
|
99
|
-
def get_class_balancing_augmentation(target: TargetFile, default_cba: AugmentationRule) -> AugmentationRule
|
100
|
-
"""
|
101
|
-
"""
|
105
|
+
def get_class_balancing_augmentation(target: TargetFile, default_cba: AugmentationRule) -> AugmentationRule:
|
106
|
+
"""Get the class balancing augmentation rule for the given target"""
|
102
107
|
if target.class_balancing_augmentation is not None:
|
103
108
|
return target.class_balancing_augmentation
|
104
109
|
return default_cba
|
sonusai/mixture/targets.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1
1
|
from sonusai.mixture.datatypes import AugmentationRules
|
2
2
|
from sonusai.mixture.datatypes import AugmentedTarget
|
3
3
|
from sonusai.mixture.datatypes import AugmentedTargets
|
4
|
-
from sonusai.mixture.datatypes import TargetFile
|
5
4
|
from sonusai.mixture.datatypes import TargetFiles
|
6
5
|
|
7
6
|
|
8
|
-
def get_augmented_targets(
|
9
|
-
|
10
|
-
|
7
|
+
def get_augmented_targets(
|
8
|
+
target_files: TargetFiles,
|
9
|
+
target_augmentations: AugmentationRules,
|
10
|
+
mixups: list[int] | None = None,
|
11
|
+
) -> AugmentedTargets:
|
11
12
|
from .augmentation import get_augmentation_indices_for_mixup
|
12
13
|
from .augmentation import get_mixups
|
13
14
|
|
@@ -19,85 +20,82 @@ def get_augmented_targets(target_files: TargetFiles,
|
|
19
20
|
augmentation_indices = get_augmentation_indices_for_mixup(target_augmentations, mixup)
|
20
21
|
for target_index in range(len(target_files)):
|
21
22
|
for augmentation_index in augmentation_indices:
|
22
|
-
augmented_targets.append(
|
23
|
-
|
23
|
+
augmented_targets.append(
|
24
|
+
AugmentedTarget(
|
25
|
+
target_id=target_index,
|
26
|
+
target_augmentation_id=augmentation_index,
|
27
|
+
)
|
28
|
+
)
|
24
29
|
|
25
30
|
return augmented_targets
|
26
31
|
|
27
32
|
|
28
|
-
def
|
29
|
-
|
30
|
-
index = [truth_setting.index for truth_setting in target.truth_settings]
|
31
|
-
|
32
|
-
# flatten, uniquify, and sort
|
33
|
-
return sorted(list(set([item for sublist in index for item in sublist])))
|
34
|
-
|
35
|
-
|
36
|
-
def get_truth_indices_for_augmented_target(augmented_target: AugmentedTarget, targets: TargetFiles) -> list[int]:
|
37
|
-
return get_truth_indices_for_target(targets[augmented_target.target_id])
|
33
|
+
def get_class_index_for_augmented_target(augmented_target: AugmentedTarget, targets: TargetFiles) -> list[int]:
|
34
|
+
return targets[augmented_target.target_id].class_indices
|
38
35
|
|
39
36
|
|
40
37
|
def get_mixup_for_augmented_target(augmented_target: AugmentedTarget, augmentations: AugmentationRules) -> int:
|
41
38
|
return augmentations[augmented_target.target_augmentation_id].mixup
|
42
39
|
|
43
40
|
|
44
|
-
def
|
45
|
-
|
46
|
-
allow_multiple: bool = False) -> list[int]:
|
47
|
-
"""Get a list of target indices containing the given truth index.
|
41
|
+
def get_target_ids_for_class_index(targets: TargetFiles, class_index: int, allow_multiple: bool = False) -> list[int]:
|
42
|
+
"""Get a list of target indices containing the given class index.
|
48
43
|
|
49
|
-
If allow_multiple is True, then include targets that contain multiple
|
44
|
+
If allow_multiple is True, then include targets that contain multiple class indices.
|
50
45
|
"""
|
51
46
|
target_indices = set()
|
52
47
|
for target_index, target in enumerate(targets):
|
53
|
-
indices =
|
48
|
+
indices = target.class_indices
|
54
49
|
if len(indices) == 1 or allow_multiple:
|
55
50
|
for index in indices:
|
56
|
-
if index ==
|
51
|
+
if index == class_index + 1:
|
57
52
|
target_indices.add(target_index)
|
58
53
|
|
59
|
-
return sorted(
|
54
|
+
return sorted(target_indices)
|
60
55
|
|
61
56
|
|
62
|
-
def
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
57
|
+
def get_augmented_target_ids_for_class_index(
|
58
|
+
augmented_targets: AugmentedTargets,
|
59
|
+
targets: TargetFiles,
|
60
|
+
augmentations: AugmentationRules,
|
61
|
+
class_index: int,
|
62
|
+
mixup: int,
|
63
|
+
allow_multiple: bool = False,
|
64
|
+
) -> list[int]:
|
65
|
+
"""Get a list of augmented target indices containing the given class index.
|
69
66
|
|
70
|
-
If allow_multiple is True, then include targets that contain multiple
|
67
|
+
If allow_multiple is True, then include targets that contain multiple class indices.
|
71
68
|
"""
|
72
69
|
augmented_target_ids = set()
|
73
70
|
for augmented_target_id, augmented_target in enumerate(augmented_targets):
|
74
71
|
if get_mixup_for_augmented_target(augmented_target=augmented_target, augmentations=augmentations) == mixup:
|
75
|
-
indices =
|
72
|
+
indices = get_class_index_for_augmented_target(augmented_target=augmented_target, targets=targets)
|
76
73
|
if len(indices) == 1 or allow_multiple:
|
77
74
|
for index in indices:
|
78
|
-
if index ==
|
75
|
+
if index == class_index + 1:
|
79
76
|
augmented_target_ids.add(augmented_target_id)
|
80
77
|
|
81
|
-
return sorted(
|
78
|
+
return sorted(augmented_target_ids)
|
82
79
|
|
83
80
|
|
84
|
-
def get_augmented_target_ids_by_class(
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
num_classes -= 1
|
92
|
-
|
81
|
+
def get_augmented_target_ids_by_class(
|
82
|
+
augmented_targets: AugmentedTargets,
|
83
|
+
targets: TargetFiles,
|
84
|
+
target_augmentations: AugmentationRules,
|
85
|
+
mixup: int,
|
86
|
+
num_classes: int,
|
87
|
+
) -> list[list[int]]:
|
93
88
|
indices = []
|
94
89
|
for idx in range(num_classes):
|
95
90
|
indices.append(
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
91
|
+
get_augmented_target_ids_for_class_index(
|
92
|
+
augmented_targets=augmented_targets,
|
93
|
+
targets=targets,
|
94
|
+
augmentations=target_augmentations,
|
95
|
+
class_index=idx,
|
96
|
+
mixup=mixup,
|
97
|
+
)
|
98
|
+
)
|
101
99
|
return indices
|
102
100
|
|
103
101
|
|
@@ -111,36 +109,40 @@ def get_target_augmentations_for_mixup(target_augmentations: AugmentationRules,
|
|
111
109
|
return [target_augmentation for target_augmentation in target_augmentations if target_augmentation.mixup == mixup]
|
112
110
|
|
113
111
|
|
114
|
-
def get_augmented_target_ids_for_mixup(
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
112
|
+
def get_augmented_target_ids_for_mixup(
|
113
|
+
augmented_targets: AugmentedTargets,
|
114
|
+
targets: TargetFiles,
|
115
|
+
target_augmentations: AugmentationRules,
|
116
|
+
mixup: int,
|
117
|
+
num_classes: int,
|
118
|
+
) -> list[list[int]]:
|
120
119
|
from collections import deque
|
121
120
|
from random import shuffle
|
122
121
|
|
123
|
-
from sonusai import SonusAIError
|
124
|
-
|
125
122
|
mixup_indices = []
|
126
123
|
|
127
124
|
if mixup == 1:
|
128
125
|
for index, augmented_target in enumerate(augmented_targets):
|
129
|
-
if
|
130
|
-
|
126
|
+
if (
|
127
|
+
get_mixup_for_augmented_target(
|
128
|
+
augmented_target=augmented_target,
|
129
|
+
augmentations=target_augmentations,
|
130
|
+
)
|
131
|
+
== 1
|
132
|
+
):
|
131
133
|
mixup_indices.append([index])
|
132
134
|
return mixup_indices
|
133
135
|
|
134
|
-
augmented_target_ids_by_class = get_augmented_target_ids_by_class(
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
136
|
+
augmented_target_ids_by_class = get_augmented_target_ids_by_class(
|
137
|
+
augmented_targets=augmented_targets,
|
138
|
+
targets=targets,
|
139
|
+
target_augmentations=target_augmentations,
|
140
|
+
mixup=mixup,
|
141
|
+
num_classes=num_classes,
|
142
|
+
)
|
140
143
|
|
141
144
|
if mixup > num_classes:
|
142
|
-
raise
|
143
|
-
f'Specified mixup, {mixup}, is greater than the number of classes, {num_classes}')
|
145
|
+
raise ValueError(f"Specified mixup, {mixup}, is greater than the number of classes, {num_classes}")
|
144
146
|
|
145
147
|
de: deque[int] = deque()
|
146
148
|
|
@@ -23,10 +23,10 @@ def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
|
|
23
23
|
|
24
24
|
from .constants import DEFAULT_NOISE
|
25
25
|
|
26
|
-
os.environ[
|
26
|
+
os.environ["default_noise"] = str(DEFAULT_NOISE) # noqa: SIM112
|
27
27
|
|
28
28
|
if isinstance(name, bytes):
|
29
|
-
name = name.decode(
|
29
|
+
name = name.decode("utf-8")
|
30
30
|
|
31
31
|
if isinstance(name, Path):
|
32
32
|
name = name.as_posix()
|
@@ -34,37 +34,37 @@ def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
|
|
34
34
|
name = os.fspath(name)
|
35
35
|
token_map: dict = {}
|
36
36
|
|
37
|
-
if
|
37
|
+
if "$" not in name and "%" not in name:
|
38
38
|
return name, token_map
|
39
39
|
|
40
|
-
var_chars = string.ascii_letters + string.digits +
|
41
|
-
quote = '
|
42
|
-
percent =
|
43
|
-
brace =
|
44
|
-
rbrace =
|
45
|
-
dollar =
|
40
|
+
var_chars = string.ascii_letters + string.digits + "_-"
|
41
|
+
quote = "'"
|
42
|
+
percent = "%"
|
43
|
+
brace = "{"
|
44
|
+
rbrace = "}"
|
45
|
+
dollar = "$"
|
46
46
|
environ = os.environ
|
47
47
|
|
48
48
|
result = name[:0]
|
49
49
|
index = 0
|
50
50
|
path_len = len(name)
|
51
51
|
while index < path_len:
|
52
|
-
c = name[index:index + 1]
|
52
|
+
c = name[index : index + 1]
|
53
53
|
if c == quote: # no expansion within single quotes
|
54
|
-
name = name[index + 1:]
|
54
|
+
name = name[index + 1 :]
|
55
55
|
path_len = len(name)
|
56
56
|
try:
|
57
57
|
index = name.index(c)
|
58
|
-
result += c + name[:index + 1]
|
58
|
+
result += c + name[: index + 1]
|
59
59
|
except ValueError:
|
60
60
|
result += c + name
|
61
61
|
index = path_len - 1
|
62
62
|
elif c == percent: # variable or '%'
|
63
|
-
if name[index + 1:index + 2] == percent:
|
63
|
+
if name[index + 1 : index + 2] == percent:
|
64
64
|
result += c
|
65
65
|
index += 1
|
66
66
|
else:
|
67
|
-
name = name[index + 1:]
|
67
|
+
name = name[index + 1 :]
|
68
68
|
path_len = len(name)
|
69
69
|
try:
|
70
70
|
index = name.index(percent)
|
@@ -75,7 +75,7 @@ def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
|
|
75
75
|
var = name[:index]
|
76
76
|
try:
|
77
77
|
if environ is None:
|
78
|
-
value = os.fsencode(os.environ[os.fsdecode(var)]).decode(
|
78
|
+
value = os.fsencode(os.environ[os.fsdecode(var)]).decode("utf-8") # type: ignore[unreachable]
|
79
79
|
else:
|
80
80
|
value = environ[var]
|
81
81
|
token_map[var] = value
|
@@ -83,11 +83,11 @@ def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
|
|
83
83
|
value = percent + var + percent
|
84
84
|
result += value
|
85
85
|
elif c == dollar: # variable or '$$'
|
86
|
-
if name[index + 1:index + 2] == dollar:
|
86
|
+
if name[index + 1 : index + 2] == dollar:
|
87
87
|
result += c
|
88
88
|
index += 1
|
89
|
-
elif name[index + 1:index + 2] == brace:
|
90
|
-
name = name[index + 2:]
|
89
|
+
elif name[index + 1 : index + 2] == brace:
|
90
|
+
name = name[index + 2 :]
|
91
91
|
path_len = len(name)
|
92
92
|
try:
|
93
93
|
index = name.index(rbrace)
|
@@ -98,7 +98,7 @@ def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
|
|
98
98
|
var = name[:index]
|
99
99
|
try:
|
100
100
|
if environ is None:
|
101
|
-
value = os.fsencode(os.environ[os.fsdecode(var)]).decode(
|
101
|
+
value = os.fsencode(os.environ[os.fsdecode(var)]).decode("utf-8") # type: ignore[unreachable]
|
102
102
|
else:
|
103
103
|
value = environ[var]
|
104
104
|
token_map[var] = value
|
@@ -108,14 +108,14 @@ def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
|
|
108
108
|
else:
|
109
109
|
var = name[:0]
|
110
110
|
index += 1
|
111
|
-
c = name[index:index + 1]
|
111
|
+
c = name[index : index + 1]
|
112
112
|
while c and c in var_chars:
|
113
113
|
var += c
|
114
114
|
index += 1
|
115
|
-
c = name[index:index + 1]
|
115
|
+
c = name[index : index + 1]
|
116
116
|
try:
|
117
117
|
if environ is None:
|
118
|
-
value = os.fsencode(os.environ[os.fsdecode(var)]).decode(
|
118
|
+
value = os.fsencode(os.environ[os.fsdecode(var)]).decode("utf-8") # type: ignore[unreachable]
|
119
119
|
else:
|
120
120
|
value = environ[var]
|
121
121
|
token_map[var] = value
|
@@ -139,5 +139,5 @@ def tokenized_replace(name: str, tokens: dict[str, str]) -> str:
|
|
139
139
|
:return: replaced string
|
140
140
|
"""
|
141
141
|
for key, value in tokens.items():
|
142
|
-
name = name.replace(value, f
|
142
|
+
name = name.replace(value, f"${key}")
|
143
143
|
return name
|
@@ -14,19 +14,18 @@ def read_impulse_response(name: str | Path) -> ImpulseResponseData:
|
|
14
14
|
import torch
|
15
15
|
import torchaudio
|
16
16
|
|
17
|
-
from sonusai import SonusAIError
|
18
17
|
from .tokenized_shell_vars import tokenized_expand
|
19
18
|
|
20
19
|
expanded_name, _ = tokenized_expand(name)
|
21
20
|
|
22
21
|
# Read impulse response data from audio file
|
23
22
|
try:
|
24
|
-
raw, sample_rate = torchaudio.load(expanded_name, backend=
|
23
|
+
raw, sample_rate = torchaudio.load(expanded_name, backend="soundfile")
|
25
24
|
except Exception as e:
|
26
25
|
if name != expanded_name:
|
27
|
-
raise
|
26
|
+
raise OSError(f"Error reading {name} (expanded: {expanded_name}): {e}") from e
|
28
27
|
else:
|
29
|
-
raise
|
28
|
+
raise OSError(f"Error reading {name}: {e}") from e
|
30
29
|
|
31
30
|
raw = torch.squeeze(raw[0, :])
|
32
31
|
offset = torch.argmax(raw)
|
@@ -49,7 +48,6 @@ def get_sample_rate(name: str | Path) -> int:
|
|
49
48
|
"""
|
50
49
|
import torchaudio
|
51
50
|
|
52
|
-
from sonusai import SonusAIError
|
53
51
|
from .tokenized_shell_vars import tokenized_expand
|
54
52
|
|
55
53
|
expanded_name, _ = tokenized_expand(name)
|
@@ -58,9 +56,9 @@ def get_sample_rate(name: str | Path) -> int:
|
|
58
56
|
return torchaudio.info(expanded_name).sample_rate
|
59
57
|
except Exception as e:
|
60
58
|
if name != expanded_name:
|
61
|
-
raise
|
59
|
+
raise OSError(f"Error reading {name} (expanded: {expanded_name}):\n{e}") from e
|
62
60
|
else:
|
63
|
-
raise
|
61
|
+
raise OSError(f"Error reading {name}:\n{e}") from e
|
64
62
|
|
65
63
|
|
66
64
|
def read_audio(name: str | Path) -> AudioT:
|
@@ -73,24 +71,25 @@ def read_audio(name: str | Path) -> AudioT:
|
|
73
71
|
import torch
|
74
72
|
import torchaudio
|
75
73
|
|
76
|
-
from sonusai import SonusAIError
|
77
74
|
from .constants import SAMPLE_RATE
|
78
75
|
from .tokenized_shell_vars import tokenized_expand
|
79
76
|
|
80
77
|
expanded_name, _ = tokenized_expand(name)
|
81
78
|
|
82
79
|
try:
|
83
|
-
out, samplerate = torchaudio.load(expanded_name, backend=
|
80
|
+
out, samplerate = torchaudio.load(expanded_name, backend="soundfile")
|
84
81
|
out = torch.reshape(out[0, :], (1, out.size()[1]))
|
85
|
-
out = torchaudio.functional.resample(
|
86
|
-
|
87
|
-
|
88
|
-
|
82
|
+
out = torchaudio.functional.resample(
|
83
|
+
out,
|
84
|
+
orig_freq=samplerate,
|
85
|
+
new_freq=SAMPLE_RATE,
|
86
|
+
resampling_method="sinc_interp_hann",
|
87
|
+
)
|
89
88
|
except Exception as e:
|
90
89
|
if name != expanded_name:
|
91
|
-
raise
|
90
|
+
raise OSError(f"Error reading {name} (expanded: {expanded_name}):\n{e}") from e
|
92
91
|
else:
|
93
|
-
raise
|
92
|
+
raise OSError(f"Error reading {name}:\n{e}") from e
|
94
93
|
|
95
94
|
result = np.squeeze(np.array(out))
|
96
95
|
return result
|
@@ -3,9 +3,7 @@ from sonusai.mixture.datatypes import Augmentation
|
|
3
3
|
from sonusai.mixture.datatypes import ImpulseResponseData
|
4
4
|
|
5
5
|
|
6
|
-
def apply_augmentation(audio: AudioT,
|
7
|
-
augmentation: Augmentation,
|
8
|
-
frame_length: int = 1) -> AudioT:
|
6
|
+
def apply_augmentation(audio: AudioT, augmentation: Augmentation, frame_length: int = 1) -> AudioT:
|
9
7
|
"""Apply augmentations to audio data using torchaudio.sox_effects
|
10
8
|
|
11
9
|
:param audio: Audio
|
@@ -17,7 +15,6 @@ def apply_augmentation(audio: AudioT,
|
|
17
15
|
import torch
|
18
16
|
import torchaudio
|
19
17
|
|
20
|
-
from sonusai import SonusAIError
|
21
18
|
from .augmentation import pad_audio_to_frame
|
22
19
|
from .constants import SAMPLE_RATE
|
23
20
|
|
@@ -28,29 +25,29 @@ def apply_augmentation(audio: AudioT,
|
|
28
25
|
# Normalize to globally set level (should this be a global config parameter,
|
29
26
|
# or hard-coded into the script?)
|
30
27
|
if augmentation.normalize is not None:
|
31
|
-
effects.append([
|
28
|
+
effects.append(["norm", str(augmentation.normalize)])
|
32
29
|
|
33
30
|
if augmentation.gain is not None:
|
34
|
-
effects.append([
|
31
|
+
effects.append(["gain", str(augmentation.gain)])
|
35
32
|
|
36
33
|
if augmentation.pitch is not None:
|
37
|
-
effects.append([
|
38
|
-
effects.append([
|
34
|
+
effects.append(["pitch", str(augmentation.pitch)])
|
35
|
+
effects.append(["rate", str(SAMPLE_RATE)])
|
39
36
|
|
40
37
|
if augmentation.tempo is not None:
|
41
|
-
effects.append([
|
38
|
+
effects.append(["tempo", "-s", str(augmentation.tempo)])
|
42
39
|
|
43
40
|
if augmentation.eq1 is not None:
|
44
|
-
effects.append([
|
41
|
+
effects.append(["equalizer", *[str(item) for item in augmentation.eq1]])
|
45
42
|
|
46
43
|
if augmentation.eq2 is not None:
|
47
|
-
effects.append([
|
44
|
+
effects.append(["equalizer", *[str(item) for item in augmentation.eq2]])
|
48
45
|
|
49
46
|
if augmentation.eq3 is not None:
|
50
|
-
effects.append([
|
47
|
+
effects.append(["equalizer", *[str(item) for item in augmentation.eq3]])
|
51
48
|
|
52
49
|
if augmentation.lpf is not None:
|
53
|
-
effects.append([
|
50
|
+
effects.append(["lowpass", "-2", str(augmentation.lpf), "0.707"])
|
54
51
|
|
55
52
|
if effects:
|
56
53
|
if audio.ndim == 1:
|
@@ -58,11 +55,9 @@ def apply_augmentation(audio: AudioT,
|
|
58
55
|
out = torch.tensor(audio)
|
59
56
|
|
60
57
|
try:
|
61
|
-
out, _ = torchaudio.sox_effects.apply_effects_tensor(out,
|
62
|
-
sample_rate=SAMPLE_RATE,
|
63
|
-
effects=effects)
|
58
|
+
out, _ = torchaudio.sox_effects.apply_effects_tensor(out, sample_rate=SAMPLE_RATE, effects=effects)
|
64
59
|
except Exception as e:
|
65
|
-
raise
|
60
|
+
raise RuntimeError(f"Error applying {augmentation}: {e}") from e
|
66
61
|
|
67
62
|
audio_out = np.squeeze(np.array(out))
|
68
63
|
else:
|
@@ -84,6 +79,7 @@ def apply_impulse_response(audio: AudioT, ir: ImpulseResponseData) -> AudioT:
|
|
84
79
|
import torchaudio
|
85
80
|
|
86
81
|
from sonusai.utils import linear_to_db
|
82
|
+
|
87
83
|
from .constants import SAMPLE_RATE
|
88
84
|
|
89
85
|
# Early exit if no ir or if all audio is zero
|
@@ -95,20 +91,20 @@ def apply_impulse_response(audio: AudioT, ir: ImpulseResponseData) -> AudioT:
|
|
95
91
|
|
96
92
|
# Convert audio to IR sample rate
|
97
93
|
audio_in = torch.reshape(torch.tensor(audio), (1, len(audio)))
|
98
|
-
audio_out, sr = torchaudio.sox_effects.apply_effects_tensor(
|
99
|
-
|
100
|
-
|
94
|
+
audio_out, sr = torchaudio.sox_effects.apply_effects_tensor(
|
95
|
+
audio_in, sample_rate=SAMPLE_RATE, effects=[["rate", str(ir.sample_rate)]]
|
96
|
+
)
|
101
97
|
|
102
98
|
# Apply IR and convert back to global sample rate
|
103
99
|
rir = torch.reshape(torch.tensor(ir.data), (1, len(ir.data)))
|
104
100
|
audio_out = torchaudio.functional.fftconvolve(audio_out, rir)
|
105
|
-
audio_out, sr = torchaudio.sox_effects.apply_effects_tensor(
|
106
|
-
|
107
|
-
|
101
|
+
audio_out, sr = torchaudio.sox_effects.apply_effects_tensor(
|
102
|
+
audio_out, sample_rate=ir.sample_rate, effects=[["rate", str(SAMPLE_RATE)]]
|
103
|
+
)
|
108
104
|
|
109
105
|
# Reset level to previous max value
|
110
|
-
audio_out, sr = torchaudio.sox_effects.apply_effects_tensor(
|
111
|
-
|
112
|
-
|
106
|
+
audio_out, sr = torchaudio.sox_effects.apply_effects_tensor(
|
107
|
+
audio_out, sample_rate=SAMPLE_RATE, effects=[["norm", str(max_db)]]
|
108
|
+
)
|
113
109
|
|
114
|
-
return np.squeeze(np.array(audio_out[:, :len(audio)]))
|
110
|
+
return np.squeeze(np.array(audio_out[:, : len(audio)]))
|