sonusai 1.0.16__cp311-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,344 @@
|
|
1
|
+
from ..datatypes import AudioT
|
2
|
+
from ..datatypes import Effects
|
3
|
+
from .mixdb import MixtureDatabase
|
4
|
+
|
5
|
+
|
6
|
+
def get_effect_rules(location: str, config: dict, test: bool = False) -> dict[str, list[Effects]]:
|
7
|
+
from ..datatypes import Effects
|
8
|
+
from ..utils.dataclass_from_dict import list_dataclass_from_dict
|
9
|
+
from .mixdb import MixtureDatabase
|
10
|
+
|
11
|
+
mixdb = MixtureDatabase(location, test)
|
12
|
+
|
13
|
+
rules: dict[str, list[Effects]] = {}
|
14
|
+
for category, source in config["sources"].items():
|
15
|
+
processed_rules: list[dict] = []
|
16
|
+
for rule in source["effects"]:
|
17
|
+
rule = _parse_ir_rule(rule, mixdb.num_ir_files)
|
18
|
+
processed_rules = _expand_effect_rules(processed_rules, rule)
|
19
|
+
rules[category] = list_dataclass_from_dict(list[Effects], processed_rules)
|
20
|
+
|
21
|
+
validate_rules(mixdb, rules)
|
22
|
+
return rules
|
23
|
+
|
24
|
+
|
25
|
+
def _expand_effect_rules(expanded_rules: list[dict], rule: dict) -> list[dict]:
|
26
|
+
from copy import deepcopy
|
27
|
+
|
28
|
+
from ..parse.expand import expand
|
29
|
+
|
30
|
+
for key in ("pre", "post"):
|
31
|
+
if key in rule:
|
32
|
+
value = rule[key]
|
33
|
+
for idx in range(len(value)):
|
34
|
+
new_rules = expand(value[idx])
|
35
|
+
if len(new_rules) > 1:
|
36
|
+
for new_rule in new_rules:
|
37
|
+
expanded_effect = deepcopy(rule)
|
38
|
+
new_value = deepcopy(value)
|
39
|
+
new_value[idx] = new_rule
|
40
|
+
expanded_effect[key] = new_value
|
41
|
+
_expand_effect_rules(expanded_rules, expanded_effect)
|
42
|
+
return expanded_rules
|
43
|
+
|
44
|
+
expanded_rules.append(rule)
|
45
|
+
return expanded_rules
|
46
|
+
|
47
|
+
|
48
|
+
def _parse_ir_rule(rule: dict, num_ir: int) -> dict:
|
49
|
+
from ..datatypes import EffectList
|
50
|
+
from .helpers import generic_ids_to_list
|
51
|
+
|
52
|
+
def _resolve_str(parameters: str) -> str:
|
53
|
+
if parameters.startswith("rand") or parameters.startswith("choose") or parameters.startswith("expand"):
|
54
|
+
return f"ir {parameters}"
|
55
|
+
|
56
|
+
irs = generic_ids_to_list(num_ir, parameters)
|
57
|
+
|
58
|
+
if not all(ro in range(num_ir) for ro in irs):
|
59
|
+
raise ValueError(f"Invalid ir of {parameters}")
|
60
|
+
|
61
|
+
if len(irs) == 1:
|
62
|
+
return f"ir {irs[0]}"
|
63
|
+
return f"ir expand({', '.join(map(str, irs))})"
|
64
|
+
|
65
|
+
def _process(rules_in: EffectList) -> EffectList:
|
66
|
+
rules_out: EffectList = []
|
67
|
+
|
68
|
+
for rule_in in rules_in:
|
69
|
+
parts = rule_in.split(maxsplit=1)
|
70
|
+
|
71
|
+
name = parts[0]
|
72
|
+
if name != "ir":
|
73
|
+
rules_out.append(rule_in)
|
74
|
+
continue
|
75
|
+
|
76
|
+
if len(parts) == 1:
|
77
|
+
continue
|
78
|
+
|
79
|
+
parameters = parts[1]
|
80
|
+
if parameters.isnumeric():
|
81
|
+
ir = int(parameters)
|
82
|
+
if ir not in range(num_ir):
|
83
|
+
raise ValueError(f"Invalid ir of {parameters}")
|
84
|
+
rules_out.append(rule_in)
|
85
|
+
continue
|
86
|
+
|
87
|
+
if isinstance(parameters, str):
|
88
|
+
rules_out.append(_resolve_str(parameters))
|
89
|
+
continue
|
90
|
+
|
91
|
+
raise ValueError(f"Invalid ir of {parameters}")
|
92
|
+
|
93
|
+
return rules_out
|
94
|
+
|
95
|
+
for key in ("pre", "post"):
|
96
|
+
if key in rule:
|
97
|
+
rule[key] = _process(rule[key])
|
98
|
+
|
99
|
+
return rule
|
100
|
+
|
101
|
+
|
102
|
+
def apply_effects(
|
103
|
+
mixdb: MixtureDatabase,
|
104
|
+
audio: AudioT,
|
105
|
+
effects: Effects,
|
106
|
+
pre: bool = True,
|
107
|
+
post: bool = True,
|
108
|
+
) -> AudioT:
|
109
|
+
"""Apply effects to audio data
|
110
|
+
|
111
|
+
:param mixdb: Mixture database
|
112
|
+
:param audio: Input audio
|
113
|
+
:param effects: Effects
|
114
|
+
:param pre: Apply pre-truth effects
|
115
|
+
:param post: Apply post-truth effects
|
116
|
+
:return: Output audio
|
117
|
+
"""
|
118
|
+
from ..datatypes import EffectList
|
119
|
+
from .ir_effects import apply_ir
|
120
|
+
from .ir_effects import read_ir
|
121
|
+
from .sox_effects import apply_sox_effects
|
122
|
+
|
123
|
+
def _process(audio_in: AudioT, effects_in) -> AudioT:
|
124
|
+
_effects: EffectList = []
|
125
|
+
for effect in effects_in:
|
126
|
+
if effect.startswith("ir "):
|
127
|
+
# Apply effects gathered so far
|
128
|
+
audio_in = apply_sox_effects(audio_in, _effects)
|
129
|
+
|
130
|
+
# Then empty the list of effects
|
131
|
+
_effects = []
|
132
|
+
|
133
|
+
# Apply IR
|
134
|
+
index = int(effect.split()[1])
|
135
|
+
audio_in = apply_ir(
|
136
|
+
audio=audio_in,
|
137
|
+
ir=read_ir(
|
138
|
+
name=mixdb.ir_file(index),
|
139
|
+
delay=mixdb.ir_delay(index),
|
140
|
+
use_cache=mixdb.use_cache,
|
141
|
+
),
|
142
|
+
)
|
143
|
+
else:
|
144
|
+
_effects.append(effect)
|
145
|
+
|
146
|
+
return apply_sox_effects(audio_in, _effects)
|
147
|
+
|
148
|
+
audio_out = audio.copy()
|
149
|
+
|
150
|
+
if pre:
|
151
|
+
audio_out = _process(audio_out, effects.pre)
|
152
|
+
|
153
|
+
if post:
|
154
|
+
audio_out = _process(audio_out, effects.post)
|
155
|
+
|
156
|
+
return audio_out
|
157
|
+
|
158
|
+
|
159
|
+
def estimate_effected_length(
|
160
|
+
samples: int,
|
161
|
+
effects: Effects,
|
162
|
+
frame_length: int = 1,
|
163
|
+
pre: bool = True,
|
164
|
+
post: bool = True,
|
165
|
+
) -> int:
|
166
|
+
"""Estimate the effected audio length
|
167
|
+
|
168
|
+
:param samples: Original length in samples
|
169
|
+
:param effects: Effects
|
170
|
+
:param frame_length: Length will be a multiple of this
|
171
|
+
:param pre: Apply pre-truth effects
|
172
|
+
:param post: Apply post-truth effects
|
173
|
+
:return: Estimated length in samples
|
174
|
+
"""
|
175
|
+
from .pad_audio import get_padded_length
|
176
|
+
|
177
|
+
def _update_samples(s: int, e: str) -> int:
|
178
|
+
import re
|
179
|
+
|
180
|
+
# speed factor[c]
|
181
|
+
speed_pattern = re.compile(r"^speed\s+(-?\d+(\.\d+)*)(c?)$")
|
182
|
+
result = re.search(speed_pattern, e)
|
183
|
+
if result:
|
184
|
+
value = float(result.group(1))
|
185
|
+
if result.group(3):
|
186
|
+
value = float(2 ** (value / 1200))
|
187
|
+
return int(s / value + 0.5)
|
188
|
+
|
189
|
+
# tempo [-q] [-m|-s|-l] factor [segment [search [overlap]]]
|
190
|
+
tempo_pattern = re.compile(r"^tempo\s+(-q\s+)?(((-m)|(-s)|(-l))\s+)?(\d+(\.\d+)*)")
|
191
|
+
result = re.search(tempo_pattern, e)
|
192
|
+
if result:
|
193
|
+
value = float(result.group(7))
|
194
|
+
return int(s / value + 0.5)
|
195
|
+
|
196
|
+
# other effects which do not affect length
|
197
|
+
return s
|
198
|
+
|
199
|
+
length = samples
|
200
|
+
|
201
|
+
if pre:
|
202
|
+
for effect in effects.pre:
|
203
|
+
length = _update_samples(length, effect)
|
204
|
+
|
205
|
+
if post:
|
206
|
+
for effect in effects.post:
|
207
|
+
length = _update_samples(length, effect)
|
208
|
+
|
209
|
+
return get_padded_length(length, frame_length)
|
210
|
+
|
211
|
+
|
212
|
+
def effects_from_rules(mixdb: MixtureDatabase, rules: Effects) -> Effects:
|
213
|
+
from copy import deepcopy
|
214
|
+
|
215
|
+
from ..parse.rand import rand
|
216
|
+
|
217
|
+
effects = deepcopy(rules)
|
218
|
+
for key in ("pre", "post"):
|
219
|
+
entries = getattr(effects, key)
|
220
|
+
for idx, entry in enumerate(entries):
|
221
|
+
if entry.find("rand") != -1:
|
222
|
+
entries[idx] = rand(entry)
|
223
|
+
if entry.startswith("ir choose"):
|
224
|
+
entries[idx] = _choose_ir(mixdb, entry)
|
225
|
+
setattr(effects, key, entries)
|
226
|
+
|
227
|
+
return effects
|
228
|
+
|
229
|
+
|
230
|
+
def conform_audio_to_length(audio: AudioT, length: int, loop: bool, start: int) -> AudioT:
|
231
|
+
"""Conform audio to the given length
|
232
|
+
|
233
|
+
:param audio: Audio to conform
|
234
|
+
:param length: Length of output
|
235
|
+
:param loop: Loop samples or pad
|
236
|
+
:param start: Starting sample offset
|
237
|
+
:return: Conformed audio
|
238
|
+
"""
|
239
|
+
import numpy as np
|
240
|
+
|
241
|
+
if loop:
|
242
|
+
return np.take(audio, range(start, start + length), mode="wrap")
|
243
|
+
|
244
|
+
# Non-loop mode
|
245
|
+
audio_slice = audio[start : start + length]
|
246
|
+
|
247
|
+
if len(audio_slice) >= length:
|
248
|
+
# We have enough samples, truncate
|
249
|
+
return audio_slice[:length]
|
250
|
+
else:
|
251
|
+
# We need padding
|
252
|
+
padding_needed = length - len(audio_slice)
|
253
|
+
return np.pad(audio_slice, (0, padding_needed))
|
254
|
+
|
255
|
+
|
256
|
+
def validate_rules(mixdb: MixtureDatabase, rules: dict[str, list[Effects]]) -> None:
|
257
|
+
from .sox_effects import validate_sox_effects
|
258
|
+
|
259
|
+
for rule_list in rules.values():
|
260
|
+
for rule in rule_list:
|
261
|
+
sox_effects: list[str] = []
|
262
|
+
effects = effects_from_rules(mixdb, rule)
|
263
|
+
|
264
|
+
for effect in effects.pre:
|
265
|
+
if not effect.startswith("ir"):
|
266
|
+
sox_effects.append(effect)
|
267
|
+
|
268
|
+
for effect in effects.post:
|
269
|
+
for check in ("speed", "tempo"):
|
270
|
+
if check in effect:
|
271
|
+
raise ValueError(f"'{check}' effect is not allowed in post-truth effect chain.")
|
272
|
+
|
273
|
+
if not effect.startswith("ir"):
|
274
|
+
sox_effects.append(effect)
|
275
|
+
|
276
|
+
validate_sox_effects(sox_effects)
|
277
|
+
|
278
|
+
|
279
|
+
def _choose_ir(mixdb: MixtureDatabase, directive: str) -> str:
|
280
|
+
"""Evaluate the 'choose' directive for an ir.
|
281
|
+
|
282
|
+
The directive is used to choose a random ir file from the database
|
283
|
+
and may take one of the following forms:
|
284
|
+
|
285
|
+
# choose a random ir file
|
286
|
+
ir choose()
|
287
|
+
|
288
|
+
# choose a random ir file between a and b, inclusive
|
289
|
+
ir choose(a, b)
|
290
|
+
|
291
|
+
# choose a random ir file with the specified tag
|
292
|
+
ir choose(tag)
|
293
|
+
|
294
|
+
:param mixdb: Mixture database
|
295
|
+
:param directive: Directive to evaluate
|
296
|
+
:return: Resolved value
|
297
|
+
"""
|
298
|
+
import re
|
299
|
+
from random import choice
|
300
|
+
from random import randint
|
301
|
+
|
302
|
+
choose_pattern = re.compile(r"^ir choose\(\)$")
|
303
|
+
choose_range_pattern = re.compile(r"^ir choose\((\d+),\s*(\d+)\)$")
|
304
|
+
choose_tag_pattern = re.compile(r"^ir choose\((\w+)\)$")
|
305
|
+
|
306
|
+
def choose_range_repl(m) -> str:
|
307
|
+
lower = int(m.group(1))
|
308
|
+
upper = int(m.group(2))
|
309
|
+
if (
|
310
|
+
lower < 0
|
311
|
+
or lower >= mixdb.num_ir_files
|
312
|
+
or upper < 0
|
313
|
+
or upper >= mixdb.num_ir_files
|
314
|
+
or lower >= upper
|
315
|
+
or str(lower) != m.group(1)
|
316
|
+
or str(upper) != m.group(2)
|
317
|
+
):
|
318
|
+
raise ValueError(
|
319
|
+
f"Invalid rule: '{directive}'. Values must be integers between 0 and {mixdb.num_ir_files - 1}."
|
320
|
+
)
|
321
|
+
return f"ir {randint(lower, upper)}" # noqa: S311
|
322
|
+
|
323
|
+
def choose_tag_repl(m) -> str:
|
324
|
+
return m.group(1)
|
325
|
+
|
326
|
+
if re.match(choose_pattern, directive):
|
327
|
+
return f"ir {randint(0, mixdb.num_ir_files - 1)}" # noqa: S311
|
328
|
+
|
329
|
+
if re.match(choose_range_pattern, directive):
|
330
|
+
try:
|
331
|
+
return f"ir {eval(re.sub(choose_range_pattern, choose_range_repl, directive))}" # noqa: S307
|
332
|
+
except Exception as e:
|
333
|
+
raise ValueError(
|
334
|
+
f"Invalid rule: '{directive}'. Values must be integers between 0 and {mixdb.num_ir_files - 1}."
|
335
|
+
) from e
|
336
|
+
|
337
|
+
if re.match(choose_tag_pattern, directive):
|
338
|
+
tag = re.sub(choose_tag_pattern, choose_tag_repl, directive)
|
339
|
+
if tag in mixdb.ir_tags:
|
340
|
+
return f"ir {choice(mixdb.ir_file_ids_for_tag(tag))}" # noqa: S311
|
341
|
+
|
342
|
+
raise ValueError(f"Invalid rule: '{directive}'. Tag, '{tag}', not found in database.")
|
343
|
+
|
344
|
+
raise ValueError(f"Invalid rule: '{directive}'.")
|
@@ -0,0 +1,78 @@
|
|
1
|
+
from ..datatypes import AudioT
|
2
|
+
from ..datatypes import Feature
|
3
|
+
|
4
|
+
|
5
|
+
def get_feature_from_audio(audio: AudioT, feature_mode: str) -> Feature:
|
6
|
+
"""Apply forward transform and generate feature data from audio data
|
7
|
+
|
8
|
+
:param audio: Time domain audio data [samples]
|
9
|
+
:param feature_mode: Feature mode
|
10
|
+
:return: Feature data [frames, strides, feature_parameters]
|
11
|
+
"""
|
12
|
+
import numpy as np
|
13
|
+
from pyaaware import FeatureGenerator
|
14
|
+
|
15
|
+
from ..datatypes import TransformConfig
|
16
|
+
from .helpers import forward_transform
|
17
|
+
|
18
|
+
fg = FeatureGenerator(feature_mode=feature_mode)
|
19
|
+
|
20
|
+
audio_f = forward_transform(
|
21
|
+
audio=audio,
|
22
|
+
config=TransformConfig(
|
23
|
+
length=fg.ftransform_length,
|
24
|
+
overlap=fg.ftransform_overlap,
|
25
|
+
bin_start=fg.bin_start,
|
26
|
+
bin_end=fg.bin_end,
|
27
|
+
ttype=fg.ftransform_ttype,
|
28
|
+
),
|
29
|
+
)
|
30
|
+
|
31
|
+
# Need to pad transform data to account for SOV modes
|
32
|
+
# audio_f [transform_frames, bins]
|
33
|
+
original_frames = audio_f.shape[0]
|
34
|
+
total_frames = np.ceil(original_frames / fg.step) * fg.step
|
35
|
+
pad_frames = total_frames - original_frames
|
36
|
+
padded_audio_f = np.pad(audio_f, ((0, pad_frames), (0, 0)), mode="constant", constant_values=0)
|
37
|
+
|
38
|
+
return fg.execute_all(padded_audio_f)[0]
|
39
|
+
|
40
|
+
|
41
|
+
def get_audio_from_feature(feature: Feature, feature_mode: str) -> AudioT:
|
42
|
+
"""Apply inverse transform to feature data to generate audio data
|
43
|
+
|
44
|
+
:param feature: Feature data [frames, stride=1, feature_parameters]
|
45
|
+
:param feature_mode: Feature mode
|
46
|
+
:return: Audio data [samples]
|
47
|
+
"""
|
48
|
+
import numpy as np
|
49
|
+
from pyaaware import FeatureGenerator
|
50
|
+
|
51
|
+
from ..datatypes import TransformConfig
|
52
|
+
from ..utils.compress import power_uncompress
|
53
|
+
from ..utils.stacked_complex import unstack_complex
|
54
|
+
from .helpers import inverse_transform
|
55
|
+
|
56
|
+
if feature.ndim != 3:
|
57
|
+
raise ValueError("feature must have 3 dimensions: [frames, stride=1, feature_parameters]")
|
58
|
+
|
59
|
+
if feature.shape[1] != 1:
|
60
|
+
raise ValueError("Strided feature data is not supported for audio extraction; stride must be 1.")
|
61
|
+
|
62
|
+
fg = FeatureGenerator(feature_mode=feature_mode)
|
63
|
+
|
64
|
+
feature_complex = unstack_complex(feature.squeeze())
|
65
|
+
if feature_mode[0:1] == "h":
|
66
|
+
feature_complex = power_uncompress(feature_complex)
|
67
|
+
return np.squeeze(
|
68
|
+
inverse_transform(
|
69
|
+
transform=feature_complex,
|
70
|
+
config=TransformConfig(
|
71
|
+
length=fg.itransform_length,
|
72
|
+
overlap=fg.itransform_overlap,
|
73
|
+
bin_start=fg.bin_start,
|
74
|
+
bin_end=fg.bin_end,
|
75
|
+
ttype=fg.itransform_ttype,
|
76
|
+
),
|
77
|
+
)
|
78
|
+
)
|