sonusai 0.15.8__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. sonusai/__init__.py +35 -4
  2. sonusai/audiofe.py +237 -0
  3. sonusai/calc_metric_spenh.py +21 -12
  4. sonusai/genft.py +2 -1
  5. sonusai/genmixdb.py +5 -5
  6. sonusai/lsdb.py +2 -2
  7. sonusai/main.py +58 -61
  8. sonusai/mixture/__init__.py +4 -2
  9. sonusai/mixture/audio.py +0 -34
  10. sonusai/mixture/config.py +1 -2
  11. sonusai/mixture/datatypes.py +1 -1
  12. sonusai/mixture/feature.py +75 -21
  13. sonusai/mixture/helpers.py +60 -30
  14. sonusai/mixture/log_duration_and_sizes.py +2 -2
  15. sonusai/mixture/mixdb.py +13 -10
  16. sonusai/mixture/spectral_mask.py +14 -14
  17. sonusai/mixture/truth_functions/data.py +1 -1
  18. sonusai/mixture/truth_functions/target.py +2 -2
  19. sonusai/mkmanifest.py +29 -2
  20. sonusai/onnx_predict.py +1 -1
  21. sonusai/plot.py +4 -4
  22. sonusai/post_spenh_targetf.py +8 -8
  23. sonusai/utils/__init__.py +8 -7
  24. sonusai/utils/asl_p56.py +3 -3
  25. sonusai/utils/asr.py +35 -8
  26. sonusai/utils/asr_functions/__init__.py +0 -5
  27. sonusai/utils/asr_functions/aaware_whisper.py +2 -2
  28. sonusai/utils/asr_manifest_functions/__init__.py +1 -0
  29. sonusai/utils/asr_manifest_functions/mcgill_speech.py +29 -0
  30. sonusai/utils/audio_devices.py +41 -0
  31. sonusai/utils/calculate_input_shape.py +3 -4
  32. sonusai/utils/create_timestamp.py +5 -0
  33. sonusai/utils/{trim_docstring.py → docstring.py} +20 -0
  34. sonusai/utils/model_utils.py +30 -0
  35. sonusai/utils/onnx_utils.py +19 -45
  36. sonusai/utils/reshape.py +11 -11
  37. sonusai/utils/wave.py +12 -5
  38. {sonusai-0.15.8.dist-info → sonusai-0.16.0.dist-info}/METADATA +8 -19
  39. {sonusai-0.15.8.dist-info → sonusai-0.16.0.dist-info}/RECORD +41 -54
  40. {sonusai-0.15.8.dist-info → sonusai-0.16.0.dist-info}/WHEEL +1 -1
  41. sonusai/data_generator/__init__.py +0 -5
  42. sonusai/data_generator/dataset_from_mixdb.py +0 -143
  43. sonusai/data_generator/keras_from_mixdb.py +0 -169
  44. sonusai/data_generator/torch_from_mixdb.py +0 -122
  45. sonusai/evaluate.py +0 -245
  46. sonusai/keras_onnx.py +0 -86
  47. sonusai/keras_predict.py +0 -231
  48. sonusai/keras_train.py +0 -334
  49. sonusai/torchl_onnx.py +0 -216
  50. sonusai/torchl_predict.py +0 -547
  51. sonusai/torchl_train.py +0 -223
  52. sonusai/utils/asr_functions/aixplain_whisper.py +0 -59
  53. sonusai/utils/asr_functions/data.py +0 -16
  54. sonusai/utils/asr_functions/deepgram.py +0 -97
  55. sonusai/utils/asr_functions/fastwhisper.py +0 -90
  56. sonusai/utils/asr_functions/google.py +0 -95
  57. sonusai/utils/asr_functions/whisper.py +0 -49
  58. sonusai/utils/keras_utils.py +0 -226
  59. {sonusai-0.15.8.dist-info → sonusai-0.16.0.dist-info}/entry_points.txt +0 -0
