sonusai 0.15.8__py3-none-any.whl → 0.15.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/audiofe.py +293 -0
- sonusai/calc_metric_spenh.py +3 -3
- sonusai/data_generator/dataset_from_mixdb.py +1 -1
- sonusai/data_generator/keras_from_mixdb.py +1 -1
- sonusai/genft.py +2 -1
- sonusai/genmixdb.py +4 -4
- sonusai/keras_predict.py +1 -1
- sonusai/lsdb.py +2 -2
- sonusai/main.py +2 -2
- sonusai/mixture/__init__.py +3 -2
- sonusai/mixture/audio.py +0 -34
- sonusai/mixture/datatypes.py +1 -1
- sonusai/mixture/feature.py +75 -21
- sonusai/mixture/helpers.py +60 -30
- sonusai/mixture/log_duration_and_sizes.py +2 -2
- sonusai/mixture/mixdb.py +13 -10
- sonusai/mixture/spectral_mask.py +14 -14
- sonusai/mixture/truth_functions/data.py +1 -1
- sonusai/mixture/truth_functions/target.py +2 -2
- sonusai/onnx_predict.py +1 -1
- sonusai/plot.py +4 -4
- sonusai/post_spenh_targetf.py +8 -8
- sonusai/torchl_predict.py +71 -76
- sonusai/utils/__init__.py +4 -0
- sonusai/utils/audio_devices.py +41 -0
- sonusai/utils/calculate_input_shape.py +3 -4
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/reshape.py +11 -11
- sonusai/utils/wave.py +12 -5
- {sonusai-0.15.8.dist-info → sonusai-0.15.9.dist-info}/METADATA +8 -1
- {sonusai-0.15.8.dist-info → sonusai-0.15.9.dist-info}/RECORD +33 -31
- {sonusai-0.15.8.dist-info → sonusai-0.15.9.dist-info}/WHEEL +1 -1
- sonusai/evaluate.py +0 -245
- {sonusai-0.15.8.dist-info → sonusai-0.15.9.dist-info}/entry_points.txt +0 -0
sonusai/mixture/helpers.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1
1
|
from typing import Any
|
2
2
|
|
3
|
+
from pyaaware import ForwardTransform
|
4
|
+
from pyaaware import InverseTransform
|
5
|
+
|
6
|
+
from sonusai.mixture import EnergyT
|
3
7
|
from sonusai.mixture.datatypes import AudioF
|
4
8
|
from sonusai.mixture.datatypes import AudioT
|
5
9
|
from sonusai.mixture.datatypes import AudiosT
|
@@ -78,7 +82,7 @@ def get_feature_generator_info(fg_config: FeatureGeneratorConfig) -> FeatureGene
|
|
78
82
|
decimation=fg.decimation,
|
79
83
|
stride=fg.stride,
|
80
84
|
step=fg.step,
|
81
|
-
|
85
|
+
feature_parameters=fg.feature_parameters,
|
82
86
|
ft_config=TransformConfig(N=fg.ftransform_N,
|
83
87
|
R=fg.ftransform_R,
|
84
88
|
bin_start=fg.bin_start,
|
@@ -327,15 +331,14 @@ def get_ft(mixdb: MixtureDatabase, mixture: Mixture, mixture_audio: AudioT, trut
|
|
327
331
|
import numpy as np
|
328
332
|
from pyaaware import FeatureGenerator
|
329
333
|
|
330
|
-
from .spectral_mask import apply_spectral_mask
|
331
334
|
from .truth import truth_reduction
|
332
335
|
|
333
|
-
mixture_f = get_mixture_f(mixdb=mixdb, mixture_audio=mixture_audio)
|
336
|
+
mixture_f = get_mixture_f(mixdb=mixdb, mixture=mixture, mixture_audio=mixture_audio)
|
334
337
|
|
335
338
|
transform_frames = mixdb.mixture_transform_frames(mixture.samples)
|
336
339
|
feature_frames = mixdb.mixture_feature_frames(mixture.samples)
|
337
340
|
|
338
|
-
feature = np.empty((feature_frames, mixdb.fg_stride, mixdb.
|
341
|
+
feature = np.empty((feature_frames, mixdb.fg_stride, mixdb.feature_parameters), dtype=np.float32)
|
339
342
|
truth_f = np.empty((feature_frames, mixdb.num_classes), dtype=np.complex64)
|
340
343
|
|
341
344
|
fg = FeatureGenerator(**asdict(mixdb.fg_config))
|
@@ -350,11 +353,6 @@ def get_ft(mixdb: MixtureDatabase, mixture: Mixture, mixture_audio: AudioT, trut
|
|
350
353
|
truth_f[feature_frame] = fg.truth()
|
351
354
|
feature_frame += 1
|
352
355
|
|
353
|
-
if mixture.spectral_mask_id is not None:
|
354
|
-
feature = apply_spectral_mask(feature=feature,
|
355
|
-
spectral_mask=mixdb.spectral_mask(mixture.spectral_mask_id),
|
356
|
-
seed=mixture.spectral_mask_seed)
|
357
|
-
|
358
356
|
if np.isreal(truth_f).all():
|
359
357
|
return feature, truth_f.real
|
360
358
|
|
@@ -444,14 +442,35 @@ def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: AudiosT)
|
|
444
442
|
return np.sum(targets_ir, axis=0)
|
445
443
|
|
446
444
|
|
447
|
-
def get_mixture_f(mixdb: MixtureDatabase, mixture_audio: AudioT) -> AudioF:
|
445
|
+
def get_mixture_f(mixdb: MixtureDatabase, mixture: Mixture, mixture_audio: AudioT) -> AudioF:
|
448
446
|
"""Get the mixture transform for the given mixture
|
449
447
|
|
450
448
|
:param mixdb: Mixture database
|
449
|
+
:param mixture: Mixture record
|
451
450
|
:param mixture_audio: Mixture audio data for the given mixid
|
452
451
|
:return: Mixture transform data
|
453
452
|
"""
|
454
|
-
|
453
|
+
from .spectral_mask import apply_spectral_mask
|
454
|
+
|
455
|
+
mixture_f = forward_transform(mixture_audio, mixdb.ft_config)
|
456
|
+
|
457
|
+
if mixture.spectral_mask_id is not None:
|
458
|
+
mixture_f = apply_spectral_mask(audio_f=mixture_f,
|
459
|
+
spectral_mask=mixdb.spectral_mask(mixture.spectral_mask_id),
|
460
|
+
seed=mixture.spectral_mask_seed)
|
461
|
+
|
462
|
+
return mixture_f
|
463
|
+
|
464
|
+
|
465
|
+
def get_transform_from_audio(audio: AudioT, transform: ForwardTransform) -> tuple[AudioF, EnergyT]:
|
466
|
+
"""Apply forward transform to input audio data to generate transform data
|
467
|
+
|
468
|
+
:param audio: Time domain data [samples]
|
469
|
+
:param transform: ForwardTransform object
|
470
|
+
:return: Frequency domain data [frames, bins], Energy [frames]
|
471
|
+
"""
|
472
|
+
f, e = transform.execute_all(audio)
|
473
|
+
return f.transpose(), e
|
455
474
|
|
456
475
|
|
457
476
|
def forward_transform(audio: AudioT, config: TransformConfig) -> AudioF:
|
@@ -465,17 +484,30 @@ def forward_transform(audio: AudioT, config: TransformConfig) -> AudioF:
|
|
465
484
|
"""
|
466
485
|
from pyaaware import AawareForwardTransform
|
467
486
|
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
bin_end=config.bin_end,
|
475
|
-
ttype=config.ttype))
|
487
|
+
audio_f, _ = get_transform_from_audio(audio=audio,
|
488
|
+
transform=AawareForwardTransform(N=config.N,
|
489
|
+
R=config.R,
|
490
|
+
bin_start=config.bin_start,
|
491
|
+
bin_end=config.bin_end,
|
492
|
+
ttype=config.ttype))
|
476
493
|
return audio_f
|
477
494
|
|
478
495
|
|
496
|
+
def get_audio_from_transform(data: AudioF, transform: InverseTransform, trim: bool = True) -> tuple[AudioT, EnergyT]:
|
497
|
+
"""Apply inverse transform to input transform data to generate audio data
|
498
|
+
|
499
|
+
:param data: Frequency domain data [frames, bins]
|
500
|
+
:param transform: InverseTransform object
|
501
|
+
:param trim: Removes starting samples so output waveform will be time-aligned with input waveform to the transform
|
502
|
+
:return: Time domain data [samples], Energy [frames]
|
503
|
+
"""
|
504
|
+
t, e = transform.execute_all(data.transpose())
|
505
|
+
if trim:
|
506
|
+
t = t[transform.N - transform.R:]
|
507
|
+
|
508
|
+
return t, e
|
509
|
+
|
510
|
+
|
479
511
|
def inverse_transform(transform: AudioF, config: TransformConfig, trim: bool = True) -> AudioT:
|
480
512
|
"""Transform frequency domain data into time domain using the inverse transform config from the feature
|
481
513
|
|
@@ -490,16 +522,14 @@ def inverse_transform(transform: AudioF, config: TransformConfig, trim: bool = T
|
|
490
522
|
import numpy as np
|
491
523
|
from pyaaware import AawareInverseTransform
|
492
524
|
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
gain=np.float32(1)),
|
502
|
-
trim=trim)
|
525
|
+
audio, _ = get_audio_from_transform(data=transform,
|
526
|
+
transform=AawareInverseTransform(N=config.N,
|
527
|
+
R=config.R,
|
528
|
+
bin_start=config.bin_start,
|
529
|
+
bin_end=config.bin_end,
|
530
|
+
ttype=config.ttype,
|
531
|
+
gain=np.float32(1)),
|
532
|
+
trim=trim)
|
503
533
|
return audio
|
504
534
|
|
505
535
|
|
@@ -534,7 +564,7 @@ def augmented_target_samples(target_files: TargetFiles,
|
|
534
564
|
it = list(product(*[target_ids, target_augmentation_ids]))
|
535
565
|
return sum([estimate_augmented_length_from_length(
|
536
566
|
length=target_files[fi].samples,
|
537
|
-
tempo=target_augmentations[ai].tempo,
|
567
|
+
tempo=float(target_augmentations[ai].tempo),
|
538
568
|
frame_length=feature_step_samples) for fi, ai, in it])
|
539
569
|
|
540
570
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
def log_duration_and_sizes(total_duration: float,
|
2
2
|
num_classes: int,
|
3
3
|
feature_step_samples: int,
|
4
|
-
|
4
|
+
feature_parameters: int,
|
5
5
|
stride: int,
|
6
6
|
desc: str) -> None:
|
7
7
|
from sonusai import logger
|
@@ -14,7 +14,7 @@ def log_duration_and_sizes(total_duration: float,
|
|
14
14
|
total_samples = int(total_duration * SAMPLE_RATE)
|
15
15
|
mixture_bytes = total_samples * SAMPLE_BYTES
|
16
16
|
truth_t_bytes = total_samples * num_classes * FLOAT_BYTES
|
17
|
-
feature_bytes = total_samples / feature_step_samples * stride *
|
17
|
+
feature_bytes = total_samples / feature_step_samples * stride * feature_parameters * FLOAT_BYTES
|
18
18
|
truth_f_bytes = total_samples / feature_step_samples * num_classes * FLOAT_BYTES
|
19
19
|
|
20
20
|
logger.info('')
|
sonusai/mixture/mixdb.py
CHANGED
@@ -248,8 +248,8 @@ class MixtureDatabase:
|
|
248
248
|
return self.fg_info.step
|
249
249
|
|
250
250
|
@cached_property
|
251
|
-
def
|
252
|
-
return self.fg_info.
|
251
|
+
def feature_parameters(self) -> int:
|
252
|
+
return self.fg_info.feature_parameters
|
253
253
|
|
254
254
|
@cached_property
|
255
255
|
def ft_config(self) -> TransformConfig:
|
@@ -809,11 +809,20 @@ class MixtureDatabase:
|
|
809
809
|
:return: Mixture transform data
|
810
810
|
"""
|
811
811
|
from .helpers import forward_transform
|
812
|
+
from .spectral_mask import apply_spectral_mask
|
812
813
|
|
813
814
|
if force or mixture is None:
|
814
815
|
mixture = self.mixture_mixture(m_id, targets, target, noise, force)
|
815
816
|
|
816
|
-
|
817
|
+
mixture_f = forward_transform(mixture, self.ft_config)
|
818
|
+
|
819
|
+
m = self.mixture(m_id)
|
820
|
+
if m.spectral_mask_id is not None:
|
821
|
+
mixture_f = apply_spectral_mask(audio_f=mixture_f,
|
822
|
+
spectral_mask=self.spectral_mask(int(m.spectral_mask_id)),
|
823
|
+
seed=m.spectral_mask_seed)
|
824
|
+
|
825
|
+
return mixture_f
|
817
826
|
|
818
827
|
def mixture_truth_t(self,
|
819
828
|
m_id: int,
|
@@ -938,7 +947,6 @@ class MixtureDatabase:
|
|
938
947
|
import numpy as np
|
939
948
|
from pyaaware import FeatureGenerator
|
940
949
|
|
941
|
-
from .spectral_mask import apply_spectral_mask
|
942
950
|
from .truth import truth_reduction
|
943
951
|
|
944
952
|
if not force:
|
@@ -964,7 +972,7 @@ class MixtureDatabase:
|
|
964
972
|
if truth_t is None:
|
965
973
|
truth_t = np.zeros((m.samples, self.num_classes), dtype=np.float32)
|
966
974
|
|
967
|
-
feature = np.empty((feature_frames, self.fg_stride, self.
|
975
|
+
feature = np.empty((feature_frames, self.fg_stride, self.feature_parameters), dtype=np.float32)
|
968
976
|
truth_f = np.empty((feature_frames, self.num_classes), dtype=np.complex64)
|
969
977
|
|
970
978
|
fg = FeatureGenerator(**asdict(self.fg_config))
|
@@ -979,11 +987,6 @@ class MixtureDatabase:
|
|
979
987
|
truth_f[feature_frame] = fg.truth()
|
980
988
|
feature_frame += 1
|
981
989
|
|
982
|
-
if m.spectral_mask_id is not None:
|
983
|
-
feature = apply_spectral_mask(feature=feature,
|
984
|
-
spectral_mask=self.spectral_mask(int(m.spectral_mask_id)),
|
985
|
-
seed=m.spectral_mask_seed)
|
986
|
-
|
987
990
|
if np.isreal(truth_f).all():
|
988
991
|
return feature, truth_f.real
|
989
992
|
|
sonusai/mixture/spectral_mask.py
CHANGED
@@ -1,23 +1,23 @@
|
|
1
|
-
from sonusai.mixture.datatypes import
|
1
|
+
from sonusai.mixture.datatypes import AudioF
|
2
2
|
from sonusai.mixture.datatypes import SpectralMask
|
3
3
|
|
4
4
|
|
5
|
-
def apply_spectral_mask(
|
5
|
+
def apply_spectral_mask(audio_f: AudioF, spectral_mask: SpectralMask, seed: int = None) -> AudioF:
|
6
6
|
"""Apply frequency and time masking
|
7
7
|
|
8
8
|
Implementation of SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
|
9
9
|
|
10
10
|
Ref: https://arxiv.org/pdf/1904.08779.pdf
|
11
11
|
|
12
|
-
f_width consecutive
|
13
|
-
distribution from 0 to the f_max_width, and f_start is chosen from [0,
|
12
|
+
f_width consecutive bins [f_start, f_start + f_width) are masked, where f_width is chosen from a uniform
|
13
|
+
distribution from 0 to the f_max_width, and f_start is chosen from [0, bins - f_width).
|
14
14
|
|
15
15
|
t_width consecutive frames [t_start, t_start + t_width) are masked, where t_width is chosen from a uniform
|
16
16
|
distribution from 0 to the t_max_width, and t_start is chosen from [0, frames - t_width).
|
17
17
|
|
18
18
|
A time mask cannot be wider than t_max_percent times the number of frames.
|
19
19
|
|
20
|
-
:param
|
20
|
+
:param audio_f: Numpy array of transform audio data [frames, bins]
|
21
21
|
:param spectral_mask: Spectral mask parameters
|
22
22
|
:param seed: Random number seed
|
23
23
|
:return: Augmented feature
|
@@ -26,28 +26,28 @@ def apply_spectral_mask(feature: Feature, spectral_mask: SpectralMask, seed: int
|
|
26
26
|
|
27
27
|
from sonusai import SonusAIError
|
28
28
|
|
29
|
-
if
|
30
|
-
raise SonusAIError('feature input must have three dimensions [frames,
|
29
|
+
if audio_f.ndim != 2:
|
30
|
+
raise SonusAIError('feature input must have three dimensions [frames, bins]')
|
31
31
|
|
32
|
-
frames,
|
32
|
+
frames, bins = audio_f.shape
|
33
33
|
|
34
34
|
f_max_width = spectral_mask.f_max_width
|
35
|
-
if f_max_width not in range(0,
|
36
|
-
f_max_width =
|
35
|
+
if f_max_width not in range(0, bins + 1):
|
36
|
+
f_max_width = bins
|
37
37
|
|
38
38
|
rng = np.random.default_rng(seed)
|
39
39
|
|
40
40
|
# apply f_num frequency masks to the feature
|
41
41
|
for _ in range(spectral_mask.f_num):
|
42
42
|
f_width = int(rng.uniform(0, f_max_width))
|
43
|
-
f_start = rng.integers(0,
|
44
|
-
|
43
|
+
f_start = rng.integers(0, bins - f_width, endpoint=True)
|
44
|
+
audio_f[:, f_start:f_start + f_width] = 0
|
45
45
|
|
46
46
|
# apply t_num time masks to the feature
|
47
47
|
t_upper_bound = int(spectral_mask.t_max_percent / 100 * frames)
|
48
48
|
for _ in range(spectral_mask.t_num):
|
49
49
|
t_width = min(int(rng.uniform(0, spectral_mask.t_max_width)), t_upper_bound)
|
50
50
|
t_start = rng.integers(0, frames - t_width, endpoint=True)
|
51
|
-
|
51
|
+
audio_f[t_start:t_start + t_width, :] = 0
|
52
52
|
|
53
|
-
return
|
53
|
+
return audio_f
|
@@ -19,7 +19,7 @@ Output shape: [:, num_classes]
|
|
19
19
|
|
20
20
|
from sonusai import SonusAIError
|
21
21
|
|
22
|
-
if data.config.num_classes != data.
|
22
|
+
if data.config.num_classes != data.feature_parameters:
|
23
23
|
raise SonusAIError(f'Invalid num_classes for target_f truth: {data.config.num_classes}')
|
24
24
|
|
25
25
|
target_freq = _execute_fft(data.target_audio, data.target_fft, len(data.offsets))
|
@@ -51,7 +51,7 @@ Output shape: [:, 2 * num_classes]
|
|
51
51
|
"""
|
52
52
|
from sonusai import SonusAIError
|
53
53
|
|
54
|
-
if data.config.num_classes != 2 * data.
|
54
|
+
if data.config.num_classes != 2 * data.feature_parameters:
|
55
55
|
raise SonusAIError(f'Invalid num_classes for target_mixture_f truth: {data.config.num_classes}')
|
56
56
|
|
57
57
|
target_freq = _execute_fft(data.target_audio, data.target_fft, len(data.offsets))
|
sonusai/onnx_predict.py
CHANGED
@@ -105,7 +105,7 @@ def main() -> None:
|
|
105
105
|
logger.info('')
|
106
106
|
logger.info(f'Run prediction on {input_name}')
|
107
107
|
audio = read_audio(input_name)
|
108
|
-
feature = get_feature_from_audio(audio=audio,
|
108
|
+
feature = get_feature_from_audio(audio=audio, feature_mode=model_metadata.feature)
|
109
109
|
|
110
110
|
predict = pad_and_predict(feature=feature,
|
111
111
|
model_name=model_name,
|
sonusai/plot.py
CHANGED
@@ -314,7 +314,7 @@ def main() -> None:
|
|
314
314
|
raise SonusAIError('Must specify MODEL when input is WAV')
|
315
315
|
|
316
316
|
mixture_audio = read_audio(input_name)
|
317
|
-
feature = get_feature_from_audio(audio=mixture_audio,
|
317
|
+
feature = get_feature_from_audio(audio=mixture_audio, feature_mode=model.feature)
|
318
318
|
fg_config = FeatureGeneratorConfig(feature_mode=model.feature,
|
319
319
|
num_classes=model.output_shape[-1],
|
320
320
|
truth_mutex=False)
|
@@ -406,11 +406,11 @@ def main() -> None:
|
|
406
406
|
title = f'{input_name}'
|
407
407
|
pdf_name = f'{base_name}-plot.pdf'
|
408
408
|
|
409
|
-
# Original size [frames, stride,
|
409
|
+
# Original size [frames, stride, feature_parameters]
|
410
410
|
# Decimate in the stride dimension
|
411
|
-
# Reshape to get frames*decimated_stride,
|
411
|
+
# Reshape to get frames*decimated_stride, feature_parameters
|
412
412
|
if feature.ndim != 3:
|
413
|
-
raise SonusAIError(f'feature does not have 3 dimensions: frames, stride,
|
413
|
+
raise SonusAIError(f'feature does not have 3 dimensions: frames, stride, feature_parameters')
|
414
414
|
spectrogram = feature[:, -fg_step:, :]
|
415
415
|
spectrogram = np.reshape(spectrogram, (spectrogram.shape[0] * spectrogram.shape[1], spectrogram.shape[2]))
|
416
416
|
|
sonusai/post_spenh_targetf.py
CHANGED
@@ -123,7 +123,7 @@ def _process(file: str) -> None:
|
|
123
123
|
from pyaaware import AawareInverseTransform
|
124
124
|
|
125
125
|
from sonusai import SonusAIError
|
126
|
-
from sonusai.mixture import
|
126
|
+
from sonusai.mixture import get_audio_from_transform
|
127
127
|
from sonusai.utils import float_to_int16
|
128
128
|
from sonusai.utils import unstack_complex
|
129
129
|
from sonusai.utils import write_wav
|
@@ -135,13 +135,13 @@ def _process(file: str) -> None:
|
|
135
135
|
raise SonusAIError(f'Error reading {file}: {e}')
|
136
136
|
|
137
137
|
output_name = join(MP_GLOBAL.output_dir, splitext(basename(file))[0] + '.wav')
|
138
|
-
audio, _ =
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
138
|
+
audio, _ = get_audio_from_transform(data=predict,
|
139
|
+
transform=AawareInverseTransform(N=MP_GLOBAL.N,
|
140
|
+
R=MP_GLOBAL.R,
|
141
|
+
bin_start=MP_GLOBAL.bin_start,
|
142
|
+
bin_end=MP_GLOBAL.bin_end,
|
143
|
+
ttype=MP_GLOBAL.ttype,
|
144
|
+
gain=np.float32(1)))
|
145
145
|
write_wav(name=output_name, audio=float_to_int16(audio))
|
146
146
|
|
147
147
|
|
sonusai/torchl_predict.py
CHANGED
@@ -43,15 +43,38 @@ Outputs the following to tpredict-<TIMESTAMP> directory:
|
|
43
43
|
torch_predict.log
|
44
44
|
|
45
45
|
"""
|
46
|
+
from os import makedirs
|
47
|
+
from os.path import basename
|
48
|
+
from os.path import isdir
|
46
49
|
from os.path import join
|
50
|
+
from os.path import normpath
|
51
|
+
from os.path import splitext
|
47
52
|
from typing import Any
|
48
53
|
|
49
54
|
import h5py
|
50
55
|
import torch
|
56
|
+
from docopt import docopt
|
57
|
+
from lightning.pytorch import Trainer
|
51
58
|
from lightning.pytorch.callbacks import BasePredictionWriter
|
59
|
+
from pyaaware import FeatureGenerator
|
60
|
+
from pyaaware import TorchInverseTransform
|
61
|
+
from torchinfo import summary
|
52
62
|
|
63
|
+
import sonusai
|
64
|
+
from sonusai import create_file_handler
|
65
|
+
from sonusai import initial_log_messages
|
53
66
|
from sonusai import logger
|
67
|
+
from sonusai import update_console_handler
|
68
|
+
from sonusai.data_generator import TorchFromMixtureDatabase
|
54
69
|
from sonusai.mixture import Feature
|
70
|
+
from sonusai.mixture import MixtureDatabase
|
71
|
+
from sonusai.mixture import get_audio_from_feature
|
72
|
+
from sonusai.mixture import get_feature_from_audio
|
73
|
+
from sonusai.mixture import read_audio
|
74
|
+
from sonusai.utils import create_ts_name
|
75
|
+
from sonusai.utils import import_keras_model
|
76
|
+
from sonusai.utils import trim_docstring
|
77
|
+
from sonusai.utils import write_wav
|
55
78
|
|
56
79
|
|
57
80
|
class CustomWriter(BasePredictionWriter):
|
@@ -61,7 +84,7 @@ class CustomWriter(BasePredictionWriter):
|
|
61
84
|
|
62
85
|
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
|
63
86
|
# this will create N (num processes) files in `output_dir` each containing
|
64
|
-
# the predictions of
|
87
|
+
# the predictions of its respective rank
|
65
88
|
# torch.save(predictions, os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt"))
|
66
89
|
|
67
90
|
# optionally, you can also save `batch_indices` to get the information about the data index
|
@@ -119,11 +142,6 @@ def power_uncompress(real, imag):
|
|
119
142
|
|
120
143
|
|
121
144
|
def main() -> None:
|
122
|
-
from docopt import docopt
|
123
|
-
|
124
|
-
import sonusai
|
125
|
-
from sonusai.utils import trim_docstring
|
126
|
-
|
127
145
|
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
128
146
|
|
129
147
|
verbose = args['--verbose']
|
@@ -139,27 +157,6 @@ def main() -> None:
|
|
139
157
|
wavdbg = args['--wavdbg'] # write .wav if true
|
140
158
|
input_name = args['INPUT']
|
141
159
|
|
142
|
-
from os import makedirs
|
143
|
-
from os.path import basename
|
144
|
-
from os.path import isdir
|
145
|
-
from os.path import isfile
|
146
|
-
from os.path import join
|
147
|
-
from os.path import splitext
|
148
|
-
from os.path import normpath
|
149
|
-
import h5py
|
150
|
-
# from sonusai.utils import float_to_int16
|
151
|
-
|
152
|
-
from torchinfo import summary
|
153
|
-
from sonusai import create_file_handler
|
154
|
-
from sonusai import initial_log_messages
|
155
|
-
from sonusai import update_console_handler
|
156
|
-
from sonusai.mixture import MixtureDatabase
|
157
|
-
from sonusai.mixture import get_feature_from_audio
|
158
|
-
from sonusai.utils import import_keras_model
|
159
|
-
from sonusai.mixture import read_audio
|
160
|
-
from sonusai.utils import create_ts_name
|
161
|
-
from sonusai.data_generator import TorchFromMixtureDatabase
|
162
|
-
|
163
160
|
if batch_size is not None:
|
164
161
|
batch_size = int(batch_size)
|
165
162
|
if batch_size != 1:
|
@@ -222,6 +219,8 @@ def main() -> None:
|
|
222
219
|
hparams['timesteps'] = timesteps
|
223
220
|
|
224
221
|
logger.info(f'Building model with hparams and batch_size={batch_size}, timesteps={timesteps}')
|
222
|
+
# hparams['cl_per_wght'] = 0.0
|
223
|
+
# hparams['feature'] = 'hum00ns1'
|
225
224
|
try:
|
226
225
|
model = litemodule.MyHyperModel(**hparams) # use hparams
|
227
226
|
# litemodule.MyHyperModel.load_from_checkpoint(ckpt_name, **hparams)
|
@@ -303,33 +302,25 @@ def main() -> None:
|
|
303
302
|
drop_last=False,
|
304
303
|
num_workers=dlcpu)
|
305
304
|
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
bin_end=fg.bin_end,
|
321
|
-
ttype=fg.itransform_ttype)
|
322
|
-
|
323
|
-
if mixdb.target_files[0].truth_settings[0].function == 'target_f' or \
|
324
|
-
mixdb.target_files[0].truth_settings[0].function == 'target_mixture_f':
|
325
|
-
enable_truth_wav = True
|
326
|
-
else:
|
327
|
-
enable_truth_wav = False
|
328
|
-
|
305
|
+
# Info needed to set up inverse transform
|
306
|
+
half = model.num_classes // 2
|
307
|
+
fg = FeatureGenerator(feature_mode=model.hparams.feature,
|
308
|
+
num_classes=model.num_classes,
|
309
|
+
truth_mutex=model.truth_mutex)
|
310
|
+
itf = TorchInverseTransform(N=fg.itransform_N,
|
311
|
+
R=fg.itransform_R,
|
312
|
+
bin_start=fg.bin_start,
|
313
|
+
bin_end=fg.bin_end,
|
314
|
+
ttype=fg.itransform_ttype)
|
315
|
+
|
316
|
+
enable_truth_wav = False
|
317
|
+
enable_mix_wav = False
|
318
|
+
if wavdbg:
|
329
319
|
if mixdb.target_files[0].truth_settings[0].function == 'target_mixture_f':
|
330
320
|
enable_mix_wav = True
|
331
|
-
|
332
|
-
|
321
|
+
enable_truth_wav = True
|
322
|
+
elif mixdb.target_files[0].truth_settings[0].function == 'target_f':
|
323
|
+
enable_truth_wav = True
|
333
324
|
|
334
325
|
if reset:
|
335
326
|
logger.info(f'Running {mixdb.num_mixtures} mixtures individually with model reset ...')
|
@@ -351,26 +342,25 @@ def main() -> None:
|
|
351
342
|
if wavdbg:
|
352
343
|
owav_base = splitext(output_name)[0]
|
353
344
|
tmp = torch.complex(ypred[..., :half], ypred[..., half:]).permute(2, 0, 1).detach()
|
345
|
+
itf.reset()
|
354
346
|
predwav, _ = itf.execute_all(tmp)
|
355
|
-
# predwav, _ = calculate_audio_from_transform(tmp, itf, trim=True)
|
356
|
-
|
347
|
+
# predwav, _ = calculate_audio_from_transform(tmp.numpy(), itf, trim=True)
|
348
|
+
write_wav(owav_base + '.wav', predwav.permute([1, 0]).numpy(), 16000)
|
357
349
|
if enable_truth_wav:
|
358
350
|
# Note this support truth type target_f and target_mixture_f
|
359
351
|
tmp = torch.complex(val[0][..., :half], val[0][..., half:2 * half]).permute(2, 0, 1).detach()
|
352
|
+
itf.reset()
|
360
353
|
truthwav, _ = itf.execute_all(tmp)
|
361
|
-
|
362
|
-
bits_per_sample=16)
|
354
|
+
write_wav(owav_base + '_truth.wav', truthwav.permute([1, 0]).numpy(), 16000)
|
363
355
|
|
364
356
|
if enable_mix_wav:
|
365
357
|
tmp = torch.complex(val[0][..., 2 * half:3 * half], val[0][..., 3 * half:]).permute(2, 0, 1)
|
358
|
+
itf.reset()
|
366
359
|
mixwav, _ = itf.execute_all(tmp.detach())
|
367
|
-
|
368
|
-
bits_per_sample=16)
|
369
|
-
# write_wav(owav_base + "_truth.wav", truthwav, 16000)
|
360
|
+
write_wav(owav_base + '_mix.wav', mixwav.permute([1, 0]).numpy(), 16000)
|
370
361
|
|
371
362
|
else:
|
372
363
|
logger.info(f'Running {mixdb.num_mixtures} mixtures with model builtin prediction loop ...')
|
373
|
-
from lightning.pytorch import Trainer
|
374
364
|
pred_writer = CustomWriter(output_dir=output_dir, write_interval="epoch")
|
375
365
|
trainer = Trainer(default_root_dir=output_dir,
|
376
366
|
callbacks=[pred_writer],
|
@@ -489,32 +479,37 @@ def main() -> None:
|
|
489
479
|
# logger.info(f'Saved results to {output_dir}')
|
490
480
|
# return
|
491
481
|
|
492
|
-
|
493
|
-
logger.exception(f'Do not know how to process input from {input_name}')
|
494
|
-
raise SystemExit(1)
|
495
|
-
|
496
|
-
logger.info(f'Run prediction on {len(input_name):,} WAV files')
|
482
|
+
logger.info(f'Run prediction on {len(input_name):,} audio files')
|
497
483
|
for file in input_name:
|
498
|
-
# Convert
|
499
|
-
|
500
|
-
feature = get_feature_from_audio(audio=
|
484
|
+
# Convert audio to feature data
|
485
|
+
audio_in = read_audio(file)
|
486
|
+
feature = get_feature_from_audio(audio=audio_in, feature_mode=model.hparams.feature)
|
501
487
|
|
502
|
-
|
503
|
-
|
504
|
-
# feature=feature,
|
505
|
-
# frames_per_batch=frames_per_batch)
|
488
|
+
with torch.no_grad():
|
489
|
+
predict = model(torch.tensor(feature))
|
506
490
|
|
507
|
-
|
491
|
+
audio_out = get_audio_from_feature(feature=predict.numpy(), feature_mode=model.hparams.feature)
|
508
492
|
|
509
493
|
output_name = join(output_dir, splitext(basename(file))[0] + '.h5')
|
510
494
|
with h5py.File(output_name, 'a') as f:
|
495
|
+
if 'audio_in' in f:
|
496
|
+
del f['audio_in']
|
497
|
+
f.create_dataset(name='audio_in', data=audio_in)
|
498
|
+
|
511
499
|
if 'feature' in f:
|
512
500
|
del f['feature']
|
513
501
|
f.create_dataset(name='feature', data=feature)
|
514
502
|
|
515
|
-
|
516
|
-
|
517
|
-
|
503
|
+
if 'predict' in f:
|
504
|
+
del f['predict']
|
505
|
+
f.create_dataset(name='predict', data=predict)
|
506
|
+
|
507
|
+
if 'audio_out' in f:
|
508
|
+
del f['audio_out']
|
509
|
+
f.create_dataset(name='audio_out', data=audio_out)
|
510
|
+
|
511
|
+
output_name = join(output_dir, splitext(basename(file))[0] + '_predict.wav')
|
512
|
+
write_wav(output_name, audio_out, 16000)
|
518
513
|
|
519
514
|
logger.info(f'Saved results to {output_dir}')
|
520
515
|
del model
|
sonusai/utils/__init__.py
CHANGED
@@ -2,10 +2,14 @@
|
|
2
2
|
from .asl_p56 import asl_p56
|
3
3
|
from .asr import ASRResult
|
4
4
|
from .asr import calc_asr
|
5
|
+
from .audio_devices import get_default_input_device
|
6
|
+
from .audio_devices import get_input_device_index_by_name
|
7
|
+
from .audio_devices import get_input_devices
|
5
8
|
from .braced_glob import braced_glob
|
6
9
|
from .braced_glob import braced_iglob
|
7
10
|
from .calculate_input_shape import calculate_input_shape
|
8
11
|
from .convert_string_to_number import convert_string_to_number
|
12
|
+
from .create_timestamp import create_timestamp
|
9
13
|
from .create_ts_name import create_ts_name
|
10
14
|
from .dataclass_from_dict import dataclass_from_dict
|
11
15
|
from .db import db_to_linear
|