sonusai 0.15.8__py3-none-any.whl → 0.15.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/audiofe.py ADDED
@@ -0,0 +1,293 @@
1
+ """sonusai audiofe
2
+
3
+ usage: audiofe [-hvds] [--version] [-i INPUT] [-l LENGTH] [-m MODEL] [-k CKPT] [-a ASR] [-w WMODEL]
4
+
5
+ options:
6
+ -h, --help
7
+ -v, --verbose Be verbose.
8
+ -d, --debug Write debug data to H5 file.
9
+ -s, --show Show a list of available audio inputs.
10
+ -i INPUT, --input INPUT Input audio.
11
+ -l LENGTH, --length LENGTH Length of audio in seconds. [default: -1].
12
+ -m MODEL, --model MODEL PL model .py file path.
13
+ -k CKPT, --checkpoint CKPT PL checkpoint file with weights.
14
+ -a ASR, --asr ASR ASR method to use.
15
+ -w WMODEL, --whisper WMODEL Whisper model used in aixplain_whisper and whisper methods. [default: tiny].
16
+
17
+ Aaware SonusAI Audio Front End.
18
+
19
+ Capture LENGTH seconds of audio from INPUT. If LENGTH is < 0, then capture until key is pressed. If INPUT is a valid
20
+ audio file name, then use the audio data from the specified file. In this case, if LENGTH is < 0, process entire file;
21
+ otherwise, process min(length(INPUT), LENGTH) seconds of audio from INPUT. Audio is saved to
22
+ audiofe_capture_<TIMESTAMP>.wav.
23
+
24
+ If a model is specified, run prediction on audio data from this model. Then compute the inverse transform of the
25
+ prediction result and save to audiofe_predict_<TIMESTAMP>.wav.
26
+
27
+ If an ASR is specified, run ASR on the captured audio and print the results. In addition, if a model was also specified,
28
+ run ASR on the predict audio and print the results.
29
+
30
+ If the debug option is enabled, write capture audio, feature, reconstruct audio, predict, and predict audio to
31
+ audiofe_<TIMESTAMP>.h5.
32
+
33
+ """
34
+ from os.path import exists
35
+ from select import select
36
+ from sys import stdin
37
+ from typing import Any
38
+
39
+ import h5py
40
+ import numpy as np
41
+ import pyaudio
42
+ import torch
43
+ from docopt import docopt
44
+ from docopt import printable_usage
45
+
46
+ import sonusai
47
+ from sonusai import create_file_handler
48
+ from sonusai import initial_log_messages
49
+ from sonusai import logger
50
+ from sonusai import update_console_handler
51
+ from sonusai.mixture import AudioT
52
+ from sonusai.mixture import CHANNEL_COUNT
53
+ from sonusai.mixture import SAMPLE_RATE
54
+ from sonusai.mixture import get_audio_from_feature
55
+ from sonusai.mixture import get_feature_from_audio
56
+ from sonusai.mixture import read_audio
57
+ from sonusai.utils import calc_asr
58
+ from sonusai.utils import create_timestamp
59
+ from sonusai.utils import get_input_device_index_by_name
60
+ from sonusai.utils import get_input_devices
61
+ from sonusai.utils import import_keras_model
62
+ from sonusai.utils import trim_docstring
63
+ from sonusai.utils import write_wav
64
+
65
+
66
+ def main() -> None:
67
+ args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
68
+ ts = create_timestamp()
69
+
70
+ verbose = args['--verbose']
71
+ length = float(args['--length'])
72
+ input_name = args['--input']
73
+ model_name = args['--model']
74
+ ckpt_name = args['--checkpoint']
75
+ asr_name = args['--asr']
76
+ whisper_name = args['--whisper']
77
+ debug = args['--debug']
78
+ show = args['--show']
79
+
80
+ capture_name = f'audiofe_capture_{ts}.wav'
81
+ predict_name = f'audiofe_predict_{ts}.wav'
82
+ h5_name = f'audiofe_{ts}.h5'
83
+
84
+ if model_name is not None and ckpt_name is None:
85
+ print(printable_usage(trim_docstring(__doc__)))
86
+ exit(1)
87
+
88
+ # Setup logging file
89
+ create_file_handler('audiofe.log')
90
+ update_console_handler(verbose)
91
+ initial_log_messages('audiofe')
92
+
93
+ if show:
94
+ logger.info('List of available audio inputs:')
95
+ logger.info('')
96
+ p = pyaudio.PyAudio()
97
+ for name in get_input_devices(p):
98
+ logger.info(f'{name}')
99
+ logger.info('')
100
+ p.terminate()
101
+ return
102
+
103
+ if input_name is not None and exists(input_name):
104
+ capture_audio = get_frames_from_file(input_name, length)
105
+ else:
106
+ try:
107
+ capture_audio = get_frames_from_device(input_name, length)
108
+ except ValueError as e:
109
+ logger.exception(e)
110
+ return
111
+
112
+ write_wav(capture_name, capture_audio, SAMPLE_RATE)
113
+ logger.info('')
114
+ logger.info(f'Wrote capture audio with shape {capture_audio.shape} to {capture_name}')
115
+ if debug:
116
+ with h5py.File(h5_name, 'a') as f:
117
+ if 'capture_audio' in f:
118
+ del f['capture_audio']
119
+ f.create_dataset('capture_audio', data=capture_audio)
120
+ logger.info(f'Wrote capture audio with shape {capture_audio.shape} to {h5_name}')
121
+
122
+ if asr_name is not None:
123
+ capture_asr = calc_asr(capture_audio, engine=asr_name, whisper_model_name=whisper_name).text
124
+ logger.info(f'Capture audio ASR: {capture_asr}')
125
+
126
+ if model_name is not None:
127
+ model = load_model(model_name=model_name, ckpt_name=ckpt_name)
128
+
129
+ feature = get_feature_from_audio(audio=capture_audio, feature_mode=model.hparams.feature)
130
+ if debug:
131
+ with h5py.File(h5_name, 'a') as f:
132
+ if 'feature' in f:
133
+ del f['feature']
134
+ f.create_dataset('feature', data=feature)
135
+ logger.info(f'Wrote feature with shape {feature.shape} to {h5_name}')
136
+
137
+ # if debug:
138
+ # reconstruct_name = f'audiofe_reconstruct_{ts}.wav'
139
+ # reconstruct_audio = get_audio_from_feature(feature=feature, feature_mode=model.hparams.feature)
140
+ # samples = min(len(capture_audio), len(reconstruct_audio))
141
+ # max_err = np.max(np.abs(capture_audio[:samples] - reconstruct_audio[:samples]))
142
+ # logger.info(f'Maximum error between capture and reconstruct: {max_err}')
143
+ # write_wav(reconstruct_name, reconstruct_audio, SAMPLE_RATE)
144
+ # logger.info(f'Wrote reconstruct audio with shape {reconstruct_audio.shape} to {reconstruct_name}')
145
+ # with h5py.File(h5_name, 'a') as f:
146
+ # if 'reconstruct_audio' in f:
147
+ # del f['reconstruct_audio']
148
+ # f.create_dataset('reconstruct_audio', data=reconstruct_audio)
149
+ # logger.info(f'Wrote reconstruct audio with shape {reconstruct_audio.shape} to {h5_name}')
150
+
151
+ with torch.no_grad():
152
+ predict = model(torch.tensor(feature))
153
+ if debug:
154
+ with h5py.File(h5_name, 'a') as f:
155
+ if 'predict' in f:
156
+ del f['predict']
157
+ f.create_dataset('predict', data=predict)
158
+ logger.info(f'Wrote predict with shape {predict.shape} to {h5_name}')
159
+
160
+ predict_audio = get_audio_from_feature(feature=predict.numpy(), feature_mode=model.hparams.feature)
161
+ write_wav(predict_name, predict_audio, SAMPLE_RATE)
162
+ logger.info(f'Wrote predict audio with shape {predict_audio.shape} to {predict_name}')
163
+ if debug:
164
+ with h5py.File(h5_name, 'a') as f:
165
+ if 'predict_audio' in f:
166
+ del f['predict_audio']
167
+ f.create_dataset('predict_audio', data=predict_audio)
168
+ logger.info(f'Wrote predict audio with shape {predict_audio.shape} to {h5_name}')
169
+
170
+ if asr_name is not None:
171
+ predict_asr = calc_asr(predict_audio, engine=asr_name, whisper_model_name=whisper_name).text
172
+ logger.info(f'Predict audio ASR: {predict_asr}')
173
+
174
+
175
+ def load_model(model_name: str, ckpt_name: str) -> Any:
176
+ batch_size = 1
177
+ timesteps = 0
178
+
179
+ # Load checkpoint first to get hparams if available
180
+ try:
181
+ checkpoint = torch.load(ckpt_name, map_location=lambda storage, loc: storage)
182
+ except Exception as e:
183
+ logger.exception(f'Error: could not load checkpoint from {ckpt_name}: {e}')
184
+ raise SystemExit(1)
185
+
186
+ # Import model definition file
187
+ logger.info(f'Importing {model_name}')
188
+ litemodule = import_keras_model(model_name)
189
+
190
+ if 'hyper_parameters' in checkpoint:
191
+ logger.info(f'Found checkpoint file with hyper-parameters')
192
+ hparams = checkpoint['hyper_parameters']
193
+ if hparams['batch_size'] != batch_size:
194
+ logger.info(
195
+ f'Overriding model default batch_size of {hparams["batch_size"]} with batch_size of {batch_size}')
196
+ hparams["batch_size"] = batch_size
197
+
198
+ if hparams['timesteps'] != 0 and timesteps == 0:
199
+ timesteps = hparams['timesteps']
200
+ logger.warning(f'Using model default timesteps of {timesteps}')
201
+
202
+ logger.info(f'Building model with {len(hparams)} total hparams')
203
+ try:
204
+ model = litemodule.MyHyperModel(**hparams)
205
+ except Exception as e:
206
+ logger.exception(f'Error: model build (MyHyperModel) in {model_name} failed: {e}')
207
+ raise SystemExit(1)
208
+ else:
209
+ logger.info(f'Found checkpoint file with no hyper-parameters')
210
+ logger.info(f'Building model with defaults')
211
+ try:
212
+ tmp = litemodule.MyHyperModel()
213
+ except Exception as e:
214
+ logger.exception(f'Error: model build (MyHyperModel) in {model_name} failed: {e}')
215
+ raise SystemExit(1)
216
+
217
+ if tmp.batch_size != batch_size:
218
+ logger.info(f'Overriding model default batch_size of {tmp.batch_size} with batch_size of {batch_size}')
219
+
220
+ if tmp.timesteps != 0 and timesteps == 0:
221
+ timesteps = tmp.timesteps
222
+ logger.warning(f'Using model default timesteps of {timesteps}')
223
+
224
+ model = litemodule.MyHyperModel(timesteps=timesteps, batch_size=batch_size)
225
+
226
+ logger.info(f'Loading weights from {ckpt_name}')
227
+ model.load_state_dict(checkpoint["state_dict"])
228
+ model.eval()
229
+ return model
230
+
231
+
232
+ def get_frames_from_device(input_name: str | None, length: float, chunk: int = 1024) -> AudioT:
233
+ p = pyaudio.PyAudio()
234
+
235
+ input_devices = get_input_devices(p)
236
+ if not input_devices:
237
+ raise ValueError('No input audio devices found')
238
+
239
+ if input_name is None:
240
+ input_name = input_devices[0]
241
+
242
+ try:
243
+ device_index = get_input_device_index_by_name(p, input_name)
244
+ except ValueError:
245
+ msg = f'Could not find {input_name}\n'
246
+ msg += f'Available devices:\n'
247
+ for input_device in input_devices:
248
+ msg += f' {input_device}\n'
249
+ raise ValueError(msg)
250
+
251
+ logger.info(f'Capturing from {p.get_device_info_by_index(device_index).get("name")}')
252
+ stream = p.open(format=pyaudio.paFloat32,
253
+ channels=CHANNEL_COUNT,
254
+ rate=SAMPLE_RATE,
255
+ input=True,
256
+ input_device_index=device_index)
257
+ stream.start_stream()
258
+
259
+ print()
260
+ print('+---------------------------------+')
261
+ print('| Press Enter to stop |')
262
+ print('+---------------------------------+')
263
+ print()
264
+
265
+ elapsed = 0.0
266
+ seconds_per_chunk = float(chunk) / float(SAMPLE_RATE)
267
+ raw_frames = []
268
+ while elapsed < length or length == -1:
269
+ raw_frames.append(stream.read(num_frames=chunk, exception_on_overflow=False))
270
+ elapsed += seconds_per_chunk
271
+ if select([stdin, ], [], [], 0)[0]:
272
+ stdin.read(1)
273
+ length = elapsed
274
+
275
+ stream.stop_stream()
276
+ stream.close()
277
+ p.terminate()
278
+ frames = np.frombuffer(b''.join(raw_frames), dtype=np.float32)
279
+ return frames
280
+
281
+
282
+ def get_frames_from_file(input_name: str, length: float) -> AudioT:
283
+ logger.info(f'Capturing from {input_name}')
284
+ frames = read_audio(input_name)
285
+ if length != -1:
286
+ num_frames = int(length * SAMPLE_RATE)
287
+ if len(frames) > num_frames:
288
+ frames = frames[:num_frames]
289
+ return frames
290
+
291
+
292
+ if __name__ == '__main__':
293
+ main()
@@ -978,11 +978,11 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
978
978
  plot_fname = base_name + '_metric_spenh.pdf'
