sonusai 0.15.9__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. sonusai/__init__.py +36 -4
  2. sonusai/audiofe.py +111 -106
  3. sonusai/calc_metric_spenh.py +38 -22
  4. sonusai/genft.py +15 -6
  5. sonusai/genmix.py +14 -6
  6. sonusai/genmixdb.py +15 -7
  7. sonusai/gentcst.py +13 -6
  8. sonusai/lsdb.py +15 -5
  9. sonusai/main.py +58 -61
  10. sonusai/mixture/__init__.py +1 -0
  11. sonusai/mixture/config.py +1 -2
  12. sonusai/mkmanifest.py +43 -8
  13. sonusai/mkwav.py +15 -6
  14. sonusai/onnx_predict.py +16 -6
  15. sonusai/plot.py +16 -6
  16. sonusai/post_spenh_targetf.py +13 -6
  17. sonusai/summarize_metric_spenh.py +71 -0
  18. sonusai/tplot.py +14 -6
  19. sonusai/utils/__init__.py +4 -7
  20. sonusai/utils/asl_p56.py +3 -3
  21. sonusai/utils/asr.py +35 -8
  22. sonusai/utils/asr_functions/__init__.py +0 -5
  23. sonusai/utils/asr_functions/aaware_whisper.py +2 -2
  24. sonusai/utils/asr_manifest_functions/__init__.py +1 -0
  25. sonusai/utils/asr_manifest_functions/mcgill_speech.py +29 -0
  26. sonusai/utils/{trim_docstring.py → docstring.py} +20 -0
  27. sonusai/utils/model_utils.py +30 -0
  28. sonusai/utils/onnx_utils.py +19 -45
  29. {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/METADATA +7 -25
  30. {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/RECORD +32 -46
  31. sonusai/data_generator/__init__.py +0 -5
  32. sonusai/data_generator/dataset_from_mixdb.py +0 -143
  33. sonusai/data_generator/keras_from_mixdb.py +0 -169
  34. sonusai/data_generator/torch_from_mixdb.py +0 -122
  35. sonusai/keras_onnx.py +0 -86
  36. sonusai/keras_predict.py +0 -231
  37. sonusai/keras_train.py +0 -334
  38. sonusai/torchl_onnx.py +0 -216
  39. sonusai/torchl_predict.py +0 -542
  40. sonusai/torchl_train.py +0 -223
  41. sonusai/utils/asr_functions/aixplain_whisper.py +0 -59
  42. sonusai/utils/asr_functions/data.py +0 -16
  43. sonusai/utils/asr_functions/deepgram.py +0 -97
  44. sonusai/utils/asr_functions/fastwhisper.py +0 -90
  45. sonusai/utils/asr_functions/google.py +0 -95
  46. sonusai/utils/asr_functions/whisper.py +0 -49
  47. sonusai/utils/keras_utils.py +0 -226
  48. {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/WHEEL +0 -0
  49. {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/entry_points.txt +0 -0
sonusai/torchl_predict.py DELETED
@@ -1,542 +0,0 @@
1
- """sonusai torchl_predict
2
-
3
- usage: torchl_predict [-hvrw] [-i MIXID] [-a ACCEL] [-p PREC] [-d DLCPU] [-m MODEL]
4
- (-k CKPT) [-b BATCH] [-t TSTEPS] INPUT ...
5
-
6
- options:
7
- -h, --help
8
- -v, --verbose Be verbose.
9
- -i MIXID, --mixid MIXID Mixture ID(s) to use if input is a mixture database. [default: *].
10
- -a ACCEL, --accelerator ACCEL Accelerator to use in PL trainer in non-reset mode [default: auto]
11
- -p PREC, --precision PREC Precision to use in PL trainer in non-reset mode. [default: 32]
12
- -d DLCPU, --dataloader-cpus Number of workers/cpus for dataloader. [default: 0]
13
- -m MODEL, --model MODEL PL model .py file path.
14
- -k CKPT, --checkpoint CKPT PL checkpoint file with weights.
15
- -b BATCH, --batch BATCH Batch size (deprecated and forced to 1). [default: 1]
16
- -t TSTEPS, --tsteps TSTEPS Timesteps. If 0, dim is not included/expected in model. [default: 0]
17
- -r, --reset Reset model between each file.
18
- -w, --wavdbg Write debug .wav files of feature input, truth, and predict. [default: False]
19
-
20
- Run PL (Pytorch Lightning) prediction with model and checkpoint input using input data from a
21
- SonusAI mixture database.
22
- The PL model is imported from MODEL .py file and weights loaded from checkpoint file CKPT.
23
-
24
- Inputs:
25
- ACCEL Accelerator used for PL prediction. As of PL v2.0.8: auto, cpu, cuda, hpu, ipu, mps, tpu
26
- PREC Precision used in PL prediction. PL trainer will convert model+weights to specified prec.
27
- As of PL v2.0.8:
28
- ('16-mixed', 'bf16-mixed', '32-true', '64-true', 64, 32, 16, '64', '32', '16', 'bf16')
29
- MODEL Path to a .py with MyHyperModel PL model class definition
30
- CKPT A PL checkpoint file with weights.
31
- INPUT The input data must be one of the following:
32
- * directory
33
- Use SonusAI mixture database directory, generate feature and truth data if not found.
34
- Run prediction on the feature. The MIXID is required (or default which is *)
35
-
36
- * Single WAV file or glob of WAV files
37
- Using the given model, generate feature data and run prediction. A model file must be
38
- provided. The MIXID is ignored.
39
-
40
- Outputs the following to tpredict-<TIMESTAMP> directory:
41
- <id>.h5
42
- dataset: predict
43
- torch_predict.log
44
-
45
- """
46
- from os import makedirs
47
- from os.path import basename
48
- from os.path import isdir
49
- from os.path import join
50
- from os.path import normpath
51
- from os.path import splitext
52
- from typing import Any
53
-
54
- import h5py
55
- import torch
56
- from docopt import docopt
57
- from lightning.pytorch import Trainer
58
- from lightning.pytorch.callbacks import BasePredictionWriter
59
- from pyaaware import FeatureGenerator
60
- from pyaaware import TorchInverseTransform
61
- from torchinfo import summary
62
-
63
- import sonusai
64
- from sonusai import create_file_handler
65
- from sonusai import initial_log_messages
66
- from sonusai import logger
67
- from sonusai import update_console_handler
68
- from sonusai.data_generator import TorchFromMixtureDatabase
69
- from sonusai.mixture import Feature
70
- from sonusai.mixture import MixtureDatabase
71
- from sonusai.mixture import get_audio_from_feature
72
- from sonusai.mixture import get_feature_from_audio
73
- from sonusai.mixture import read_audio
74
- from sonusai.utils import create_ts_name
75
- from sonusai.utils import import_keras_model
76
- from sonusai.utils import trim_docstring
77
- from sonusai.utils import write_wav
78
-
79
-
80
- class CustomWriter(BasePredictionWriter):
81
- def __init__(self, output_dir, write_interval):
82
- super().__init__(write_interval)
83
- self.output_dir = output_dir
84
-
85
- def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
86
- # this will create N (num processes) files in `output_dir` each containing
87
- # the predictions of its respective rank
88
- # torch.save(predictions, os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt"))
89
-
90
- # optionally, you can also save `batch_indices` to get the information about the data index
91
- # from your prediction data
92
- num_dev = len(batch_indices)
93
- logger.debug(f'Num dev: {num_dev}, prediction writer global rank: {trainer.global_rank}')
94
- len_pred = len(predictions) # for debug, should be num_dev
95
- logger.debug(f'len predictions: {len_pred}, len batch_indices0 {len(batch_indices[0])}')
96
- logger.debug(f'Prediction writer batch indices: {batch_indices}')
97
-
98
- logger.info(f'Predictions returned: {len(predictions)}, writing to .h5 files ...')
99
- for ndi in range(num_dev): # iterate over list devices (num of batch groups)
100
- num_batches = len(batch_indices[ndi]) # num batches in dev
101
- for bi in range(num_batches): # iterate over list of batches per dev
102
- bsz = len(batch_indices[ndi][bi]) # batch size
103
- for di in range(bsz):
104
- gid = batch_indices[0][bi][di]
105
- # gid = (bgi+1)*bi + bi
106
- # gid = bgi + bi
107
- logger.debug(f'{ndi}, {bi}, {di}: global id: {gid}')
108
- output_name = join(self.output_dir, trainer.predict_dataloaders.dataset.mixdb.mixtures[gid].name)
109
- # output_name = join(output_dir, mixdb.mixtures[i].name)
110
- pdat = predictions[bi][di, None].cpu().numpy()
111
- logger.debug(f'Writing predict shape {pdat.shape} to {output_name}')
112
- with h5py.File(output_name, 'a') as f:
113
- if 'predict' in f:
114
- del f['predict']
115
- f.create_dataset('predict', data=pdat)
116
-
117
- # output_name = join(self.output_dir,trainer.predict_dataloaders.dataset.mixdb.mixtures[0].name)
118
- # logger.debug(f'Writing predict shape {pdat.shape} to {output_name}')
119
- # torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"))
120
-
121
-
122
- def power_compress(x):
123
- real = x[..., 0]
124
- imag = x[..., 1]
125
- spec = torch.complex(real, imag)
126
- mag = torch.abs(spec)
127
- phase = torch.angle(spec)
128
- mag = mag ** 0.3
129
- real_compress = mag * torch.cos(phase)
130
- imag_compress = mag * torch.sin(phase)
131
- return torch.stack([real_compress, imag_compress], 1)
132
-
133
-
134
- def power_uncompress(real, imag):
135
- spec = torch.complex(real, imag)
136
- mag = torch.abs(spec)
137
- phase = torch.angle(spec)
138
- mag = mag ** (1. / 0.3)
139
- real_compress = mag * torch.cos(phase)
140
- imag_compress = mag * torch.sin(phase)
141
- return torch.stack([real_compress, imag_compress], -1)
142
-
143
-
144
- def main() -> None:
145
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
146
-
147
- verbose = args['--verbose']
148
- mixids = args['--mixid']
149
- accel = args['--accelerator']
150
- prec = args['--precision']
151
- dlcpu = int(args['--dataloader-cpus'])
152
- modelpath = args['--model']
153
- ckpt_name = args['--checkpoint']
154
- batch_size = args['--batch']
155
- timesteps = args['--tsteps']
156
- reset = args['--reset']
157
- wavdbg = args['--wavdbg'] # write .wav if true
158
- input_name = args['INPUT']
159
-
160
- if batch_size is not None:
161
- batch_size = int(batch_size)
162
- if batch_size != 1:
163
- batch_size = 1
164
- logger.info(f'For now prediction only supports batch_size = 1, forcing it to 1 now')
165
-
166
- if timesteps is not None:
167
- timesteps = int(timesteps)
168
-
169
- if len(input_name) == 1 and isdir(input_name[0]):
170
- in_basename = basename(normpath(input_name[0]))
171
- else:
172
- in_basename = ''
173
-
174
- output_dir = create_ts_name('tpredict-' + in_basename)
175
- makedirs(output_dir, exist_ok=True)
176
-
177
- # Setup logging file
178
- logger.info(f'Created output subdirectory {output_dir}')
179
- create_file_handler(join(output_dir, 'torchl_predict.log'))
180
- update_console_handler(verbose)
181
- initial_log_messages('torch_predict')
182
- logger.info(f'torch {torch.__version__}')
183
-
184
- # Load checkpoint first to get hparams if available
185
- try:
186
- checkpoint = torch.load(ckpt_name, map_location=lambda storage, loc: storage)
187
- except Exception as e:
188
- logger.exception(f'Error: could not load checkpoint from {ckpt_name}: {e}')
189
- raise SystemExit(1)
190
-
191
- # Import model definition file
192
- model_base = basename(modelpath)
193
- model_root = splitext(model_base)[0]
194
- logger.info(f'Importing {modelpath}')
195
- litemodule = import_keras_model(modelpath)
196
-
197
- if 'hyper_parameters' in checkpoint:
198
- hparams = checkpoint['hyper_parameters']
199
- logger.info(f'Found checkpoint file with hyper-params named {checkpoint["hparams_name"]} '
200
- f'with {len(hparams)} total hparams.')
201
- if batch_size is not None and hparams['batch_size'] != batch_size:
202
- if batch_size != 1:
203
- batch_size = 1
204
- logger.info(f'For now prediction only supports batch_size = 1, forcing it to 1 now')
205
- logger.info(f'Overriding batch_size: default = {hparams["batch_size"]}; specified = {batch_size}.')
206
- hparams["batch_size"] = batch_size
207
-
208
- if timesteps is not None:
209
- if hparams['timesteps'] == 0 and timesteps != 0:
210
- logger.warning(f'Model does not contain timesteps; ignoring override.')
211
- timesteps = 0
212
-
213
- if hparams['timesteps'] != 0 and timesteps == 0:
214
- logger.warning(f'Model contains timesteps; ignoring override, using model default.')
215
- timesteps = hparams['timesteps']
216
-
217
- if hparams['timesteps'] != timesteps:
218
- logger.info(f'Overriding timesteps: default = {hparams["timesteps"]}; specified = {timesteps}.')
219
- hparams['timesteps'] = timesteps
220
-
221
- logger.info(f'Building model with hparams and batch_size={batch_size}, timesteps={timesteps}')
222
- # hparams['cl_per_wght'] = 0.0
223
- # hparams['feature'] = 'hum00ns1'
224
- try:
225
- model = litemodule.MyHyperModel(**hparams) # use hparams
226
- # litemodule.MyHyperModel.load_from_checkpoint(ckpt_name, **hparams)
227
- except Exception as e:
228
- logger.exception(f'Error: model build (MyHyperModel) in {model_base} failed: {e}')
229
- raise SystemExit(1)
230
- else:
231
- logger.info(f'Warning: found checkpoint with no hyper-parameters, building model with defaults')
232
- try:
233
- tmp = litemodule.MyHyperModel() # use default hparams
234
- except Exception as e:
235
- logger.exception(f'Error: model build (MyHyperModel) in {model_base} failed: {e}')
236
- raise SystemExit(1)
237
-
238
- if batch_size is not None:
239
- if tmp.batch_size != batch_size:
240
- logger.info(f'Overriding batch_size: default = {tmp.batch_size}; specified = {batch_size}.')
241
- else:
242
- batch_size = tmp.batch_size # inherit
243
-
244
- if timesteps is not None:
245
- if tmp.timesteps == 0 and timesteps != 0:
246
- logger.warning(f'Model does not contain timesteps; ignoring override.')
247
- timesteps = 0
248
-
249
- if tmp.timesteps != 0 and timesteps == 0:
250
- logger.warning(f'Model contains timesteps; ignoring override.')
251
- timesteps = tmp.timesteps
252
-
253
- if tmp.timesteps != timesteps:
254
- logger.info(f'Overriding timesteps: default = {tmp.timesteps}; specified = {timesteps}.')
255
- else:
256
- timesteps = tmp.timesteps
257
-
258
- logger.info(f'Building model with default hparams and batch_size= {batch_size}, timesteps={timesteps}')
259
- model = litemodule.MyHyperModel(timesteps=timesteps, batch_size=batch_size)
260
-
261
- logger.info('')
262
- logger.info(summary(model))
263
- logger.info('')
264
- logger.info(f'feature {model.hparams.feature}')
265
- logger.info(f'num_classes {model.num_classes}')
266
- logger.info(f'batch_size {model.hparams.batch_size}')
267
- logger.info(f'timesteps {model.hparams.timesteps}')
268
- logger.info(f'flatten {model.flatten}')
269
- logger.info(f'add1ch {model.add1ch}')
270
- logger.info(f'truth_mutex {model.truth_mutex}')
271
- logger.info(f'input_shape {model.input_shape}')
272
- logger.info('')
273
- logger.info(f'Loading weights from {ckpt_name}')
274
- # model = model.load_from_checkpoint(ckpt_name) # weights only, needs investigation
275
- model.load_state_dict(checkpoint["state_dict"])
276
- model.eval()
277
-
278
- logger.info('')
279
- # Load mixture database and setup dataloader
280
- if len(input_name) == 1 and isdir(input_name[0]): # Single path to mixdb subdir
281
- input_name = input_name[0]
282
- logger.info(f'Loading mixture database from {input_name}')
283
- mixdb = MixtureDatabase(input_name)
284
- logger.info(f'Sonusai mixture db: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes')
285
-
286
- if mixdb.feature != model.hparams.feature:
287
- logger.warning(f'Feature in mixture database {mixdb.feature} does not match feature in model')
288
- # raise SystemExit(1)
289
-
290
- # TBD check num_classes ??
291
-
292
- p_mixids = mixdb.mixids_to_list(mixids)
293
- sampler = None
294
- p_datagen = TorchFromMixtureDatabase(mixdb=mixdb,
295
- mixids=p_mixids,
296
- batch_size=model.hparams.batch_size,
297
- cut_len=0,
298
- flatten=model.flatten,
299
- add1ch=model.add1ch,
300
- random_cut=False,
301
- sampler=sampler,
302
- drop_last=False,
303
- num_workers=dlcpu)
304
-
305
- # Info needed to set up inverse transform
306
- half = model.num_classes // 2
307
- fg = FeatureGenerator(feature_mode=model.hparams.feature,
308
- num_classes=model.num_classes,
309
- truth_mutex=model.truth_mutex)
310
- itf = TorchInverseTransform(N=fg.itransform_N,
311
- R=fg.itransform_R,
312
- bin_start=fg.bin_start,
313
- bin_end=fg.bin_end,
314
- ttype=fg.itransform_ttype)
315
-
316
- enable_truth_wav = False
317
- enable_mix_wav = False
318
- if wavdbg:
319
- if mixdb.target_files[0].truth_settings[0].function == 'target_mixture_f':
320
- enable_mix_wav = True
321
- enable_truth_wav = True
322
- elif mixdb.target_files[0].truth_settings[0].function == 'target_f':
323
- enable_truth_wav = True
324
-
325
- if reset:
326
- logger.info(f'Running {mixdb.num_mixtures} mixtures individually with model reset ...')
327
- for idx, val in enumerate(p_datagen):
328
- # truth = val[1]
329
- feature = val[0]
330
- with torch.no_grad():
331
- ypred = model(feature)
332
- output_name = join(output_dir, mixdb.mixtures[idx].name)
333
- pdat = ypred.detach().numpy()
334
- if timesteps > 0:
335
- logger.debug(f'In and out tsteps: {feature.shape[1]},{pdat.shape[1]}')
336
- logger.debug(f'Writing predict shape {pdat.shape} to {output_name}')
337
- with h5py.File(output_name, 'a') as f:
338
- if 'predict' in f:
339
- del f['predict']
340
- f.create_dataset('predict', data=pdat)
341
-
342
- if wavdbg:
343
- owav_base = splitext(output_name)[0]
344
- tmp = torch.complex(ypred[..., :half], ypred[..., half:]).permute(2, 0, 1).detach()
345
- itf.reset()
346
- predwav, _ = itf.execute_all(tmp)
347
- # predwav, _ = calculate_audio_from_transform(tmp.numpy(), itf, trim=True)
348
- write_wav(owav_base + '.wav', predwav.permute([1, 0]).numpy(), 16000)
349
- if enable_truth_wav:
350
- # Note this support truth type target_f and target_mixture_f
351
- tmp = torch.complex(val[0][..., :half], val[0][..., half:2 * half]).permute(2, 0, 1).detach()
352
- itf.reset()
353
- truthwav, _ = itf.execute_all(tmp)
354
- write_wav(owav_base + '_truth.wav', truthwav.permute([1, 0]).numpy(), 16000)
355
-
356
- if enable_mix_wav:
357
- tmp = torch.complex(val[0][..., 2 * half:3 * half], val[0][..., 3 * half:]).permute(2, 0, 1)
358
- itf.reset()
359
- mixwav, _ = itf.execute_all(tmp.detach())
360
- write_wav(owav_base + '_mix.wav', mixwav.permute([1, 0]).numpy(), 16000)
361
-
362
- else:
363
- logger.info(f'Running {mixdb.num_mixtures} mixtures with model builtin prediction loop ...')
364
- pred_writer = CustomWriter(output_dir=output_dir, write_interval="epoch")
365
- trainer = Trainer(default_root_dir=output_dir,
366
- callbacks=[pred_writer],
367
- precision=prec,
368
- devices='auto',
369
- accelerator=accel) # prints avail GPU, TPU, IPU, HPU and selected device
370
- # trainer = Trainer(default_root_dir=output_dir,
371
- # devices='auto',
372
- # accelerator='auto') # prints avail GPU, TPU, IPU, HPU and selected device
373
- # logger.info(f'Strategy: {trainer.strategy.strategy_name}') # doesn't work for ddp strategy
374
- logger.info(f'Accelerator stats: {trainer.accelerator.get_device_stats(device=None)}')
375
- logger.info(f'World size: {trainer.world_size}')
376
- logger.info(f'Nodes: {trainer.num_nodes}')
377
- logger.info(f'Devices: {trainer.accelerator.auto_device_count()}')
378
-
379
- # Use builtin lightning prediction loop, returns a list
380
- # predictions = trainer.predict(model, p_datagen) # standard method, but no support distributed
381
- with torch.no_grad():
382
- trainer.predict(model, p_datagen)
383
- # predictions = model.predict_outputs
384
- # pred_batch_idx = model.predict_batch_idx
385
- # if trainer.world_size > 1:
386
- # ddp_max_mem = torch.cuda.max_memory_allocated(trainer.local_rank) / 1000
387
- # logger.info(f"GPU {trainer.local_rank} max memory using DDP: {ddp_max_mem:.2f} MB")
388
- # if not trainer.is_global_zero:
389
- # return
390
- # logger.debug(f'type predictions: {type(predictions)}, type batch_idx: {type(pred_batch_idx)}')
391
- # logger.debug(f'# predictions: {len(predictions)}, # batch_idx: {len(pred_batch_idx)}')
392
- # logger.debug(f'{pred_batch_idx}')
393
- # # # all_predictions = torch.cat(predictions) # predictions = torch.cat(predictions).cpu()
394
- # # if trainer.world_size > 1:
395
- # # # print(f'Predictions returned: {len(all_predictions)}')
396
- # # ddp_max_mem = torch.cuda.max_memory_allocated(trainer.local_rank) / 1000
397
- # # logger.info(f"GPU {trainer.local_rank} max memory using DDP: {ddp_max_mem:.2f} MB")
398
- # # gathered = [None] * torch.distributed.get_world_size()
399
- # # torch.distributed.all_gather_object(gathered, predictions)
400
- # # torch.distributed.all_gather_object(gathered, pred_batch_idx)
401
- # # torch.distributed.barrier()
402
- # # if not trainer.is_global_zero:
403
- # # return
404
- # # predictions = sum(gathered, [])
405
- # # if trainer.global_rank == 0:
406
- # # logger.info(f"All predictions gathered: {len(predictions)}")
407
- #
408
- # logger.info(f'Predictions returned: {len(predictions)}, writing to .h5 files ...')
409
- # #for idx, mixid in enumerate(p_mixids):
410
- # for i in pred_batch_idx: # note assumes batch 0:num_mix matches 0:num_mix in mixdb.mixtures
411
- # # print(f'{idx}, {mixid}')
412
- # output_name = join(output_dir, mixdb.mixtures[i].name)
413
- # pdat = predictions[i].cpu().numpy()
414
- # logger.debug(f'Writing predict shape {pdat.shape} to {output_name}')
415
- # with h5py.File(output_name, 'a') as f:
416
- # if 'predict' in f:
417
- # del f['predict']
418
- # f.create_dataset('predict', data=pdat)
419
- #
420
- # if wavdbg:
421
- # owav_base = splitext(output_name)[0]
422
- # tmp = torch.complex(predictions[idx][..., :half], predictions[idx][..., half:]).permute(2, 1, 0)
423
- # predwav, _ = itf.execute_all(tmp.squeeze().detach().numpy())
424
- # write_wav(owav_base + ".wav", predwav.detach().numpy(), 16000)
425
-
426
- logger.info(f'Saved results to {output_dir}')
427
- return
428
-
429
- # if reset:
430
- # # reset mode cycles through each file one at a time
431
- # for mixid in mixids:
432
- # feature, _ = mixdb.mixture_ft(mixid)
433
- # if feature.shape[0] > 2500:
434
- # print(f'Trimming input frames from {feature.shape[0]} to {2500},')
435
- # feature = feature[0:2500,::]
436
- # half = feature.shape[-1] // 2
437
- # noisy_spec_cmplx = torch.complex(torch.tensor(feature[..., :half]),
438
- # torch.tensor(feature[..., half:])).to(device)
439
- # del feature
440
- #
441
- # predict = _pad_and_predict(built_model=model, feature=noisy_spec_cmplx)
442
- # del noisy_spec_cmplx
443
- #
444
- # audio_est = torch_istft_olsa_hanns(predict, mixdb.it_config.N, mixdb.it_config.R).cpu()
445
- # del predict
446
- # output_name = join(output_dir, splitext(mixdb.mixtures[mixid].name)[0]+'.wav')
447
- # print(f'Saving prediction to {output_name}')
448
- # write_wav(name=output_name, audio=float_to_int16(audio_est.detach().numpy()).transpose())
449
- #
450
- # torch.cuda.empty_cache()
451
- #
452
- # # TBD .h5 predict file optional output file
453
- # # output_name = join(output_dir, mixdb.mixtures[mixid].name)
454
- # # with h5py.File(output_name, 'a') as f:
455
- # # if 'predict' in f:
456
- # # del f['predict']
457
- # # f.create_dataset(name='predict', data=predict)
458
- #
459
- # else:
460
- # # Run all data at once using a data generator
461
- # feature = KerasFromH5(mixdb=mixdb,
462
- # mixids=mixids,
463
- # batch_size=hypermodel.batch_size,
464
- # timesteps=hypermodel.timesteps,
465
- # flatten=hypermodel.flatten,
466
- # add1ch=hypermodel.add1ch)
467
- #
468
- # predict = built_model.predict(feature, batch_size=hypermodel.batch_size, verbose=1)
469
- # predict, _ = reshape_outputs(predict=predict, timesteps=hypermodel.timesteps)
470
- #
471
- # # Write data to separate files
472
- # for idx, mixid in enumerate(mixids):
473
- # output_name = join(output_dir, mixdb.mixtures[mixid].name)
474
- # with h5py.File(output_name, 'a') as f:
475
- # if 'predict' in f:
476
- # del f['predict']
477
- # f.create_dataset('predict', data=predict[feature.file_indices[idx]])
478
- #
479
- # logger.info(f'Saved results to {output_dir}')
480
- # return
481
-
482
- logger.info(f'Run prediction on {len(input_name):,} audio files')
483
- for file in input_name:
484
- # Convert audio to feature data
485
- audio_in = read_audio(file)
486
- feature = get_feature_from_audio(audio=audio_in, feature_mode=model.hparams.feature)
487
-
488
- with torch.no_grad():
489
- predict = model(torch.tensor(feature))
490
-
491
- audio_out = get_audio_from_feature(feature=predict.numpy(), feature_mode=model.hparams.feature)
492
-
493
- output_name = join(output_dir, splitext(basename(file))[0] + '.h5')
494
- with h5py.File(output_name, 'a') as f:
495
- if 'audio_in' in f:
496
- del f['audio_in']
497
- f.create_dataset(name='audio_in', data=audio_in)
498
-
499
- if 'feature' in f:
500
- del f['feature']
501
- f.create_dataset(name='feature', data=feature)
502
-
503
- if 'predict' in f:
504
- del f['predict']
505
- f.create_dataset(name='predict', data=predict)
506
-
507
- if 'audio_out' in f:
508
- del f['audio_out']
509
- f.create_dataset(name='audio_out', data=audio_out)
510
-
511
- output_name = join(output_dir, splitext(basename(file))[0] + '_predict.wav')
512
- write_wav(output_name, audio_out, 16000)
513
-
514
- logger.info(f'Saved results to {output_dir}')
515
- del model
516
-
517
-
518
- def _pad_and_predict(built_model: Any, feature: Feature) -> torch.Tensor:
519
- """
520
- Run prediction on feature [frames,1,bins*2] (stacked complex numpy array, stride/tsteps=1)
521
- Returns predict output [batch,frames,bins] in complex torch.tensor
522
- """
523
- noisy_spec = power_compress(torch.view_as_real(torch.from_numpy(feature).permute(1, 0, 2)))
524
- # print(f'noisy_spec type {type(noisy_spec_cmplx)}')
525
- # print(f'noisy_spec dtype {noisy_spec_cmplx.dtype}')
526
- # print(f'noisy_spec size {noisy_spec_cmplx.shape}')
527
- with torch.no_grad():
528
- est_real, est_imag = built_model(noisy_spec) # expects in size [batch, 2, tsteps, bins]
529
- est_real, est_imag = est_real.permute(0, 1, 3, 2), est_imag.permute(0, 1, 3, 2)
530
- est_spec_uncompress = torch.view_as_complex(power_uncompress(est_real, est_imag).squeeze(1))
531
- # inv tf want [ch,frames,bins] complex (synonymous with [batch,tsteps,bins]), keep as torch.tensor
532
- predict = est_spec_uncompress.permute(0, 2, 1) # .detach().numpy()
533
-
534
- return predict
535
-
536
-
537
- if __name__ == '__main__':
538
- try:
539
- main()
540
- except KeyboardInterrupt:
541
- logger.info('Canceled due to keyboard interrupt')
542
- exit()