sonusai/torchl_predict.py DELETED
@@ -1,547 +0,0 @@
1
- """sonusai torchl_predict
2
-
3
- usage: torchl_predict [-hvrw] [-i MIXID] [-a ACCEL] [-p PREC] [-d DLCPU] [-m MODEL]
4
- (-k CKPT) [-b BATCH] [-t TSTEPS] INPUT ...
5
-
6
- options:
7
- -h, --help
8
- -v, --verbose Be verbose.
9
- -i MIXID, --mixid MIXID Mixture ID(s) to use if input is a mixture database. [default: *].
10
- -a ACCEL, --accelerator ACCEL Accelerator to use in PL trainer in non-reset mode [default: auto]
11
- -p PREC, --precision PREC Precision to use in PL trainer in non-reset mode. [default: 32]
12
- -d DLCPU, --dataloader-cpus Number of workers/cpus for dataloader. [default: 0]
13
- -m MODEL, --model MODEL PL model .py file path.
14
- -k CKPT, --checkpoint CKPT PL checkpoint file with weights.
15
- -b BATCH, --batch BATCH Batch size (deprecated and forced to 1). [default: 1]
16
- -t TSTEPS, --tsteps TSTEPS Timesteps. If 0, dim is not included/expected in model. [default: 0]
17
- -r, --reset Reset model between each file.
18
- -w, --wavdbg Write debug .wav files of feature input, truth, and predict. [default: False]
19
-
20
- Run PL (Pytorch Lightning) prediction with model and checkpoint input using input data from a
21
- SonusAI mixture database.
22
- The PL model is imported from MODEL .py file and weights loaded from checkpoint file CKPT.
23
-
24
- Inputs:
25
- ACCEL Accelerator used for PL prediction. As of PL v2.0.8: auto, cpu, cuda, hpu, ipu, mps, tpu
26
- PREC Precision used in PL prediction. PL trainer will convert model+weights to specified prec.
27
- As of PL v2.0.8:
28
- ('16-mixed', 'bf16-mixed', '32-true', '64-true', 64, 32, 16, '64', '32', '16', 'bf16')
29
- MODEL Path to a .py with MyHyperModel PL model class definition
30
- CKPT A PL checkpoint file with weights.
31
- INPUT The input data must be one of the following:
32
- * directory
33
- Use SonusAI mixture database directory, generate feature and truth data if not found.
34
- Run prediction on the feature. The MIXID is required (or default which is *)
35
-
36
- * Single WAV file or glob of WAV files
37
- Using the given model, generate feature data and run prediction. A model file must be
38
- provided. The MIXID is ignored.
39
-
40
- Outputs the following to tpredict-<TIMESTAMP> directory:
41
- <id>.h5
42
- dataset: predict
43
- torch_predict.log
44
-
45
- """
46
- from os.path import join
47
- from typing import Any
48
-
49
- import h5py
50
- import torch
51
- from lightning.pytorch.callbacks import BasePredictionWriter
52
-
53
- from sonusai import logger
54
- from sonusai.mixture import Feature
55
-
56
-
57
- class CustomWriter(BasePredictionWriter):
58
- def __init__(self, output_dir, write_interval):
59
- super().__init__(write_interval)
60
- self.output_dir = output_dir
61
-
62
- def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
63
- # this will create N (num processes) files in `output_dir` each containing
64
- # the predictions of it's respective rank
65
- # torch.save(predictions, os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt"))
66
-
67
- # optionally, you can also save `batch_indices` to get the information about the data index
68
- # from your prediction data
69
- num_dev = len(batch_indices)
70
- logger.debug(f'Num dev: {num_dev}, prediction writer global rank: {trainer.global_rank}')
71
- len_pred = len(predictions) # for debug, should be num_dev
72
- logger.debug(f'len predictions: {len_pred}, len batch_indices0 {len(batch_indices[0])}')
73
- logger.debug(f'Prediction writer batch indices: {batch_indices}')
74
-
75
- logger.info(f'Predictions returned: {len(predictions)}, writing to .h5 files ...')
76
- for ndi in range(num_dev): # iterate over list devices (num of batch groups)
77
- num_batches = len(batch_indices[ndi]) # num batches in dev
78
- for bi in range(num_batches): # iterate over list of batches per dev
79
- bsz = len(batch_indices[ndi][bi]) # batch size
80
- for di in range(bsz):
81
- gid = batch_indices[0][bi][di]
82
- # gid = (bgi+1)*bi + bi
83
- # gid = bgi + bi
84
- logger.debug(f'{ndi}, {bi}, {di}: global id: {gid}')
85
- output_name = join(self.output_dir, trainer.predict_dataloaders.dataset.mixdb.mixtures[gid].name)
86
- # output_name = join(output_dir, mixdb.mixtures[i].name)
87
- pdat = predictions[bi][di, None].cpu().numpy()
88
- logger.debug(f'Writing predict shape {pdat.shape} to {output_name}')
89
- with h5py.File(output_name, 'a') as f:
90
- if 'predict' in f:
91
- del f['predict']
92
- f.create_dataset('predict', data=pdat)
93
-
94
- # output_name = join(self.output_dir,trainer.predict_dataloaders.dataset.mixdb.mixtures[0].name)
95
- # logger.debug(f'Writing predict shape {pdat.shape} to {output_name}')
96
- # torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"))
97
-
98
-
99
- def power_compress(x):
100
- real = x[..., 0]
101
- imag = x[..., 1]
102
- spec = torch.complex(real, imag)
103
- mag = torch.abs(spec)
104
- phase = torch.angle(spec)
105
- mag = mag ** 0.3
106
- real_compress = mag * torch.cos(phase)
107
- imag_compress = mag * torch.sin(phase)
108
- return torch.stack([real_compress, imag_compress], 1)
109
-
110
-
111
- def power_uncompress(real, imag):
112
- spec = torch.complex(real, imag)
113
- mag = torch.abs(spec)
114
- phase = torch.angle(spec)
115
- mag = mag ** (1. / 0.3)
116
- real_compress = mag * torch.cos(phase)
117
- imag_compress = mag * torch.sin(phase)
118
- return torch.stack([real_compress, imag_compress], -1)
119
-
120
-
121
- def main() -> None:
122
- from docopt import docopt
123
-
124
- import sonusai
125
- from sonusai.utils import trim_docstring
126
-
127
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
128
-
129
- verbose = args['--verbose']
130
- mixids = args['--mixid']
131
- accel = args['--accelerator']
132
- prec = args['--precision']
133
- dlcpu = int(args['--dataloader-cpus'])
134
- modelpath = args['--model']
135
- ckpt_name = args['--checkpoint']
136
- batch_size = args['--batch']
137
- timesteps = args['--tsteps']
138
- reset = args['--reset']
139
- wavdbg = args['--wavdbg'] # write .wav if true
140
- input_name = args['INPUT']
141
-
142
- from os import makedirs
143
- from os.path import basename
144
- from os.path import isdir
145
- from os.path import isfile
146
- from os.path import join
147
- from os.path import splitext
148
- from os.path import normpath
149
- import h5py
150
- # from sonusai.utils import float_to_int16
151
-
152
- from torchinfo import summary
153
- from sonusai import create_file_handler
154
- from sonusai import initial_log_messages
155
- from sonusai import update_console_handler
156
- from sonusai.mixture import MixtureDatabase
157
- from sonusai.mixture import get_feature_from_audio
158
- from sonusai.utils import import_keras_model
159
- from sonusai.mixture import read_audio
160
- from sonusai.utils import create_ts_name
161
- from sonusai.data_generator import TorchFromMixtureDatabase
162
-
163
- if batch_size is not None:
164
- batch_size = int(batch_size)
165
- if batch_size != 1:
166
- batch_size = 1
167
- logger.info(f'For now prediction only supports batch_size = 1, forcing it to 1 now')
168
-
169
- if timesteps is not None:
170
- timesteps = int(timesteps)
171
-
172
- if len(input_name) == 1 and isdir(input_name[0]):
173
- in_basename = basename(normpath(input_name[0]))
174
- else:
175
- in_basename = ''
176
-
177
- output_dir = create_ts_name('tpredict-' + in_basename)
178
- makedirs(output_dir, exist_ok=True)
179
-
180
- # Setup logging file
181
- logger.info(f'Created output subdirectory {output_dir}')
182
- create_file_handler(join(output_dir, 'torchl_predict.log'))
183
- update_console_handler(verbose)
184
- initial_log_messages('torch_predict')
185
- logger.info(f'torch {torch.__version__}')
186
-
187
- # Load checkpoint first to get hparams if available
188
- try:
189
- checkpoint = torch.load(ckpt_name, map_location=lambda storage, loc: storage)
190
- except Exception as e:
191
- logger.exception(f'Error: could not load checkpoint from {ckpt_name}: {e}')
192
- raise SystemExit(1)
193
-
194
- # Import model definition file
195
- model_base = basename(modelpath)
196
- model_root = splitext(model_base)[0]
197
- logger.info(f'Importing {modelpath}')
198
- litemodule = import_keras_model(modelpath)
199
-
200
- if 'hyper_parameters' in checkpoint:
201
- hparams = checkpoint['hyper_parameters']
202
- logger.info(f'Found checkpoint file with hyper-params named {checkpoint["hparams_name"]} '
203
- f'with {len(hparams)} total hparams.')
204
- if batch_size is not None and hparams['batch_size'] != batch_size:
205
- if batch_size != 1:
206
- batch_size = 1
207
- logger.info(f'For now prediction only supports batch_size = 1, forcing it to 1 now')
208
- logger.info(f'Overriding batch_size: default = {hparams["batch_size"]}; specified = {batch_size}.')
209
- hparams["batch_size"] = batch_size
210
-
211
- if timesteps is not None:
212
- if hparams['timesteps'] == 0 and timesteps != 0:
213
- logger.warning(f'Model does not contain timesteps; ignoring override.')
214
- timesteps = 0
215
-
216
- if hparams['timesteps'] != 0 and timesteps == 0:
217
- logger.warning(f'Model contains timesteps; ignoring override, using model default.')
218
- timesteps = hparams['timesteps']
219
-
220
- if hparams['timesteps'] != timesteps:
221
- logger.info(f'Overriding timesteps: default = {hparams["timesteps"]}; specified = {timesteps}.')
222
- hparams['timesteps'] = timesteps
223
-
224
- logger.info(f'Building model with hparams and batch_size={batch_size}, timesteps={timesteps}')
225
- try:
226
- model = litemodule.MyHyperModel(**hparams) # use hparams
227
- # litemodule.MyHyperModel.load_from_checkpoint(ckpt_name, **hparams)
228
- except Exception as e:
229
- logger.exception(f'Error: model build (MyHyperModel) in {model_base} failed: {e}')
230
- raise SystemExit(1)
231
- else:
232
- logger.info(f'Warning: found checkpoint with no hyper-parameters, building model with defaults')
233
- try:
234
- tmp = litemodule.MyHyperModel() # use default hparams
235
- except Exception as e:
236
- logger.exception(f'Error: model build (MyHyperModel) in {model_base} failed: {e}')
237
- raise SystemExit(1)
238
-
239
- if batch_size is not None:
240
- if tmp.batch_size != batch_size:
241
- logger.info(f'Overriding batch_size: default = {tmp.batch_size}; specified = {batch_size}.')
242
- else:
243
- batch_size = tmp.batch_size # inherit
244
-
245
- if timesteps is not None:
246
- if tmp.timesteps == 0 and timesteps != 0:
247
- logger.warning(f'Model does not contain timesteps; ignoring override.')
248
- timesteps = 0
249
-
250
- if tmp.timesteps != 0 and timesteps == 0:
251
- logger.warning(f'Model contains timesteps; ignoring override.')
252
- timesteps = tmp.timesteps
253
-
254
- if tmp.timesteps != timesteps:
255
- logger.info(f'Overriding timesteps: default = {tmp.timesteps}; specified = {timesteps}.')
256
- else:
257
- timesteps = tmp.timesteps
258
-
259
- logger.info(f'Building model with default hparams and batch_size= {batch_size}, timesteps={timesteps}')
260
- model = litemodule.MyHyperModel(timesteps=timesteps, batch_size=batch_size)
261
-
262
- logger.info('')
263
- logger.info(summary(model))
264
- logger.info('')
265
- logger.info(f'feature {model.hparams.feature}')
266
- logger.info(f'num_classes {model.num_classes}')
267
- logger.info(f'batch_size {model.hparams.batch_size}')
268
- logger.info(f'timesteps {model.hparams.timesteps}')
269
- logger.info(f'flatten {model.flatten}')
270
- logger.info(f'add1ch {model.add1ch}')
271
- logger.info(f'truth_mutex {model.truth_mutex}')
272
- logger.info(f'input_shape {model.input_shape}')
273
- logger.info('')
274
- logger.info(f'Loading weights from {ckpt_name}')
275
- # model = model.load_from_checkpoint(ckpt_name) # weights only, needs investigation
276
- model.load_state_dict(checkpoint["state_dict"])
277
- model.eval()
278
-
279
- logger.info('')
280
- # Load mixture database and setup dataloader
281
- if len(input_name) == 1 and isdir(input_name[0]): # Single path to mixdb subdir
282
- input_name = input_name[0]
283
- logger.info(f'Loading mixture database from {input_name}')
284
- mixdb = MixtureDatabase(input_name)
285
- logger.info(f'Sonusai mixture db: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes')
286
-
287
- if mixdb.feature != model.hparams.feature:
288
- logger.warning(f'Feature in mixture database {mixdb.feature} does not match feature in model')
289
- # raise SystemExit(1)
290
-
291
- # TBD check num_classes ??
292
-
293
- p_mixids = mixdb.mixids_to_list(mixids)
294
- sampler = None
295
- p_datagen = TorchFromMixtureDatabase(mixdb=mixdb,
296
- mixids=p_mixids,
297
- batch_size=model.hparams.batch_size,
298
- cut_len=0,
299
- flatten=model.flatten,
300
- add1ch=model.add1ch,
301
- random_cut=False,
302
- sampler=sampler,
303
- drop_last=False,
304
- num_workers=dlcpu)
305
-
306
- if wavdbg: # setup for wav write if enabled
307
- # Info needed to setup inverse transform
308
- from pyaaware import FeatureGenerator
309
- from pyaaware import TorchInverseTransform
310
- from torchaudio import save
311
- # from sonusai.utils import write_wav
312
-
313
- half = model.num_classes // 2
314
- fg = FeatureGenerator(feature_mode=model.hparams.feature,
315
- num_classes=model.num_classes,
316
- truth_mutex=model.truth_mutex)
317
- itf = TorchInverseTransform(N=fg.itransform_N,
318
- R=fg.itransform_R,
319
- bin_start=fg.bin_start,
320
- bin_end=fg.bin_end,
321
- ttype=fg.itransform_ttype)
322
-
323
- if mixdb.target_files[0].truth_settings[0].function == 'target_f' or \
324
- mixdb.target_files[0].truth_settings[0].function == 'target_mixture_f':
325
- enable_truth_wav = True
326
- else:
327
- enable_truth_wav = False
328
-
329
- if mixdb.target_files[0].truth_settings[0].function == 'target_mixture_f':
330
- enable_mix_wav = True
331
- else:
332
- enable_mix_wav = False
333
-
334
- if reset:
335
- logger.info(f'Running {mixdb.num_mixtures} mixtures individually with model reset ...')
336
- for idx, val in enumerate(p_datagen):
337
- # truth = val[1]
338
- feature = val[0]
339
- with torch.no_grad():
340
- ypred = model(feature)
341
- output_name = join(output_dir, mixdb.mixtures[idx].name)
342
- pdat = ypred.detach().numpy()
343
- if timesteps > 0:
344
- logger.debug(f'In and out tsteps: {feature.shape[1]},{pdat.shape[1]}')
345
- logger.debug(f'Writing predict shape {pdat.shape} to {output_name}')
346
- with h5py.File(output_name, 'a') as f:
347
- if 'predict' in f:
348
- del f['predict']
349
- f.create_dataset('predict', data=pdat)
350
-
351
- if wavdbg:
352
- owav_base = splitext(output_name)[0]
353
- tmp = torch.complex(ypred[..., :half], ypred[..., half:]).permute(2, 0, 1).detach()
354
- predwav, _ = itf.execute_all(tmp)
355
- # predwav, _ = calculate_audio_from_transform(tmp, itf, trim=True)
356
- save(owav_base + '.wav', predwav.permute([1, 0]), 16000, encoding='PCM_S', bits_per_sample=16)
357
- if enable_truth_wav:
358
- # Note this support truth type target_f and target_mixture_f
359
- tmp = torch.complex(val[0][..., :half], val[0][..., half:2 * half]).permute(2, 0, 1).detach()
360
- truthwav, _ = itf.execute_all(tmp)
361
- save(owav_base + '_truth.wav', truthwav.permute([1, 0]), 16000, encoding='PCM_S',
362
- bits_per_sample=16)
363
-
364
- if enable_mix_wav:
365
- tmp = torch.complex(val[0][..., 2 * half:3 * half], val[0][..., 3 * half:]).permute(2, 0, 1)
366
- mixwav, _ = itf.execute_all(tmp.detach())
367
- save(owav_base + "_mix.wav", mixwav.permute([1, 0]), 16000, encoding='PCM_S',
368
- bits_per_sample=16)
369
- # write_wav(owav_base + "_truth.wav", truthwav, 16000)
370
-
371
- else:
372
- logger.info(f'Running {mixdb.num_mixtures} mixtures with model builtin prediction loop ...')
373
- from lightning.pytorch import Trainer
374
- pred_writer = CustomWriter(output_dir=output_dir, write_interval="epoch")
375
- trainer = Trainer(default_root_dir=output_dir,
376
- callbacks=[pred_writer],
377
- precision=prec,
378
- devices='auto',
379
- accelerator=accel) # prints avail GPU, TPU, IPU, HPU and selected device
380
- # trainer = Trainer(default_root_dir=output_dir,
381
- # devices='auto',
382
- # accelerator='auto') # prints avail GPU, TPU, IPU, HPU and selected device
383
- # logger.info(f'Strategy: {trainer.strategy.strategy_name}') # doesn't work for ddp strategy
384
- logger.info(f'Accelerator stats: {trainer.accelerator.get_device_stats(device=None)}')
385
- logger.info(f'World size: {trainer.world_size}')
386
- logger.info(f'Nodes: {trainer.num_nodes}')
387
- logger.info(f'Devices: {trainer.accelerator.auto_device_count()}')
388
-
389
- # Use builtin lightning prediction loop, returns a list
390
- # predictions = trainer.predict(model, p_datagen) # standard method, but no support distributed
391
- with torch.no_grad():
392
- trainer.predict(model, p_datagen)
393
- # predictions = model.predict_outputs
394
- # pred_batch_idx = model.predict_batch_idx
395
- # if trainer.world_size > 1:
396
- # ddp_max_mem = torch.cuda.max_memory_allocated(trainer.local_rank) / 1000
397
- # logger.info(f"GPU {trainer.local_rank} max memory using DDP: {ddp_max_mem:.2f} MB")
398
- # if not trainer.is_global_zero:
399
- # return
400
- # logger.debug(f'type predictions: {type(predictions)}, type batch_idx: {type(pred_batch_idx)}')
401
- # logger.debug(f'# predictions: {len(predictions)}, # batch_idx: {len(pred_batch_idx)}')
402
- # logger.debug(f'{pred_batch_idx}')
403
- # # # all_predictions = torch.cat(predictions) # predictions = torch.cat(predictions).cpu()
404
- # # if trainer.world_size > 1:
405
- # # # print(f'Predictions returned: {len(all_predictions)}')
406
- # # ddp_max_mem = torch.cuda.max_memory_allocated(trainer.local_rank) / 1000
407
- # # logger.info(f"GPU {trainer.local_rank} max memory using DDP: {ddp_max_mem:.2f} MB")
408
- # # gathered = [None] * torch.distributed.get_world_size()
409
- # # torch.distributed.all_gather_object(gathered, predictions)
410
- # # torch.distributed.all_gather_object(gathered, pred_batch_idx)
411
- # # torch.distributed.barrier()
412
- # # if not trainer.is_global_zero:
413
- # # return
414
- # # predictions = sum(gathered, [])
415
- # # if trainer.global_rank == 0:
416
- # # logger.info(f"All predictions gathered: {len(predictions)}")
417
- #
418
- # logger.info(f'Predictions returned: {len(predictions)}, writing to .h5 files ...')
419
- # #for idx, mixid in enumerate(p_mixids):
420
- # for i in pred_batch_idx: # note assumes batch 0:num_mix matches 0:num_mix in mixdb.mixtures
421
- # # print(f'{idx}, {mixid}')
422
- # output_name = join(output_dir, mixdb.mixtures[i].name)
423
- # pdat = predictions[i].cpu().numpy()
424
- # logger.debug(f'Writing predict shape {pdat.shape} to {output_name}')
425
- # with h5py.File(output_name, 'a') as f:
426
- # if 'predict' in f:
427
- # del f['predict']
428
- # f.create_dataset('predict', data=pdat)
429
- #
430
- # if wavdbg:
431
- # owav_base = splitext(output_name)[0]
432
- # tmp = torch.complex(predictions[idx][..., :half], predictions[idx][..., half:]).permute(2, 1, 0)
433
- # predwav, _ = itf.execute_all(tmp.squeeze().detach().numpy())
434
- # write_wav(owav_base + ".wav", predwav.detach().numpy(), 16000)
435
-
436
- logger.info(f'Saved results to {output_dir}')
437
- return
438
-
439
- # if reset:
440
- # # reset mode cycles through each file one at a time
441
- # for mixid in mixids:
442
- # feature, _ = mixdb.mixture_ft(mixid)
443
- # if feature.shape[0] > 2500:
444
- # print(f'Trimming input frames from {feature.shape[0]} to {2500},')
445
- # feature = feature[0:2500,::]
446
- # half = feature.shape[-1] // 2
447
- # noisy_spec_cmplx = torch.complex(torch.tensor(feature[..., :half]),
448
- # torch.tensor(feature[..., half:])).to(device)
449
- # del feature
450
- #
451
- # predict = _pad_and_predict(built_model=model, feature=noisy_spec_cmplx)
452
- # del noisy_spec_cmplx
453
- #
454
- # audio_est = torch_istft_olsa_hanns(predict, mixdb.it_config.N, mixdb.it_config.R).cpu()
455
- # del predict
456
- # output_name = join(output_dir, splitext(mixdb.mixtures[mixid].name)[0]+'.wav')
457
- # print(f'Saving prediction to {output_name}')
458
- # write_wav(name=output_name, audio=float_to_int16(audio_est.detach().numpy()).transpose())
459
- #
460
- # torch.cuda.empty_cache()
461
- #
462
- # # TBD .h5 predict file optional output file
463
- # # output_name = join(output_dir, mixdb.mixtures[mixid].name)
464
- # # with h5py.File(output_name, 'a') as f:
465
- # # if 'predict' in f:
466
- # # del f['predict']
467
- # # f.create_dataset(name='predict', data=predict)
468
- #
469
- # else:
470
- # # Run all data at once using a data generator
471
- # feature = KerasFromH5(mixdb=mixdb,
472
- # mixids=mixids,
473
- # batch_size=hypermodel.batch_size,
474
- # timesteps=hypermodel.timesteps,
475
- # flatten=hypermodel.flatten,
476
- # add1ch=hypermodel.add1ch)
477
- #
478
- # predict = built_model.predict(feature, batch_size=hypermodel.batch_size, verbose=1)
479
- # predict, _ = reshape_outputs(predict=predict, timesteps=hypermodel.timesteps)
480
- #
481
- # # Write data to separate files
482
- # for idx, mixid in enumerate(mixids):
483
- # output_name = join(output_dir, mixdb.mixtures[mixid].name)
484
- # with h5py.File(output_name, 'a') as f:
485
- # if 'predict' in f:
486
- # del f['predict']
487
- # f.create_dataset('predict', data=predict[feature.file_indices[idx]])
488
- #
489
- # logger.info(f'Saved results to {output_dir}')
490
- # return
491
-
492
- if not all(isfile(file) and splitext(file)[1] == '.wav' for file in input_name):
493
- logger.exception(f'Do not know how to process input from {input_name}')
494
- raise SystemExit(1)
495
-
496
- logger.info(f'Run prediction on {len(input_name):,} WAV files')
497
- for file in input_name:
498
- # Convert WAV to feature data
499
- audio = read_audio(file)
500
- feature = get_feature_from_audio(audio=audio, feature=model.feature)
501
-
502
- # feature, predict = _pad_and_predict(hypermodel=hypermodel,
503
- # built_model=built_model,
504
- # feature=feature,
505
- # frames_per_batch=frames_per_batch)
506
-
507
- # clean = torch_istft_olsa_hanns(clean_spec_cmplx, mixdb.ift_config.N, mixdb.ift_config.R)
508
-
509
- output_name = join(output_dir, splitext(basename(file))[0] + '.h5')
510
- with h5py.File(output_name, 'a') as f:
511
- if 'feature' in f:
512
- del f['feature']
513
- f.create_dataset(name='feature', data=feature)
514
-
515
- # if 'predict' in f:
516
- # del f['predict']
517
- # f.create_dataset(name='predict', data=predict)
518
-
519
- logger.info(f'Saved results to {output_dir}')
520
- del model
521
-
522
-
523
- def _pad_and_predict(built_model: Any, feature: Feature) -> torch.Tensor:
524
- """
525
- Run prediction on feature [frames,1,bins*2] (stacked complex numpy array, stride/tsteps=1)
526
- Returns predict output [batch,frames,bins] in complex torch.tensor
527
- """
528
- noisy_spec = power_compress(torch.view_as_real(torch.from_numpy(feature).permute(1, 0, 2)))
529
- # print(f'noisy_spec type {type(noisy_spec_cmplx)}')
530
- # print(f'noisy_spec dtype {noisy_spec_cmplx.dtype}')
531
- # print(f'noisy_spec size {noisy_spec_cmplx.shape}')
532
- with torch.no_grad():
533
- est_real, est_imag = built_model(noisy_spec) # expects in size [batch, 2, tsteps, bins]
534
- est_real, est_imag = est_real.permute(0, 1, 3, 2), est_imag.permute(0, 1, 3, 2)
535
- est_spec_uncompress = torch.view_as_complex(power_uncompress(est_real, est_imag).squeeze(1))
536
- # inv tf want [ch,frames,bins] complex (synonymous with [batch,tsteps,bins]), keep as torch.tensor
537
- predict = est_spec_uncompress.permute(0, 2, 1) # .detach().numpy()
538
-
539
- return predict
540
-
541
-
542
- if __name__ == '__main__':
543
- try:
544
- main()
545
- except KeyboardInterrupt:
546
- logger.info('Canceled due to keyboard interrupt')
547
- exit()