979
979
 
980
980
  # Reshape feature to eliminate overlap redundancy for easier to understand spectrogram view
981
- # Original size (frames, stride, num_bands), decimates in stride dimension only if step is > 1
982
- # Reshape to get frames*decimated_stride, num_bands
981
+ # Original size (frames, stride, feature_parameters), decimates in stride dimension only if step is > 1
982
+ # Reshape to get frames*decimated_stride, feature_parameters
983
983
  step = int(mixdb.feature_samples / mixdb.feature_step_samples)
984
984
  if feature.ndim != 3:
985
- raise SonusAIError(f'feature does not have 3 dimensions: frames, stride, num_bands')
985
+ raise SonusAIError(f'feature does not have 3 dimensions: frames, stride, feature_parameters')
986
986
 
987
987
  # for feature cn*00n**
988
988
  feat_sgram = unstack_complex(feature)
@@ -42,7 +42,7 @@ class DatasetFromMixtureDatabase(Sequence):
42
42
  self.add1ch = add1ch
43
43
  self.shuffle = shuffle
44
44
  self.stride = self.mixdb.fg_stride
45
- self.num_bands = self.mixdb.fg_num_bands
45
+ self.feature_parameters = self.mixdb.feature_parameters
46
46
  self.num_classes = self.mixdb.num_classes
