sonusai 0.15.9__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +36 -4
- sonusai/audiofe.py +111 -106
- sonusai/calc_metric_spenh.py +38 -22
- sonusai/genft.py +15 -6
- sonusai/genmix.py +14 -6
- sonusai/genmixdb.py +15 -7
- sonusai/gentcst.py +13 -6
- sonusai/lsdb.py +15 -5
- sonusai/main.py +58 -61
- sonusai/mixture/__init__.py +1 -0
- sonusai/mixture/config.py +1 -2
- sonusai/mkmanifest.py +43 -8
- sonusai/mkwav.py +15 -6
- sonusai/onnx_predict.py +16 -6
- sonusai/plot.py +16 -6
- sonusai/post_spenh_targetf.py +13 -6
- sonusai/summarize_metric_spenh.py +71 -0
- sonusai/tplot.py +14 -6
- sonusai/utils/__init__.py +4 -7
- sonusai/utils/asl_p56.py +3 -3
- sonusai/utils/asr.py +35 -8
- sonusai/utils/asr_functions/__init__.py +0 -5
- sonusai/utils/asr_functions/aaware_whisper.py +2 -2
- sonusai/utils/asr_manifest_functions/__init__.py +1 -0
- sonusai/utils/asr_manifest_functions/mcgill_speech.py +29 -0
- sonusai/utils/{trim_docstring.py → docstring.py} +20 -0
- sonusai/utils/model_utils.py +30 -0
- sonusai/utils/onnx_utils.py +19 -45
- {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/METADATA +7 -25
- {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/RECORD +32 -46
- sonusai/data_generator/__init__.py +0 -5
- sonusai/data_generator/dataset_from_mixdb.py +0 -143
- sonusai/data_generator/keras_from_mixdb.py +0 -169
- sonusai/data_generator/torch_from_mixdb.py +0 -122
- sonusai/keras_onnx.py +0 -86
- sonusai/keras_predict.py +0 -231
- sonusai/keras_train.py +0 -334
- sonusai/torchl_onnx.py +0 -216
- sonusai/torchl_predict.py +0 -542
- sonusai/torchl_train.py +0 -223
- sonusai/utils/asr_functions/aixplain_whisper.py +0 -59
- sonusai/utils/asr_functions/data.py +0 -16
- sonusai/utils/asr_functions/deepgram.py +0 -97
- sonusai/utils/asr_functions/fastwhisper.py +0 -90
- sonusai/utils/asr_functions/google.py +0 -95
- sonusai/utils/asr_functions/whisper.py +0 -49
- sonusai/utils/keras_utils.py +0 -226
- {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/WHEEL +0 -0
- {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/entry_points.txt +0 -0
sonusai/torchl_predict.py
DELETED
@@ -1,542 +0,0 @@
|
|
1
|
-
"""sonusai torchl_predict
|
2
|
-
|
3
|
-
usage: torchl_predict [-hvrw] [-i MIXID] [-a ACCEL] [-p PREC] [-d DLCPU] [-m MODEL]
|
4
|
-
(-k CKPT) [-b BATCH] [-t TSTEPS] INPUT ...
|
5
|
-
|
6
|
-
options:
|
7
|
-
-h, --help
|
8
|
-
-v, --verbose Be verbose.
|
9
|
-
-i MIXID, --mixid MIXID Mixture ID(s) to use if input is a mixture database. [default: *].
|
10
|
-
-a ACCEL, --accelerator ACCEL Accelerator to use in PL trainer in non-reset mode [default: auto]
|
11
|
-
-p PREC, --precision PREC Precision to use in PL trainer in non-reset mode. [default: 32]
|
12
|
-
-d DLCPU, --dataloader-cpus Number of workers/cpus for dataloader. [default: 0]
|
13
|
-
-m MODEL, --model MODEL PL model .py file path.
|
14
|
-
-k CKPT, --checkpoint CKPT PL checkpoint file with weights.
|
15
|
-
-b BATCH, --batch BATCH Batch size (deprecated and forced to 1). [default: 1]
|
16
|
-
-t TSTEPS, --tsteps TSTEPS Timesteps. If 0, dim is not included/expected in model. [default: 0]
|
17
|
-
-r, --reset Reset model between each file.
|
18
|
-
-w, --wavdbg Write debug .wav files of feature input, truth, and predict. [default: False]
|
19
|
-
|
20
|
-
Run PL (Pytorch Lightning) prediction with model and checkpoint input using input data from a
|
21
|
-
SonusAI mixture database.
|
22
|
-
The PL model is imported from MODEL .py file and weights loaded from checkpoint file CKPT.
|
23
|
-
|
24
|
-
Inputs:
|
25
|
-
ACCEL Accelerator used for PL prediction. As of PL v2.0.8: auto, cpu, cuda, hpu, ipu, mps, tpu
|
26
|
-
PREC Precision used in PL prediction. PL trainer will convert model+weights to specified prec.
|
27
|
-
As of PL v2.0.8:
|
28
|
-
('16-mixed', 'bf16-mixed', '32-true', '64-true', 64, 32, 16, '64', '32', '16', 'bf16')
|
29
|
-
MODEL Path to a .py with MyHyperModel PL model class definition
|
30
|
-
CKPT A PL checkpoint file with weights.
|
31
|
-
INPUT The input data must be one of the following:
|
32
|
-
* directory
|
33
|
-
Use SonusAI mixture database directory, generate feature and truth data if not found.
|
34
|
-
Run prediction on the feature. The MIXID is required (or default which is *)
|
35
|
-
|
36
|
-
* Single WAV file or glob of WAV files
|
37
|
-
Using the given model, generate feature data and run prediction. A model file must be
|
38
|
-
provided. The MIXID is ignored.
|
39
|
-
|
40
|
-
Outputs the following to tpredict-<TIMESTAMP> directory:
|
41
|
-
<id>.h5
|
42
|
-
dataset: predict
|
43
|
-
torch_predict.log
|
44
|
-
|
45
|
-
"""
|
46
|
-
from os import makedirs
|
47
|
-
from os.path import basename
|
48
|
-
from os.path import isdir
|
49
|
-
from os.path import join
|
50
|
-
from os.path import normpath
|
51
|
-
from os.path import splitext
|
52
|
-
from typing import Any
|
53
|
-
|
54
|
-
import h5py
|
55
|
-
import torch
|
56
|
-
from docopt import docopt
|
57
|
-
from lightning.pytorch import Trainer
|
58
|
-
from lightning.pytorch.callbacks import BasePredictionWriter
|
59
|
-
from pyaaware import FeatureGenerator
|
60
|
-
from pyaaware import TorchInverseTransform
|
61
|
-
from torchinfo import summary
|
62
|
-
|
63
|
-
import sonusai
|
64
|
-
from sonusai import create_file_handler
|
65
|
-
from sonusai import initial_log_messages
|
66
|
-
from sonusai import logger
|
67
|
-
from sonusai import update_console_handler
|
68
|
-
from sonusai.data_generator import TorchFromMixtureDatabase
|
69
|
-
from sonusai.mixture import Feature
|
70
|
-
from sonusai.mixture import MixtureDatabase
|
71
|
-
from sonusai.mixture import get_audio_from_feature
|
72
|
-
from sonusai.mixture import get_feature_from_audio
|
73
|
-
from sonusai.mixture import read_audio
|
74
|
-
from sonusai.utils import create_ts_name
|
75
|
-
from sonusai.utils import import_keras_model
|
76
|
-
from sonusai.utils import trim_docstring
|
77
|
-
from sonusai.utils import write_wav
|
78
|
-
|
79
|
-
|
80
|
-
class CustomWriter(BasePredictionWriter):
|
81
|
-
def __init__(self, output_dir, write_interval):
|
82
|
-
super().__init__(write_interval)
|
83
|
-
self.output_dir = output_dir
|
84
|
-
|
85
|
-
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
|
86
|
-
# this will create N (num processes) files in `output_dir` each containing
|
87
|
-
# the predictions of its respective rank
|
88
|
-
# torch.save(predictions, os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt"))
|
89
|
-
|
90
|
-
# optionally, you can also save `batch_indices` to get the information about the data index
|
91
|
-
# from your prediction data
|
92
|
-
num_dev = len(batch_indices)
|
93
|
-
logger.debug(f'Num dev: {num_dev}, prediction writer global rank: {trainer.global_rank}')
|
94
|
-
len_pred = len(predictions) # for debug, should be num_dev
|
95
|
-
logger.debug(f'len predictions: {len_pred}, len batch_indices0 {len(batch_indices[0])}')
|
96
|
-
logger.debug(f'Prediction writer batch indices: {batch_indices}')
|
97
|
-
|
98
|
-
logger.info(f'Predictions returned: {len(predictions)}, writing to .h5 files ...')
|
99
|
-
for ndi in range(num_dev): # iterate over list devices (num of batch groups)
|
100
|
-
num_batches = len(batch_indices[ndi]) # num batches in dev
|
101
|
-
for bi in range(num_batches): # iterate over list of batches per dev
|
102
|
-
bsz = len(batch_indices[ndi][bi]) # batch size
|
103
|
-
for di in range(bsz):
|
104
|
-
gid = batch_indices[0][bi][di]
|
105
|
-
# gid = (bgi+1)*bi + bi
|
106
|
-
# gid = bgi + bi
|
107
|
-
logger.debug(f'{ndi}, {bi}, {di}: global id: {gid}')
|
108
|
-
output_name = join(self.output_dir, trainer.predict_dataloaders.dataset.mixdb.mixtures[gid].name)
|
109
|
-
# output_name = join(output_dir, mixdb.mixtures[i].name)
|
110
|
-
pdat = predictions[bi][di, None].cpu().numpy()
|
111
|
-
logger.debug(f'Writing predict shape {pdat.shape} to {output_name}')
|
112
|
-
with h5py.File(output_name, 'a') as f:
|
113
|
-
if 'predict' in f:
|
114
|
-
del f['predict']
|
115
|
-
f.create_dataset('predict', data=pdat)
|
116
|
-
|
117
|
-
# output_name = join(self.output_dir,trainer.predict_dataloaders.dataset.mixdb.mixtures[0].name)
|
118
|
-
# logger.debug(f'Writing predict shape {pdat.shape} to {output_name}')
|
119
|
-
# torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"))
|
120
|
-
|
121
|
-
|
122
|
-
def power_compress(x):
|
123
|
-
real = x[..., 0]
|
124
|
-
imag = x[..., 1]
|
125
|
-
spec = torch.complex(real, imag)
|
126
|
-
mag = torch.abs(spec)
|
127
|
-
phase = torch.angle(spec)
|
128
|
-
mag = mag ** 0.3
|
129
|
-
real_compress = mag * torch.cos(phase)
|
130
|
-
imag_compress = mag * torch.sin(phase)
|
131
|
-
return torch.stack([real_compress, imag_compress], 1)
|
132
|
-
|
133
|
-
|
134
|
-
def power_uncompress(real, imag):
|
135
|
-
spec = torch.complex(real, imag)
|
136
|
-
mag = torch.abs(spec)
|
137
|
-
phase = torch.angle(spec)
|
138
|
-
mag = mag ** (1. / 0.3)
|
139
|
-
real_compress = mag * torch.cos(phase)
|
140
|
-
imag_compress = mag * torch.sin(phase)
|
141
|
-
return torch.stack([real_compress, imag_compress], -1)
|
142
|
-
|
143
|
-
|
144
|
-
def main() -> None:
|
145
|
-
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
146
|
-
|
147
|
-
verbose = args['--verbose']
|
148
|
-
mixids = args['--mixid']
|
149
|
-
accel = args['--accelerator']
|
150
|
-
prec = args['--precision']
|
151
|
-
dlcpu = int(args['--dataloader-cpus'])
|
152
|
-
modelpath = args['--model']
|
153
|
-
ckpt_name = args['--checkpoint']
|
154
|
-
batch_size = args['--batch']
|
155
|
-
timesteps = args['--tsteps']
|
156
|
-
reset = args['--reset']
|
157
|
-
wavdbg = args['--wavdbg'] # write .wav if true
|
158
|
-
input_name = args['INPUT']
|
159
|
-
|
160
|
-
if batch_size is not None:
|
161
|
-
batch_size = int(batch_size)
|
162
|
-
if batch_size != 1:
|
163
|
-
batch_size = 1
|
164
|
-
logger.info(f'For now prediction only supports batch_size = 1, forcing it to 1 now')
|
165
|
-
|
166
|
-
if timesteps is not None:
|
167
|
-
timesteps = int(timesteps)
|
168
|
-
|
169
|
-
if len(input_name) == 1 and isdir(input_name[0]):
|
170
|
-
in_basename = basename(normpath(input_name[0]))
|
171
|
-
else:
|
172
|
-
in_basename = ''
|
173
|
-
|
174
|
-
output_dir = create_ts_name('tpredict-' + in_basename)
|
175
|
-
makedirs(output_dir, exist_ok=True)
|
176
|
-
|
177
|
-
# Setup logging file
|
178
|
-
logger.info(f'Created output subdirectory {output_dir}')
|
179
|
-
create_file_handler(join(output_dir, 'torchl_predict.log'))
|
180
|
-
update_console_handler(verbose)
|
181
|
-
initial_log_messages('torch_predict')
|
182
|
-
logger.info(f'torch {torch.__version__}')
|
183
|
-
|
184
|
-
# Load checkpoint first to get hparams if available
|
185
|
-
try:
|
186
|
-
checkpoint = torch.load(ckpt_name, map_location=lambda storage, loc: storage)
|
187
|
-
except Exception as e:
|
188
|
-
logger.exception(f'Error: could not load checkpoint from {ckpt_name}: {e}')
|
189
|
-
raise SystemExit(1)
|
190
|
-
|
191
|
-
# Import model definition file
|
192
|
-
model_base = basename(modelpath)
|
193
|
-
model_root = splitext(model_base)[0]
|
194
|
-
logger.info(f'Importing {modelpath}')
|
195
|
-
litemodule = import_keras_model(modelpath)
|
196
|
-
|
197
|
-
if 'hyper_parameters' in checkpoint:
|
198
|
-
hparams = checkpoint['hyper_parameters']
|
199
|
-
logger.info(f'Found checkpoint file with hyper-params named {checkpoint["hparams_name"]} '
|
200
|
-
f'with {len(hparams)} total hparams.')
|
201
|
-
if batch_size is not None and hparams['batch_size'] != batch_size:
|
202
|
-
if batch_size != 1:
|
203
|
-
batch_size = 1
|
204
|
-
logger.info(f'For now prediction only supports batch_size = 1, forcing it to 1 now')
|
205
|
-
logger.info(f'Overriding batch_size: default = {hparams["batch_size"]}; specified = {batch_size}.')
|
206
|
-
hparams["batch_size"] = batch_size
|
207
|
-
|
208
|
-
if timesteps is not None:
|
209
|
-
if hparams['timesteps'] == 0 and timesteps != 0:
|
210
|
-
logger.warning(f'Model does not contain timesteps; ignoring override.')
|
211
|
-
timesteps = 0
|
212
|
-
|
213
|
-
if hparams['timesteps'] != 0 and timesteps == 0:
|
214
|
-
logger.warning(f'Model contains timesteps; ignoring override, using model default.')
|
215
|
-
timesteps = hparams['timesteps']
|
216
|
-
|
217
|
-
if hparams['timesteps'] != timesteps:
|
218
|
-
logger.info(f'Overriding timesteps: default = {hparams["timesteps"]}; specified = {timesteps}.')
|
219
|
-
hparams['timesteps'] = timesteps
|
220
|
-
|
221
|
-
logger.info(f'Building model with hparams and batch_size={batch_size}, timesteps={timesteps}')
|
222
|
-
# hparams['cl_per_wght'] = 0.0
|
223
|
-
# hparams['feature'] = 'hum00ns1'
|
224
|
-
try:
|
225
|
-
model = litemodule.MyHyperModel(**hparams) # use hparams
|
226
|
-
# litemodule.MyHyperModel.load_from_checkpoint(ckpt_name, **hparams)
|
227
|
-
except Exception as e:
|
228
|
-
logger.exception(f'Error: model build (MyHyperModel) in {model_base} failed: {e}')
|
229
|
-
raise SystemExit(1)
|
230
|
-
else:
|
231
|
-
logger.info(f'Warning: found checkpoint with no hyper-parameters, building model with defaults')
|
232
|
-
try:
|
233
|
-
tmp = litemodule.MyHyperModel() # use default hparams
|
234
|
-
except Exception as e:
|
235
|
-
logger.exception(f'Error: model build (MyHyperModel) in {model_base} failed: {e}')
|
236
|
-
raise SystemExit(1)
|
237
|
-
|
238
|
-
if batch_size is not None:
|
239
|
-
if tmp.batch_size != batch_size:
|
240
|
-
logger.info(f'Overriding batch_size: default = {tmp.batch_size}; specified = {batch_size}.')
|
241
|
-
else:
|
242
|
-
batch_size = tmp.batch_size # inherit
|
243
|
-
|
244
|
-
if timesteps is not None:
|
245
|
-
if tmp.timesteps == 0 and timesteps != 0:
|
246
|
-
logger.warning(f'Model does not contain timesteps; ignoring override.')
|
247
|
-
timesteps = 0
|
248
|
-
|
249
|
-
if tmp.timesteps != 0 and timesteps == 0:
|
250
|
-
logger.warning(f'Model contains timesteps; ignoring override.')
|
251
|
-
timesteps = tmp.timesteps
|
252
|
-
|
253
|
-
if tmp.timesteps != timesteps:
|
254
|
-
logger.info(f'Overriding timesteps: default = {tmp.timesteps}; specified = {timesteps}.')
|
255
|
-
else:
|
256
|
-
timesteps = tmp.timesteps
|
257
|
-
|
258
|
-
logger.info(f'Building model with default hparams and batch_size= {batch_size}, timesteps={timesteps}')
|
259
|
-
model = litemodule.MyHyperModel(timesteps=timesteps, batch_size=batch_size)
|
260
|
-
|
261
|
-
logger.info('')
|
262
|
-
logger.info(summary(model))
|
263
|
-
logger.info('')
|
264
|
-
logger.info(f'feature {model.hparams.feature}')
|
265
|
-
logger.info(f'num_classes {model.num_classes}')
|
266
|
-
logger.info(f'batch_size {model.hparams.batch_size}')
|
267
|
-
logger.info(f'timesteps {model.hparams.timesteps}')
|
268
|
-
logger.info(f'flatten {model.flatten}')
|
269
|
-
logger.info(f'add1ch {model.add1ch}')
|
270
|
-
logger.info(f'truth_mutex {model.truth_mutex}')
|
271
|
-
logger.info(f'input_shape {model.input_shape}')
|
272
|
-
logger.info('')
|
273
|
-
logger.info(f'Loading weights from {ckpt_name}')
|
274
|
-
# model = model.load_from_checkpoint(ckpt_name) # weights only, needs investigation
|
275
|
-
model.load_state_dict(checkpoint["state_dict"])
|
276
|
-
model.eval()
|
277
|
-
|
278
|
-
logger.info('')
|
279
|
-
# Load mixture database and setup dataloader
|
280
|
-
if len(input_name) == 1 and isdir(input_name[0]): # Single path to mixdb subdir
|
281
|
-
input_name = input_name[0]
|
282
|
-
logger.info(f'Loading mixture database from {input_name}')
|
283
|
-
mixdb = MixtureDatabase(input_name)
|
284
|
-
logger.info(f'Sonusai mixture db: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes')
|
285
|
-
|
286
|
-
if mixdb.feature != model.hparams.feature:
|
287
|
-
logger.warning(f'Feature in mixture database {mixdb.feature} does not match feature in model')
|
288
|
-
# raise SystemExit(1)
|
289
|
-
|
290
|
-
# TBD check num_classes ??
|
291
|
-
|
292
|
-
p_mixids = mixdb.mixids_to_list(mixids)
|
293
|
-
sampler = None
|
294
|
-
p_datagen = TorchFromMixtureDatabase(mixdb=mixdb,
|
295
|
-
mixids=p_mixids,
|
296
|
-
batch_size=model.hparams.batch_size,
|
297
|
-
cut_len=0,
|
298
|
-
flatten=model.flatten,
|
299
|
-
add1ch=model.add1ch,
|
300
|
-
random_cut=False,
|
301
|
-
sampler=sampler,
|
302
|
-
drop_last=False,
|
303
|
-
num_workers=dlcpu)
|
304
|
-
|
305
|
-
# Info needed to set up inverse transform
|
306
|
-
half = model.num_classes // 2
|
307
|
-
fg = FeatureGenerator(feature_mode=model.hparams.feature,
|
308
|
-
num_classes=model.num_classes,
|
309
|
-
truth_mutex=model.truth_mutex)
|
310
|
-
itf = TorchInverseTransform(N=fg.itransform_N,
|
311
|
-
R=fg.itransform_R,
|
312
|
-
bin_start=fg.bin_start,
|
313
|
-
bin_end=fg.bin_end,
|
314
|
-
ttype=fg.itransform_ttype)
|
315
|
-
|
316
|
-
enable_truth_wav = False
|
317
|
-
enable_mix_wav = False
|
318
|
-
if wavdbg:
|
319
|
-
if mixdb.target_files[0].truth_settings[0].function == 'target_mixture_f':
|
320
|
-
enable_mix_wav = True
|
321
|
-
enable_truth_wav = True
|
322
|
-
elif mixdb.target_files[0].truth_settings[0].function == 'target_f':
|
323
|
-
enable_truth_wav = True
|
324
|
-
|
325
|
-
if reset:
|
326
|
-
logger.info(f'Running {mixdb.num_mixtures} mixtures individually with model reset ...')
|
327
|
-
for idx, val in enumerate(p_datagen):
|
328
|
-
# truth = val[1]
|
329
|
-
feature = val[0]
|
330
|
-
with torch.no_grad():
|
331
|
-
ypred = model(feature)
|
332
|
-
output_name = join(output_dir, mixdb.mixtures[idx].name)
|
333
|
-
pdat = ypred.detach().numpy()
|
334
|
-
if timesteps > 0:
|
335
|
-
logger.debug(f'In and out tsteps: {feature.shape[1]},{pdat.shape[1]}')
|
336
|
-
logger.debug(f'Writing predict shape {pdat.shape} to {output_name}')
|
337
|
-
with h5py.File(output_name, 'a') as f:
|
338
|
-
if 'predict' in f:
|
339
|
-
del f['predict']
|
340
|
-
f.create_dataset('predict', data=pdat)
|
341
|
-
|
342
|
-
if wavdbg:
|
343
|
-
owav_base = splitext(output_name)[0]
|
344
|
-
tmp = torch.complex(ypred[..., :half], ypred[..., half:]).permute(2, 0, 1).detach()
|
345
|
-
itf.reset()
|
346
|
-
predwav, _ = itf.execute_all(tmp)
|
347
|
-
# predwav, _ = calculate_audio_from_transform(tmp.numpy(), itf, trim=True)
|
348
|
-
write_wav(owav_base + '.wav', predwav.permute([1, 0]).numpy(), 16000)
|
349
|
-
if enable_truth_wav:
|
350
|
-
# Note this support truth type target_f and target_mixture_f
|
351
|
-
tmp = torch.complex(val[0][..., :half], val[0][..., half:2 * half]).permute(2, 0, 1).detach()
|
352
|
-
itf.reset()
|
353
|
-
truthwav, _ = itf.execute_all(tmp)
|
354
|
-
write_wav(owav_base + '_truth.wav', truthwav.permute([1, 0]).numpy(), 16000)
|
355
|
-
|
356
|
-
if enable_mix_wav:
|
357
|
-
tmp = torch.complex(val[0][..., 2 * half:3 * half], val[0][..., 3 * half:]).permute(2, 0, 1)
|
358
|
-
itf.reset()
|
359
|
-
mixwav, _ = itf.execute_all(tmp.detach())
|
360
|
-
write_wav(owav_base + '_mix.wav', mixwav.permute([1, 0]).numpy(), 16000)
|
361
|
-
|
362
|
-
else:
|
363
|
-
logger.info(f'Running {mixdb.num_mixtures} mixtures with model builtin prediction loop ...')
|
364
|
-
pred_writer = CustomWriter(output_dir=output_dir, write_interval="epoch")
|
365
|
-
trainer = Trainer(default_root_dir=output_dir,
|
366
|
-
callbacks=[pred_writer],
|
367
|
-
precision=prec,
|
368
|
-
devices='auto',
|
369
|
-
accelerator=accel) # prints avail GPU, TPU, IPU, HPU and selected device
|
370
|
-
# trainer = Trainer(default_root_dir=output_dir,
|
371
|
-
# devices='auto',
|
372
|
-
# accelerator='auto') # prints avail GPU, TPU, IPU, HPU and selected device
|
373
|
-
# logger.info(f'Strategy: {trainer.strategy.strategy_name}') # doesn't work for ddp strategy
|
374
|
-
logger.info(f'Accelerator stats: {trainer.accelerator.get_device_stats(device=None)}')
|
375
|
-
logger.info(f'World size: {trainer.world_size}')
|
376
|
-
logger.info(f'Nodes: {trainer.num_nodes}')
|
377
|
-
logger.info(f'Devices: {trainer.accelerator.auto_device_count()}')
|
378
|
-
|
379
|
-
# Use builtin lightning prediction loop, returns a list
|
380
|
-
# predictions = trainer.predict(model, p_datagen) # standard method, but no support distributed
|
381
|
-
with torch.no_grad():
|
382
|
-
trainer.predict(model, p_datagen)
|
383
|
-
# predictions = model.predict_outputs
|
384
|
-
# pred_batch_idx = model.predict_batch_idx
|
385
|
-
# if trainer.world_size > 1:
|
386
|
-
# ddp_max_mem = torch.cuda.max_memory_allocated(trainer.local_rank) / 1000
|
387
|
-
# logger.info(f"GPU {trainer.local_rank} max memory using DDP: {ddp_max_mem:.2f} MB")
|
388
|
-
# if not trainer.is_global_zero:
|
389
|
-
# return
|
390
|
-
# logger.debug(f'type predictions: {type(predictions)}, type batch_idx: {type(pred_batch_idx)}')
|
391
|
-
# logger.debug(f'# predictions: {len(predictions)}, # batch_idx: {len(pred_batch_idx)}')
|
392
|
-
# logger.debug(f'{pred_batch_idx}')
|
393
|
-
# # # all_predictions = torch.cat(predictions) # predictions = torch.cat(predictions).cpu()
|
394
|
-
# # if trainer.world_size > 1:
|
395
|
-
# # # print(f'Predictions returned: {len(all_predictions)}')
|
396
|
-
# # ddp_max_mem = torch.cuda.max_memory_allocated(trainer.local_rank) / 1000
|
397
|
-
# # logger.info(f"GPU {trainer.local_rank} max memory using DDP: {ddp_max_mem:.2f} MB")
|
398
|
-
# # gathered = [None] * torch.distributed.get_world_size()
|
399
|
-
# # torch.distributed.all_gather_object(gathered, predictions)
|
400
|
-
# # torch.distributed.all_gather_object(gathered, pred_batch_idx)
|
401
|
-
# # torch.distributed.barrier()
|
402
|
-
# # if not trainer.is_global_zero:
|
403
|
-
# # return
|
404
|
-
# # predictions = sum(gathered, [])
|
405
|
-
# # if trainer.global_rank == 0:
|
406
|
-
# # logger.info(f"All predictions gathered: {len(predictions)}")
|
407
|
-
#
|
408
|
-
# logger.info(f'Predictions returned: {len(predictions)}, writing to .h5 files ...')
|
409
|
-
# #for idx, mixid in enumerate(p_mixids):
|
410
|
-
# for i in pred_batch_idx: # note assumes batch 0:num_mix matches 0:num_mix in mixdb.mixtures
|
411
|
-
# # print(f'{idx}, {mixid}')
|
412
|
-
# output_name = join(output_dir, mixdb.mixtures[i].name)
|
413
|
-
# pdat = predictions[i].cpu().numpy()
|
414
|
-
# logger.debug(f'Writing predict shape {pdat.shape} to {output_name}')
|
415
|
-
# with h5py.File(output_name, 'a') as f:
|
416
|
-
# if 'predict' in f:
|
417
|
-
# del f['predict']
|
418
|
-
# f.create_dataset('predict', data=pdat)
|
419
|
-
#
|
420
|
-
# if wavdbg:
|
421
|
-
# owav_base = splitext(output_name)[0]
|
422
|
-
# tmp = torch.complex(predictions[idx][..., :half], predictions[idx][..., half:]).permute(2, 1, 0)
|
423
|
-
# predwav, _ = itf.execute_all(tmp.squeeze().detach().numpy())
|
424
|
-
# write_wav(owav_base + ".wav", predwav.detach().numpy(), 16000)
|
425
|
-
|
426
|
-
logger.info(f'Saved results to {output_dir}')
|
427
|
-
return
|
428
|
-
|
429
|
-
# if reset:
|
430
|
-
# # reset mode cycles through each file one at a time
|
431
|
-
# for mixid in mixids:
|
432
|
-
# feature, _ = mixdb.mixture_ft(mixid)
|
433
|
-
# if feature.shape[0] > 2500:
|
434
|
-
# print(f'Trimming input frames from {feature.shape[0]} to {2500},')
|
435
|
-
# feature = feature[0:2500,::]
|
436
|
-
# half = feature.shape[-1] // 2
|
437
|
-
# noisy_spec_cmplx = torch.complex(torch.tensor(feature[..., :half]),
|
438
|
-
# torch.tensor(feature[..., half:])).to(device)
|
439
|
-
# del feature
|
440
|
-
#
|
441
|
-
# predict = _pad_and_predict(built_model=model, feature=noisy_spec_cmplx)
|
442
|
-
# del noisy_spec_cmplx
|
443
|
-
#
|
444
|
-
# audio_est = torch_istft_olsa_hanns(predict, mixdb.it_config.N, mixdb.it_config.R).cpu()
|
445
|
-
# del predict
|
446
|
-
# output_name = join(output_dir, splitext(mixdb.mixtures[mixid].name)[0]+'.wav')
|
447
|
-
# print(f'Saving prediction to {output_name}')
|
448
|
-
# write_wav(name=output_name, audio=float_to_int16(audio_est.detach().numpy()).transpose())
|
449
|
-
#
|
450
|
-
# torch.cuda.empty_cache()
|
451
|
-
#
|
452
|
-
# # TBD .h5 predict file optional output file
|
453
|
-
# # output_name = join(output_dir, mixdb.mixtures[mixid].name)
|
454
|
-
# # with h5py.File(output_name, 'a') as f:
|
455
|
-
# # if 'predict' in f:
|
456
|
-
# # del f['predict']
|
457
|
-
# # f.create_dataset(name='predict', data=predict)
|
458
|
-
#
|
459
|
-
# else:
|
460
|
-
# # Run all data at once using a data generator
|
461
|
-
# feature = KerasFromH5(mixdb=mixdb,
|
462
|
-
# mixids=mixids,
|
463
|
-
# batch_size=hypermodel.batch_size,
|
464
|
-
# timesteps=hypermodel.timesteps,
|
465
|
-
# flatten=hypermodel.flatten,
|
466
|
-
# add1ch=hypermodel.add1ch)
|
467
|
-
#
|
468
|
-
# predict = built_model.predict(feature, batch_size=hypermodel.batch_size, verbose=1)
|
469
|
-
# predict, _ = reshape_outputs(predict=predict, timesteps=hypermodel.timesteps)
|
470
|
-
#
|
471
|
-
# # Write data to separate files
|
472
|
-
# for idx, mixid in enumerate(mixids):
|
473
|
-
# output_name = join(output_dir, mixdb.mixtures[mixid].name)
|
474
|
-
# with h5py.File(output_name, 'a') as f:
|
475
|
-
# if 'predict' in f:
|
476
|
-
# del f['predict']
|
477
|
-
# f.create_dataset('predict', data=predict[feature.file_indices[idx]])
|
478
|
-
#
|
479
|
-
# logger.info(f'Saved results to {output_dir}')
|
480
|
-
# return
|
481
|
-
|
482
|
-
logger.info(f'Run prediction on {len(input_name):,} audio files')
|
483
|
-
for file in input_name:
|
484
|
-
# Convert audio to feature data
|
485
|
-
audio_in = read_audio(file)
|
486
|
-
feature = get_feature_from_audio(audio=audio_in, feature_mode=model.hparams.feature)
|
487
|
-
|
488
|
-
with torch.no_grad():
|
489
|
-
predict = model(torch.tensor(feature))
|
490
|
-
|
491
|
-
audio_out = get_audio_from_feature(feature=predict.numpy(), feature_mode=model.hparams.feature)
|
492
|
-
|
493
|
-
output_name = join(output_dir, splitext(basename(file))[0] + '.h5')
|
494
|
-
with h5py.File(output_name, 'a') as f:
|
495
|
-
if 'audio_in' in f:
|
496
|
-
del f['audio_in']
|
497
|
-
f.create_dataset(name='audio_in', data=audio_in)
|
498
|
-
|
499
|
-
if 'feature' in f:
|
500
|
-
del f['feature']
|
501
|
-
f.create_dataset(name='feature', data=feature)
|
502
|
-
|
503
|
-
if 'predict' in f:
|
504
|
-
del f['predict']
|
505
|
-
f.create_dataset(name='predict', data=predict)
|
506
|
-
|
507
|
-
if 'audio_out' in f:
|
508
|
-
del f['audio_out']
|
509
|
-
f.create_dataset(name='audio_out', data=audio_out)
|
510
|
-
|
511
|
-
output_name = join(output_dir, splitext(basename(file))[0] + '_predict.wav')
|
512
|
-
write_wav(output_name, audio_out, 16000)
|
513
|
-
|
514
|
-
logger.info(f'Saved results to {output_dir}')
|
515
|
-
del model
|
516
|
-
|
517
|
-
|
518
|
-
def _pad_and_predict(built_model: Any, feature: Feature) -> torch.Tensor:
|
519
|
-
"""
|
520
|
-
Run prediction on feature [frames,1,bins*2] (stacked complex numpy array, stride/tsteps=1)
|
521
|
-
Returns predict output [batch,frames,bins] in complex torch.tensor
|
522
|
-
"""
|
523
|
-
noisy_spec = power_compress(torch.view_as_real(torch.from_numpy(feature).permute(1, 0, 2)))
|
524
|
-
# print(f'noisy_spec type {type(noisy_spec_cmplx)}')
|
525
|
-
# print(f'noisy_spec dtype {noisy_spec_cmplx.dtype}')
|
526
|
-
# print(f'noisy_spec size {noisy_spec_cmplx.shape}')
|
527
|
-
with torch.no_grad():
|
528
|
-
est_real, est_imag = built_model(noisy_spec) # expects in size [batch, 2, tsteps, bins]
|
529
|
-
est_real, est_imag = est_real.permute(0, 1, 3, 2), est_imag.permute(0, 1, 3, 2)
|
530
|
-
est_spec_uncompress = torch.view_as_complex(power_uncompress(est_real, est_imag).squeeze(1))
|
531
|
-
# inv tf want [ch,frames,bins] complex (synonymous with [batch,tsteps,bins]), keep as torch.tensor
|
532
|
-
predict = est_spec_uncompress.permute(0, 2, 1) # .detach().numpy()
|
533
|
-
|
534
|
-
return predict
|
535
|
-
|
536
|
-
|
537
|
-
if __name__ == '__main__':
|
538
|
-
try:
|
539
|
-
main()
|
540
|
-
except KeyboardInterrupt:
|
541
|
-
logger.info('Canceled due to keyboard interrupt')
|
542
|
-
exit()
|