sonusai 0.15.8__py3-none-any.whl → 0.15.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,9 @@
1
1
  from typing import Any
2
2
 
3
+ from pyaaware import ForwardTransform
4
+ from pyaaware import InverseTransform
5
+
6
+ from sonusai.mixture import EnergyT
3
7
  from sonusai.mixture.datatypes import AudioF
4
8
  from sonusai.mixture.datatypes import AudioT
5
9
  from sonusai.mixture.datatypes import AudiosT
@@ -78,7 +82,7 @@ def get_feature_generator_info(fg_config: FeatureGeneratorConfig) -> FeatureGene
78
82
  decimation=fg.decimation,
79
83
  stride=fg.stride,
80
84
  step=fg.step,
81
- num_bands=fg.num_bands,
85
+ feature_parameters=fg.feature_parameters,
82
86
  ft_config=TransformConfig(N=fg.ftransform_N,
83
87
  R=fg.ftransform_R,
84
88
  bin_start=fg.bin_start,
@@ -327,15 +331,14 @@ def get_ft(mixdb: MixtureDatabase, mixture: Mixture, mixture_audio: AudioT, trut
327
331
  import numpy as np
328
332
  from pyaaware import FeatureGenerator
329
333
 
330
- from .spectral_mask import apply_spectral_mask
331
334
  from .truth import truth_reduction
332
335
 
333
- mixture_f = get_mixture_f(mixdb=mixdb, mixture_audio=mixture_audio)
336
+ mixture_f = get_mixture_f(mixdb=mixdb, mixture=mixture, mixture_audio=mixture_audio)
334
337
 
335
338
  transform_frames = mixdb.mixture_transform_frames(mixture.samples)
336
339
  feature_frames = mixdb.mixture_feature_frames(mixture.samples)
337
340
 
338
- feature = np.empty((feature_frames, mixdb.fg_stride, mixdb.fg_num_bands), dtype=np.float32)
341
+ feature = np.empty((feature_frames, mixdb.fg_stride, mixdb.feature_parameters), dtype=np.float32)
339
342
  truth_f = np.empty((feature_frames, mixdb.num_classes), dtype=np.complex64)
340
343
 
341
344
  fg = FeatureGenerator(**asdict(mixdb.fg_config))
@@ -350,11 +353,6 @@ def get_ft(mixdb: MixtureDatabase, mixture: Mixture, mixture_audio: AudioT, trut
350
353
  truth_f[feature_frame] = fg.truth()
351
354
  feature_frame += 1
352
355
 
353
- if mixture.spectral_mask_id is not None:
354
- feature = apply_spectral_mask(feature=feature,
355
- spectral_mask=mixdb.spectral_mask(mixture.spectral_mask_id),
356
- seed=mixture.spectral_mask_seed)
357
-
358
356
  if np.isreal(truth_f).all():
359
357
  return feature, truth_f.real
360
358
 
@@ -444,14 +442,35 @@ def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: AudiosT)
444
442
  return np.sum(targets_ir, axis=0)
445
443
 
446
444
 
447
- def get_mixture_f(mixdb: MixtureDatabase, mixture_audio: AudioT) -> AudioF:
445
+ def get_mixture_f(mixdb: MixtureDatabase, mixture: Mixture, mixture_audio: AudioT) -> AudioF:
448
446
  """Get the mixture transform for the given mixture
449
447
 
450
448
  :param mixdb: Mixture database
449
+ :param mixture: Mixture record
451
450
  :param mixture_audio: Mixture audio data for the given mixid
452
451
  :return: Mixture transform data
453
452
  """
454
- return forward_transform(mixture_audio, mixdb.ft_config)
453
+ from .spectral_mask import apply_spectral_mask
454
+
455
+ mixture_f = forward_transform(mixture_audio, mixdb.ft_config)
456
+
457
+ if mixture.spectral_mask_id is not None:
458
+ mixture_f = apply_spectral_mask(audio_f=mixture_f,
459
+ spectral_mask=mixdb.spectral_mask(mixture.spectral_mask_id),
460
+ seed=mixture.spectral_mask_seed)
461
+
462
+ return mixture_f
463
+
464
+
465
+ def get_transform_from_audio(audio: AudioT, transform: ForwardTransform) -> tuple[AudioF, EnergyT]:
466
+ """Apply forward transform to input audio data to generate transform data
467
+
468
+ :param audio: Time domain data [samples]
469
+ :param transform: ForwardTransform object
470
+ :return: Frequency domain data [frames, bins], Energy [frames]
471
+ """
472
+ f, e = transform.execute_all(audio)
473
+ return f.transpose(), e
455
474
 
456
475
 
457
476
  def forward_transform(audio: AudioT, config: TransformConfig) -> AudioF:
@@ -465,17 +484,30 @@ def forward_transform(audio: AudioT, config: TransformConfig) -> AudioF:
465
484
  """
466
485
  from pyaaware import AawareForwardTransform
467
486
 
468
- from .audio import calculate_transform_from_audio
469
-
470
- audio_f, _ = calculate_transform_from_audio(audio=audio,
471
- transform=AawareForwardTransform(N=config.N,
472
- R=config.R,
473
- bin_start=config.bin_start,
474
- bin_end=config.bin_end,
475
- ttype=config.ttype))
487
+ audio_f, _ = get_transform_from_audio(audio=audio,
488
+ transform=AawareForwardTransform(N=config.N,
489
+ R=config.R,
490
+ bin_start=config.bin_start,
491
+ bin_end=config.bin_end,
492
+ ttype=config.ttype))
476
493
  return audio_f
477
494
 
478
495
 
496
+ def get_audio_from_transform(data: AudioF, transform: InverseTransform, trim: bool = True) -> tuple[AudioT, EnergyT]:
497
+ """Apply inverse transform to input transform data to generate audio data
498
+
499
+ :param data: Frequency domain data [frames, bins]
500
+ :param transform: InverseTransform object
501
+ :param trim: Removes starting samples so output waveform will be time-aligned with input waveform to the transform
502
+ :return: Time domain data [samples], Energy [frames]
503
+ """
504
+ t, e = transform.execute_all(data.transpose())
505
+ if trim:
506
+ t = t[transform.N - transform.R:]
507
+
508
+ return t, e
509
+
510
+
479
511
  def inverse_transform(transform: AudioF, config: TransformConfig, trim: bool = True) -> AudioT:
480
512
  """Transform frequency domain data into time domain using the inverse transform config from the feature
481
513
 
@@ -490,16 +522,14 @@ def inverse_transform(transform: AudioF, config: TransformConfig, trim: bool = T
490
522
  import numpy as np
491
523
  from pyaaware import AawareInverseTransform
492
524
 
493
- from .audio import calculate_audio_from_transform
494
-
495
- audio, _ = calculate_audio_from_transform(data=transform,
496
- transform=AawareInverseTransform(N=config.N,
497
- R=config.R,
498
- bin_start=config.bin_start,
499
- bin_end=config.bin_end,
500
- ttype=config.ttype,
501
- gain=np.float32(1)),
502
- trim=trim)
525
+ audio, _ = get_audio_from_transform(data=transform,
526
+ transform=AawareInverseTransform(N=config.N,
527
+ R=config.R,
528
+ bin_start=config.bin_start,
529
+ bin_end=config.bin_end,
530
+ ttype=config.ttype,
531
+ gain=np.float32(1)),
532
+ trim=trim)
503
533
  return audio
504
534
 
505
535
 
@@ -534,7 +564,7 @@ def augmented_target_samples(target_files: TargetFiles,
534
564
  it = list(product(*[target_ids, target_augmentation_ids]))
535
565
  return sum([estimate_augmented_length_from_length(
536
566
  length=target_files[fi].samples,
537
- tempo=target_augmentations[ai].tempo,
567
+ tempo=float(target_augmentations[ai].tempo),
538
568
  frame_length=feature_step_samples) for fi, ai, in it])
539
569
 
540
570
 
@@ -1,7 +1,7 @@
1
1
  def log_duration_and_sizes(total_duration: float,
2
2
  num_classes: int,
3
3
  feature_step_samples: int,
4
- num_bands: int,
4
+ feature_parameters: int,
5
5
  stride: int,
6
6
  desc: str) -> None:
7
7
  from sonusai import logger
@@ -14,7 +14,7 @@ def log_duration_and_sizes(total_duration: float,
14
14
  total_samples = int(total_duration * SAMPLE_RATE)
15
15
  mixture_bytes = total_samples * SAMPLE_BYTES
16
16
  truth_t_bytes = total_samples * num_classes * FLOAT_BYTES
17
- feature_bytes = total_samples / feature_step_samples * stride * num_bands * FLOAT_BYTES
17
+ feature_bytes = total_samples / feature_step_samples * stride * feature_parameters * FLOAT_BYTES
18
18
  truth_f_bytes = total_samples / feature_step_samples * num_classes * FLOAT_BYTES
19
19
 
20
20
  logger.info('')
sonusai/mixture/mixdb.py CHANGED
@@ -248,8 +248,8 @@ class MixtureDatabase:
248
248
  return self.fg_info.step
249
249
 
250
250
  @cached_property
251
- def fg_num_bands(self) -> int:
252
- return self.fg_info.num_bands
251
+ def feature_parameters(self) -> int:
252
+ return self.fg_info.feature_parameters
253
253
 
254
254
  @cached_property
255
255
  def ft_config(self) -> TransformConfig:
@@ -809,11 +809,20 @@ class MixtureDatabase:
809
809
  :return: Mixture transform data
810
810
  """
811
811
  from .helpers import forward_transform
812
+ from .spectral_mask import apply_spectral_mask
812
813
 
813
814
  if force or mixture is None:
814
815
  mixture = self.mixture_mixture(m_id, targets, target, noise, force)
815
816
 
816
- return forward_transform(mixture, self.ft_config)
817
+ mixture_f = forward_transform(mixture, self.ft_config)
818
+
819
+ m = self.mixture(m_id)
820
+ if m.spectral_mask_id is not None:
821
+ mixture_f = apply_spectral_mask(audio_f=mixture_f,
822
+ spectral_mask=self.spectral_mask(int(m.spectral_mask_id)),
823
+ seed=m.spectral_mask_seed)
824
+
825
+ return mixture_f
817
826
 
818
827
  def mixture_truth_t(self,
819
828
  m_id: int,
@@ -938,7 +947,6 @@ class MixtureDatabase:
938
947
  import numpy as np
939
948
  from pyaaware import FeatureGenerator
940
949
 
941
- from .spectral_mask import apply_spectral_mask
942
950
  from .truth import truth_reduction
943
951
 
944
952
  if not force:
@@ -964,7 +972,7 @@ class MixtureDatabase:
964
972
  if truth_t is None:
965
973
  truth_t = np.zeros((m.samples, self.num_classes), dtype=np.float32)
966
974
 
967
- feature = np.empty((feature_frames, self.fg_stride, self.fg_num_bands), dtype=np.float32)
975
+ feature = np.empty((feature_frames, self.fg_stride, self.feature_parameters), dtype=np.float32)
968
976
  truth_f = np.empty((feature_frames, self.num_classes), dtype=np.complex64)
969
977
 
970
978
  fg = FeatureGenerator(**asdict(self.fg_config))
@@ -979,11 +987,6 @@ class MixtureDatabase:
979
987
  truth_f[feature_frame] = fg.truth()
980
988
  feature_frame += 1
981
989
 
982
- if m.spectral_mask_id is not None:
983
- feature = apply_spectral_mask(feature=feature,
984
- spectral_mask=self.spectral_mask(int(m.spectral_mask_id)),
985
- seed=m.spectral_mask_seed)
986
-
987
990
  if np.isreal(truth_f).all():
988
991
  return feature, truth_f.real
989
992
 
@@ -1,23 +1,23 @@
1
- from sonusai.mixture.datatypes import Feature
1
+ from sonusai.mixture.datatypes import AudioF
2
2
  from sonusai.mixture.datatypes import SpectralMask
3
3
 
4
4
 
5
- def apply_spectral_mask(feature: Feature, spectral_mask: SpectralMask, seed: int = None) -> Feature:
5
+ def apply_spectral_mask(audio_f: AudioF, spectral_mask: SpectralMask, seed: int = None) -> AudioF:
6
6
  """Apply frequency and time masking
7
7
 
8
8
  Implementation of SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
9
9
 
10
10
  Ref: https://arxiv.org/pdf/1904.08779.pdf
11
11
 
12
- f_width consecutive bands [f_start, f_start + f_width) are masked, where f_width is chosen from a uniform
13
- distribution from 0 to the f_max_width, and f_start is chosen from [0, bands - f_width).
12
+ f_width consecutive bins [f_start, f_start + f_width) are masked, where f_width is chosen from a uniform
13
+ distribution from 0 to the f_max_width, and f_start is chosen from [0, bins - f_width).
14
14
 
15
15
  t_width consecutive frames [t_start, t_start + t_width) are masked, where t_width is chosen from a uniform
16
16
  distribution from 0 to the t_max_width, and t_start is chosen from [0, frames - t_width).
17
17
 
18
18
  A time mask cannot be wider than t_max_percent times the number of frames.
19
19
 
20
- :param feature: Numpy array of feature data [frames, strides, bands]
20
+ :param audio_f: Numpy array of transform audio data [frames, bins]
21
21
  :param spectral_mask: Spectral mask parameters
22
22
  :param seed: Random number seed
23
23
  :return: Augmented feature
@@ -26,28 +26,28 @@ def apply_spectral_mask(feature: Feature, spectral_mask: SpectralMask, seed: int
26
26
 
27
27
  from sonusai import SonusAIError
28
28
 
29
- if feature.ndim != 3:
30
- raise SonusAIError('feature input must have three dimensions [frames, strides, bands]')
29
+ if audio_f.ndim != 2:
30
+ raise SonusAIError('feature input must have three dimensions [frames, bins]')
31
31
 
32
- frames, strides, bands = feature.shape
32
+ frames, bins = audio_f.shape
33
33
 
34
34
  f_max_width = spectral_mask.f_max_width
35
- if f_max_width not in range(0, bands + 1):
36
- f_max_width = bands
35
+ if f_max_width not in range(0, bins + 1):
36
+ f_max_width = bins
37
37
 
38
38
  rng = np.random.default_rng(seed)
39
39
 
40
40
  # apply f_num frequency masks to the feature
41
41
  for _ in range(spectral_mask.f_num):
42
42
  f_width = int(rng.uniform(0, f_max_width))
43
- f_start = rng.integers(0, bands - f_width, endpoint=True)
44
- feature[:, :, f_start:f_start + f_width] = 0
43
+ f_start = rng.integers(0, bins - f_width, endpoint=True)
44
+ audio_f[:, f_start:f_start + f_width] = 0
45
45
 
46
46
  # apply t_num time masks to the feature
47
47
  t_upper_bound = int(spectral_mask.t_max_percent / 100 * frames)
48
48
  for _ in range(spectral_mask.t_num):
49
49
  t_width = min(int(rng.uniform(0, spectral_mask.t_max_width)), t_upper_bound)
50
50
  t_start = rng.integers(0, frames - t_width, endpoint=True)
51
- feature[t_start:t_start + t_width, :, :] = 0
51
+ audio_f[t_start:t_start + t_width, :] = 0
52
52
 
53
- return feature
53
+ return audio_f
@@ -23,7 +23,7 @@ class Data:
23
23
  num_classes=config.num_classes,
24
24
  truth_mutex=config.mutex)
25
25
 
26
- self.num_bands = fg.num_bands
26
+ self.feature_parameters = fg.feature_parameters
27
27
  self.ttype = fg.ftransform_ttype
28
28
  self.frame_size = fg.ftransform_R
29
29
 
@@ -19,7 +19,7 @@ Output shape: [:, num_classes]
19
19
 
20
20
  from sonusai import SonusAIError
21
21
 
22
- if data.config.num_classes != data.num_bands:
22
+ if data.config.num_classes != data.feature_parameters:
23
23
  raise SonusAIError(f'Invalid num_classes for target_f truth: {data.config.num_classes}')
24
24
 
25
25
  target_freq = _execute_fft(data.target_audio, data.target_fft, len(data.offsets))
@@ -51,7 +51,7 @@ Output shape: [:, 2 * num_classes]
51
51
  """
52
52
  from sonusai import SonusAIError
53
53
 
54
- if data.config.num_classes != 2 * data.num_bands:
54
+ if data.config.num_classes != 2 * data.feature_parameters:
55
55
  raise SonusAIError(f'Invalid num_classes for target_mixture_f truth: {data.config.num_classes}')
56
56
 
57
57
  target_freq = _execute_fft(data.target_audio, data.target_fft, len(data.offsets))
sonusai/onnx_predict.py CHANGED
@@ -105,7 +105,7 @@ def main() -> None:
105
105
  logger.info('')
106
106
  logger.info(f'Run prediction on {input_name}')
107
107
  audio = read_audio(input_name)
108
- feature = get_feature_from_audio(audio=audio, feature=model_metadata.feature)
108
+ feature = get_feature_from_audio(audio=audio, feature_mode=model_metadata.feature)
109
109
 
110
110
  predict = pad_and_predict(feature=feature,
111
111
  model_name=model_name,
sonusai/plot.py CHANGED
@@ -314,7 +314,7 @@ def main() -> None:
314
314
  raise SonusAIError('Must specify MODEL when input is WAV')
315
315
 
316
316
  mixture_audio = read_audio(input_name)
317
- feature = get_feature_from_audio(audio=mixture_audio, feature=model.feature)
317
+ feature = get_feature_from_audio(audio=mixture_audio, feature_mode=model.feature)
318
318
  fg_config = FeatureGeneratorConfig(feature_mode=model.feature,
319
319
  num_classes=model.output_shape[-1],
320
320
  truth_mutex=False)
@@ -406,11 +406,11 @@ def main() -> None:
406
406
  title = f'{input_name}'
407
407
  pdf_name = f'{base_name}-plot.pdf'
408
408
 
409
- # Original size [frames, stride, num_bands]
409
+ # Original size [frames, stride, feature_parameters]
410
410
  # Decimate in the stride dimension
411
- # Reshape to get frames*decimated_stride, num_bands
411
+ # Reshape to get frames*decimated_stride, feature_parameters
412
412
  if feature.ndim != 3:
413
- raise SonusAIError(f'feature does not have 3 dimensions: frames, stride, num_bands')
413
+ raise SonusAIError(f'feature does not have 3 dimensions: frames, stride, feature_parameters')
414
414
  spectrogram = feature[:, -fg_step:, :]
415
415
  spectrogram = np.reshape(spectrogram, (spectrogram.shape[0] * spectrogram.shape[1], spectrogram.shape[2]))
416
416
 
@@ -123,7 +123,7 @@ def _process(file: str) -> None:
123
123
  from pyaaware import AawareInverseTransform
124
124
 
125
125
  from sonusai import SonusAIError
126
- from sonusai.mixture import calculate_audio_from_transform
126
+ from sonusai.mixture import get_audio_from_transform
127
127
  from sonusai.utils import float_to_int16
128
128
  from sonusai.utils import unstack_complex
129
129
  from sonusai.utils import write_wav
@@ -135,13 +135,13 @@ def _process(file: str) -> None:
135
135
  raise SonusAIError(f'Error reading {file}: {e}')
136
136
 
137
137
  output_name = join(MP_GLOBAL.output_dir, splitext(basename(file))[0] + '.wav')
138
- audio, _ = calculate_audio_from_transform(data=predict,
139
- transform=AawareInverseTransform(N=MP_GLOBAL.N,
140
- R=MP_GLOBAL.R,
141
- bin_start=MP_GLOBAL.bin_start,
142
- bin_end=MP_GLOBAL.bin_end,
143
- ttype=MP_GLOBAL.ttype,
144
- gain=np.float32(1)))
138
+ audio, _ = get_audio_from_transform(data=predict,
139
+ transform=AawareInverseTransform(N=MP_GLOBAL.N,
140
+ R=MP_GLOBAL.R,
141
+ bin_start=MP_GLOBAL.bin_start,
142
+ bin_end=MP_GLOBAL.bin_end,
143
+ ttype=MP_GLOBAL.ttype,
144
+ gain=np.float32(1)))
145
145
  write_wav(name=output_name, audio=float_to_int16(audio))
146
146
 
147
147
 
sonusai/torchl_predict.py CHANGED
@@ -43,15 +43,38 @@ Outputs the following to tpredict-<TIMESTAMP> directory:
43
43
  torch_predict.log
44
44
 
45
45
  """
46
+ from os import makedirs
47
+ from os.path import basename
48
+ from os.path import isdir
46
49
  from os.path import join
50
+ from os.path import normpath
51
+ from os.path import splitext
47
52
  from typing import Any
48
53
 
49
54
  import h5py
50
55
  import torch
56
+ from docopt import docopt
57
+ from lightning.pytorch import Trainer
51
58
  from lightning.pytorch.callbacks import BasePredictionWriter
59
+ from pyaaware import FeatureGenerator
60
+ from pyaaware import TorchInverseTransform
61
+ from torchinfo import summary
52
62
 
63
+ import sonusai
64
+ from sonusai import create_file_handler
65
+ from sonusai import initial_log_messages
53
66
  from sonusai import logger
67
+ from sonusai import update_console_handler
68
+ from sonusai.data_generator import TorchFromMixtureDatabase
54
69
  from sonusai.mixture import Feature
70
+ from sonusai.mixture import MixtureDatabase
71
+ from sonusai.mixture import get_audio_from_feature
72
+ from sonusai.mixture import get_feature_from_audio
73
+ from sonusai.mixture import read_audio
74
+ from sonusai.utils import create_ts_name
75
+ from sonusai.utils import import_keras_model
76
+ from sonusai.utils import trim_docstring
77
+ from sonusai.utils import write_wav
55
78
 
56
79
 
57
80
  class CustomWriter(BasePredictionWriter):
@@ -61,7 +84,7 @@ class CustomWriter(BasePredictionWriter):
61
84
 
62
85
  def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
63
86
  # this will create N (num processes) files in `output_dir` each containing
64
- # the predictions of it's respective rank
87
+ # the predictions of its respective rank
65
88
  # torch.save(predictions, os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt"))
66
89
 
67
90
  # optionally, you can also save `batch_indices` to get the information about the data index
@@ -119,11 +142,6 @@ def power_uncompress(real, imag):
119
142
 
120
143
 
121
144
  def main() -> None:
122
- from docopt import docopt
123
-
124
- import sonusai
125
- from sonusai.utils import trim_docstring
126
-
127
145
  args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
128
146
 
129
147
  verbose = args['--verbose']
@@ -139,27 +157,6 @@ def main() -> None:
139
157
  wavdbg = args['--wavdbg'] # write .wav if true
140
158
  input_name = args['INPUT']
141
159
 
142
- from os import makedirs
143
- from os.path import basename
144
- from os.path import isdir
145
- from os.path import isfile
146
- from os.path import join
147
- from os.path import splitext
148
- from os.path import normpath
149
- import h5py
150
- # from sonusai.utils import float_to_int16
151
-
152
- from torchinfo import summary
153
- from sonusai import create_file_handler
154
- from sonusai import initial_log_messages
155
- from sonusai import update_console_handler
156
- from sonusai.mixture import MixtureDatabase
157
- from sonusai.mixture import get_feature_from_audio
158
- from sonusai.utils import import_keras_model
159
- from sonusai.mixture import read_audio
160
- from sonusai.utils import create_ts_name
161
- from sonusai.data_generator import TorchFromMixtureDatabase
162
-
163
160
  if batch_size is not None:
164
161
  batch_size = int(batch_size)
165
162
  if batch_size != 1:
@@ -222,6 +219,8 @@ def main() -> None:
222
219
  hparams['timesteps'] = timesteps
223
220
 
224
221
  logger.info(f'Building model with hparams and batch_size={batch_size}, timesteps={timesteps}')
222
+ # hparams['cl_per_wght'] = 0.0
223
+ # hparams['feature'] = 'hum00ns1'
225
224
  try:
226
225
  model = litemodule.MyHyperModel(**hparams) # use hparams
227
226
  # litemodule.MyHyperModel.load_from_checkpoint(ckpt_name, **hparams)
@@ -303,33 +302,25 @@ def main() -> None:
303
302
  drop_last=False,
304
303
  num_workers=dlcpu)
305
304
 
306
- if wavdbg: # setup for wav write if enabled
307
- # Info needed to setup inverse transform
308
- from pyaaware import FeatureGenerator
309
- from pyaaware import TorchInverseTransform
310
- from torchaudio import save
311
- # from sonusai.utils import write_wav
312
-
313
- half = model.num_classes // 2
314
- fg = FeatureGenerator(feature_mode=model.hparams.feature,
315
- num_classes=model.num_classes,
316
- truth_mutex=model.truth_mutex)
317
- itf = TorchInverseTransform(N=fg.itransform_N,
318
- R=fg.itransform_R,
319
- bin_start=fg.bin_start,
320
- bin_end=fg.bin_end,
321
- ttype=fg.itransform_ttype)
322
-
323
- if mixdb.target_files[0].truth_settings[0].function == 'target_f' or \
324
- mixdb.target_files[0].truth_settings[0].function == 'target_mixture_f':
325
- enable_truth_wav = True
326
- else:
327
- enable_truth_wav = False
328
-
305
+ # Info needed to set up inverse transform
306
+ half = model.num_classes // 2
307
+ fg = FeatureGenerator(feature_mode=model.hparams.feature,
308
+ num_classes=model.num_classes,
309
+ truth_mutex=model.truth_mutex)
310
+ itf = TorchInverseTransform(N=fg.itransform_N,
311
+ R=fg.itransform_R,
312
+ bin_start=fg.bin_start,
313
+ bin_end=fg.bin_end,
314
+ ttype=fg.itransform_ttype)
315
+
316
+ enable_truth_wav = False
317
+ enable_mix_wav = False
318
+ if wavdbg:
329
319
  if mixdb.target_files[0].truth_settings[0].function == 'target_mixture_f':
330
320
  enable_mix_wav = True
331
- else:
332
- enable_mix_wav = False
321
+ enable_truth_wav = True
322
+ elif mixdb.target_files[0].truth_settings[0].function == 'target_f':
323
+ enable_truth_wav = True
333
324
 
334
325
  if reset:
335
326
  logger.info(f'Running {mixdb.num_mixtures} mixtures individually with model reset ...')
@@ -351,26 +342,25 @@ def main() -> None:
351
342
  if wavdbg:
352
343
  owav_base = splitext(output_name)[0]
353
344
  tmp = torch.complex(ypred[..., :half], ypred[..., half:]).permute(2, 0, 1).detach()
345
+ itf.reset()
354
346
  predwav, _ = itf.execute_all(tmp)
355
- # predwav, _ = calculate_audio_from_transform(tmp, itf, trim=True)
356
- save(owav_base + '.wav', predwav.permute([1, 0]), 16000, encoding='PCM_S', bits_per_sample=16)
347
+ # predwav, _ = calculate_audio_from_transform(tmp.numpy(), itf, trim=True)
348
+ write_wav(owav_base + '.wav', predwav.permute([1, 0]).numpy(), 16000)
357
349
  if enable_truth_wav:
358
350
  # Note this support truth type target_f and target_mixture_f
359
351
  tmp = torch.complex(val[0][..., :half], val[0][..., half:2 * half]).permute(2, 0, 1).detach()
352
+ itf.reset()
360
353
  truthwav, _ = itf.execute_all(tmp)
361
- save(owav_base + '_truth.wav', truthwav.permute([1, 0]), 16000, encoding='PCM_S',
362
- bits_per_sample=16)
354
+ write_wav(owav_base + '_truth.wav', truthwav.permute([1, 0]).numpy(), 16000)
363
355
 
364
356
  if enable_mix_wav:
365
357
  tmp = torch.complex(val[0][..., 2 * half:3 * half], val[0][..., 3 * half:]).permute(2, 0, 1)
358
+ itf.reset()
366
359
  mixwav, _ = itf.execute_all(tmp.detach())
367
- save(owav_base + "_mix.wav", mixwav.permute([1, 0]), 16000, encoding='PCM_S',
368
- bits_per_sample=16)
369
- # write_wav(owav_base + "_truth.wav", truthwav, 16000)
360
+ write_wav(owav_base + '_mix.wav', mixwav.permute([1, 0]).numpy(), 16000)
370
361
 
371
362
  else:
372
363
  logger.info(f'Running {mixdb.num_mixtures} mixtures with model builtin prediction loop ...')
373
- from lightning.pytorch import Trainer
374
364
  pred_writer = CustomWriter(output_dir=output_dir, write_interval="epoch")
375
365
  trainer = Trainer(default_root_dir=output_dir,
376
366
  callbacks=[pred_writer],
@@ -489,32 +479,37 @@ def main() -> None:
489
479
  # logger.info(f'Saved results to {output_dir}')
490
480
  # return
491
481
 
492
- if not all(isfile(file) and splitext(file)[1] == '.wav' for file in input_name):
493
- logger.exception(f'Do not know how to process input from {input_name}')
494
- raise SystemExit(1)
495
-
496
- logger.info(f'Run prediction on {len(input_name):,} WAV files')
482
+ logger.info(f'Run prediction on {len(input_name):,} audio files')
497
483
  for file in input_name:
498
- # Convert WAV to feature data
499
- audio = read_audio(file)
500
- feature = get_feature_from_audio(audio=audio, feature=model.feature)
484
+ # Convert audio to feature data
485
+ audio_in = read_audio(file)
486
+ feature = get_feature_from_audio(audio=audio_in, feature_mode=model.hparams.feature)
501
487
 
502
- # feature, predict = _pad_and_predict(hypermodel=hypermodel,
503
- # built_model=built_model,
504
- # feature=feature,
505
- # frames_per_batch=frames_per_batch)
488
+ with torch.no_grad():
489
+ predict = model(torch.tensor(feature))
506
490
 
507
- # clean = torch_istft_olsa_hanns(clean_spec_cmplx, mixdb.ift_config.N, mixdb.ift_config.R)
491
+ audio_out = get_audio_from_feature(feature=predict.numpy(), feature_mode=model.hparams.feature)
508
492
 
509
493
  output_name = join(output_dir, splitext(basename(file))[0] + '.h5')
510
494
  with h5py.File(output_name, 'a') as f:
495
+ if 'audio_in' in f:
496
+ del f['audio_in']
497
+ f.create_dataset(name='audio_in', data=audio_in)
498
+
511
499
  if 'feature' in f:
512
500
  del f['feature']
513
501
  f.create_dataset(name='feature', data=feature)
514
502
 
515
- # if 'predict' in f:
516
- # del f['predict']
517
- # f.create_dataset(name='predict', data=predict)
503
+ if 'predict' in f:
504
+ del f['predict']
505
+ f.create_dataset(name='predict', data=predict)
506
+
507
+ if 'audio_out' in f:
508
+ del f['audio_out']
509
+ f.create_dataset(name='audio_out', data=audio_out)
510
+
511
+ output_name = join(output_dir, splitext(basename(file))[0] + '_predict.wav')
512
+ write_wav(output_name, audio_out, 16000)
518
513
 
519
514
  logger.info(f'Saved results to {output_dir}')
520
515
  del model
sonusai/utils/__init__.py CHANGED
@@ -2,10 +2,14 @@
2
2
  from .asl_p56 import asl_p56
3
3
  from .asr import ASRResult
4
4
  from .asr import calc_asr
5
+ from .audio_devices import get_default_input_device
6
+ from .audio_devices import get_input_device_index_by_name
7
+ from .audio_devices import get_input_devices
5
8
  from .braced_glob import braced_glob
6
9
  from .braced_glob import braced_iglob
7
10
  from .calculate_input_shape import calculate_input_shape
8
11
  from .convert_string_to_number import convert_string_to_number
12
+ from .create_timestamp import create_timestamp
9
13
  from .create_ts_name import create_ts_name
10
14
  from .dataclass_from_dict import dataclass_from_dict
11
15
  from .db import db_to_linear