47
47
  self.mixture_frame_segments = None
48
48
  self.batch_frame_segments = None
@@ -61,7 +61,7 @@ class KerasFromMixtureDatabase(Sequence):
61
61
  self.add1ch = add1ch
62
62
  self.shuffle = shuffle
63
63
  self.stride = self.mixdb.fg_stride
64
- self.num_bands = self.mixdb.fg_num_bands
64
+ self.feature_parameters = self.mixdb.feature_parameters
65
65
  self.num_classes = self.mixdb.num_classes
66
66
  self.mixture_frame_segments: Optional[int] = None
67
67
  self.batch_frame_segments: Optional[int] = None
sonusai/genft.py CHANGED
@@ -165,7 +165,8 @@ def main() -> None:
165
165
  logger.info(f'Wrote {len(mixids)} mixtures to {location}')
166
166
  logger.info('')
167
167
  logger.info(f'Duration: {seconds_to_hms(seconds=duration)}')
168
- logger.info(f'feature: {human_readable_size(total_feature_frames * mixdb.fg_stride * mixdb.fg_num_bands * 4, 1)}')
168
+ logger.info(
169
+ f'feature: {human_readable_size(total_feature_frames * mixdb.fg_stride * mixdb.feature_parameters * 4, 1)}')
169
170
  logger.info(f'truth_f: {human_readable_size(total_feature_frames * mixdb.num_classes * 4, 1)}')
170
171
  if compute_segsnr:
171
172
  logger.info(f'segsnr: {human_readable_size(total_transform_frames * 4, 1)}')
sonusai/genmixdb.py CHANGED
@@ -337,12 +337,12 @@ def genmixdb(location: str,
337
337
  log_duration_and_sizes(total_duration=total_duration,
338
338
  num_classes=mixdb.num_classes,
339
339
  feature_step_samples=mixdb.feature_step_samples,
340
- num_bands=mixdb.fg_num_bands,
340
+ feature_parameters=mixdb.feature_parameters,
341
341
  stride=mixdb.fg_stride,
342
342
  desc='Estimated')
343
343
  logger.info(f'Feature shape: '
344
- f'{mixdb.fg_stride} x {mixdb.fg_num_bands} '
345
- f'({mixdb.fg_stride * mixdb.fg_num_bands} total params)')
344
+ f'{mixdb.fg_stride} x {mixdb.feature_parameters} '
345
+ f'({mixdb.fg_stride * mixdb.feature_parameters} total params)')
346
346
  logger.info(f'Feature samples: {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)')
347
347
  logger.info(f'Feature step samples: {mixdb.feature_step_samples} samples ({mixdb.feature_step_ms} ms)')
348
348
  logger.info('')
@@ -371,7 +371,7 @@ def genmixdb(location: str,
371
371
  log_duration_and_sizes(total_duration=total_duration,
372
372
  num_classes=mixdb.num_classes,
373
373
  feature_step_samples=mixdb.feature_step_samples,
374
- num_bands=mixdb.fg_num_bands,
374
+ feature_parameters=mixdb.feature_parameters,
375
375
  stride=mixdb.fg_stride,
376
376
  desc='Actual')
377
377
  logger.info('')
sonusai/keras_predict.py CHANGED
@@ -180,7 +180,7 @@ def main() -> None:
180
180
  for file in input_name:
181
181
  # Convert WAV to feature data
182
182
  audio = read_audio(file)
183
- feature = get_feature_from_audio(audio=audio, feature=hypermodel.feature)
183
+ feature = get_feature_from_audio(audio=audio, feature_mode=hypermodel.feature)
184
184
 
185
185
  feature, predict = _pad_and_predict(hypermodel=hypermodel,
186
186
  built_model=built_model,
sonusai/lsdb.py CHANGED
@@ -48,8 +48,8 @@ def lsdb(mixdb: MixtureDatabase,
48
48
  logger.info(f'{"Targets":{desc_len}} {mixdb.num_target_files}')
49
49
  logger.info(f'{"Noises":{desc_len}} {mixdb.num_noise_files}')
50
50
  logger.info(f'{"Feature":{desc_len}} {mixdb.feature}')
51
- logger.info(f'{"Feature shape":{desc_len}} {mixdb.fg_stride} x {mixdb.fg_num_bands} '
52
- f'({mixdb.fg_stride * mixdb.fg_num_bands} total params)')
51
+ logger.info(f'{"Feature shape":{desc_len}} {mixdb.fg_stride} x {mixdb.feature_parameters} '
52
+ f'({mixdb.fg_stride * mixdb.feature_parameters} total params)')
53
53
  logger.info(f'{"Feature samples":{desc_len}} {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)')
54
54
  logger.info(f'{"Feature step samples":{desc_len}} {mixdb.feature_step_samples} samples '
55
55
  f'({mixdb.feature_step_ms} ms)')
sonusai/main.py CHANGED
@@ -3,9 +3,9 @@
3
3
  usage: sonusai [--version] [--help] <command> [<args>...]
4
4
 
5
5
  The sonusai commands are:
6
+ audiofe Audio front end
6
7
  calc_metric_spenh Run speech enhancement and analysis
7
8
  doc Documentation
8
- evaluate Evaluate model performance
9
9
  genft Generate feature and truth data
10
10
  genmix Generate mixture and truth data
11
11
  genmixdb Generate a mixture database
@@ -39,9 +39,9 @@ def main() -> None:
39
39
  from sonusai.utils import trim_docstring
40
40
 
41
41
  commands = (
42
+ 'audiofe',
42
43
  'calc_metric_spenh',
43
44
  'doc',
44
- 'evaluate',
45
45
  'genft',
46
46
  'genmix',
47
47
  'genmixdb',
@@ -1,6 +1,4 @@
1
1
  # SonusAI mixture utilities
2
- from .audio import calculate_audio_from_transform
3
- from .audio import calculate_transform_from_audio
4
2
  from .audio import get_duration
5
3
  from .audio import get_next_noise
6
4
  from .audio import get_num_samples
@@ -83,6 +81,7 @@ from .datatypes import TruthFunctionConfig
83
81
  from .datatypes import TruthSetting
84
82
  from .datatypes import TruthSettings
85
83
  from .datatypes import UniversalSNR
84
+ from .feature import get_audio_from_feature
86
85
  from .feature import get_feature_from_audio
87
86
  from .generation import generate_mixtures
88
87
  from .generation import get_all_snrs_from_config
@@ -102,8 +101,10 @@ from .helpers import augmented_noise_samples
102
101
  from .helpers import augmented_target_samples
103
102
  from .helpers import check_audio_files_exist
104
103
  from .helpers import forward_transform
104
+ from .helpers import get_audio_from_transform
105
105
  from .helpers import get_ft
106
106
  from .helpers import get_segsnr
107
+ from .helpers import get_transform_from_audio
107
108
  from .helpers import get_truth_t
108
109
  from .helpers import inverse_transform
109
110
  from .helpers import mixture_metadata
sonusai/mixture/audio.py CHANGED
@@ -1,11 +1,6 @@
1
1
  from functools import lru_cache
2
2
 
3
- from pyaaware import ForwardTransform
4
- from pyaaware import InverseTransform
5
-
6
- from sonusai.mixture.datatypes import AudioF
7
3
  from sonusai.mixture.datatypes import AudioT
8
- from sonusai.mixture.datatypes import EnergyT
9
4
  from sonusai.mixture.datatypes import ImpulseResponseData
10
5
 
11
6
 
@@ -22,35 +17,6 @@ def get_next_noise(audio: AudioT, offset: int, length: int) -> AudioT:
22
17
  return np.take(audio, range(offset, offset + length), mode='wrap')
23
18
 
24
19
 
25
- def calculate_transform_from_audio(audio: AudioT,
26
- transform: ForwardTransform) -> tuple[AudioF, EnergyT]:
27
- """Apply forward transform to input audio data to generate transform data
28
-
29
- :param audio: Time domain data [samples]
30
- :param transform: ForwardTransform object
31
- :return: Frequency domain data [frames, bins], Energy [frames]
32
- """
33
- f, e = transform.execute_all(audio)
34
- return f.transpose(), e
35
-
36
-
37
- def calculate_audio_from_transform(data: AudioF,
38
- transform: InverseTransform,
39
- trim: bool = True) -> tuple[AudioT, EnergyT]:
40
- """Apply inverse transform to input transform data to generate audio data
41
-
42
- :param data: Frequency domain data [frames, bins]
43
- :param transform: InverseTransform object
44
- :param trim: Removes starting samples so output waveform will be time-aligned with input waveform to the transform
45
- :return: Time domain data [samples], Energy [frames]
46
- """
47
- t, e = transform.execute_all(data.transpose())
48
- if trim:
49
- t = t[transform.N - transform.R:]
50
-
51
- return t, e
52
-
53
-
54
20
  def get_duration(audio: AudioT) -> float:
55
21
  """Get duration of audio in seconds
56
22
 
@@ -304,7 +304,7 @@ class FeatureGeneratorInfo:
304
304
  decimation: int
305
305
  stride: int
306
306
  step: int
307
- num_bands: int
307
+ feature_parameters: int
308
308
  ft_config: TransformConfig
309
309
  eft_config: TransformConfig
310
310
  it_config: TransformConfig
@@ -1,51 +1,105 @@
1
+ from typing import Optional
2
+
3
+ from sonusai.mixture.datatypes import AudioF
1
4
  from sonusai.mixture.datatypes import AudioT
2
5
  from sonusai.mixture.datatypes import Feature
3
6
 
4
7
 
5
- def get_feature_from_audio(audio: AudioT, feature: str) -> Feature:
6
- from dataclasses import asdict
8
+ def get_feature_from_audio(audio: AudioT,
9
+ feature_mode: str,
10
+ num_classes: Optional[int] = 1,
11
+ truth_mutex: Optional[bool] = False) -> Feature:
12
+ """Apply forward transform and generate feature data from audio data
7
13
 
14
+ :param audio: Time domain audio data [samples]
15
+ :param feature_mode: Feature mode
16
+ :param num_classes: Number of classes
17
+ :param truth_mutex: Whether to calculate 'other' label
18
+ :return: Feature data [frames, strides, feature_parameters]
19
+ """
8
20
  import numpy as np
9
21
  from pyaaware import FeatureGenerator
10
22
 
11
23
  from .augmentation import pad_audio_to_frame
12
- from .datatypes import FeatureGeneratorConfig
13
24
  from .datatypes import TransformConfig
14
25
  from .helpers import forward_transform
15
- from .truth import truth_reduction
16
26
 
17
- num_classes = 1
18
- truth_mutex = False
19
- truth_reduction_function = 'max'
27
+ fg = FeatureGenerator(feature_mode=feature_mode,
28
+ num_classes=num_classes,
29
+ truth_mutex=truth_mutex)
20
30
 
21
- fg_config = FeatureGeneratorConfig(feature_mode=feature,
22
- num_classes=num_classes,
23
- truth_mutex=truth_mutex)
24
- fg = FeatureGenerator(**asdict(fg_config))
25
31
  feature_step_samples = fg.ftransform_R * fg.decimation * fg.step
26
-
27
32
  audio = pad_audio_to_frame(audio, feature_step_samples)
28
- samples = len(audio)
29
- audio_f = forward_transform(audio, TransformConfig(N=fg.ftransform_N,
33
+
34
+ audio_f = forward_transform(audio=audio,
35
+ config=TransformConfig(N=fg.ftransform_N,
30
36
  R=fg.ftransform_R,
31
37
  bin_start=fg.bin_start,
32
38
  bin_end=fg.bin_end,
33
39
  ttype=fg.ftransform_ttype))
34
40
 
41
+ samples = len(audio)
35
42
  transform_frames = samples // fg.ftransform_R
36
43
  feature_frames = samples // feature_step_samples
37
44
 
38
- truth_t = np.empty((samples, num_classes), dtype=np.float32)
39
-
40
- data = np.empty((feature_frames, fg.stride, fg.num_bands), dtype=np.float32)
45
+ feature = np.empty((feature_frames, fg.stride, fg.feature_parameters), dtype=np.float32)
41
46
 
42
47
  feature_frame = 0
43
48
  for transform_frame in range(transform_frames):
44
- indices = slice(transform_frame * fg.ftransform_R, (transform_frame + 1) * fg.ftransform_R)
45
- fg.execute(audio_f[transform_frame], truth_reduction(truth_t[indices], truth_reduction_function))
49
+ fg.execute(audio_f[transform_frame])
46
50
 
47
51
  if fg.eof():
48
- data[feature_frame] = fg.feature()
52
+ feature[feature_frame] = fg.feature()
49
53
  feature_frame += 1
50
54
 
51
- return data
55
+ return feature
56
+
57
+
58
+ def get_audio_from_feature(feature: Feature,
59
+ feature_mode: str,
60
+ num_classes: Optional[int] = 1,
61
+ truth_mutex: Optional[bool] = False,
62
+ trim: Optional[bool] = True) -> AudioT:
63
+ """Apply inverse transform to feature data to generate audio data
64
+
65
+ :param feature: Feature data [frames, strides, feature_parameters]
66
+ :param feature_mode: Feature mode
67
+ :param num_classes: Number of classes
68
+ :param truth_mutex: Whether to calculate 'other' label
69
+ :param trim: Whether to trim the audio data
70
+ :return: Audio data [samples]
71
+ """
72
+ import numpy as np
73
+
74
+ from pyaaware import FeatureGenerator
75
+
76
+ from .datatypes import TransformConfig
77
+ from .helpers import inverse_transform
78
+ from sonusai.utils.stacked_complex import unstack_complex
79
+
80
+ fg = FeatureGenerator(feature_mode=feature_mode,
81
+ num_classes=num_classes,
82
+ truth_mutex=truth_mutex)
83
+
84
+ feature_complex = unstack_complex(feature)
85
+ if feature_mode[0:1] == 'h':
86
+ feature_complex = _power_uncompress(feature_complex)
87
+ return np.squeeze(inverse_transform(transform=feature_complex,
88
+ config=TransformConfig(N=fg.itransform_N,
89
+ R=fg.itransform_R,
90
+ bin_start=fg.bin_start,
91
+ bin_end=fg.bin_end,
92
+ ttype=fg.itransform_ttype),
93
+ trim=trim))
94
+
95
+
96
+ def _power_uncompress(feature: AudioF) -> AudioF:
97
+ import numpy as np
98
+
99
+ mag = np.abs(feature)
100
+ phase = np.angle(feature)
101
+ mag = mag ** (1. / 0.3)
102
+ real_uncompress = mag * np.cos(phase)
103
+ imag_uncompress = mag * np.sin(phase)
104
+
105
+ return real_uncompress + 1j * imag_uncompress