sonusai 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,508 @@
1
+ """sonusai ovino_predict
2
+
3
+ usage: ovino_predict [-hvlwr] [--include GLOB] [-i MIXID] MODEL DATA ...
4
+
5
+ options:
6
+ -h, --help
7
+ -v, --verbose Be verbose.
8
+ -l, --list-device-details List details of all OpenVINO available devices
9
+ -i MIXID, --mixid MIXID Mixture ID(s) to generate if input is a mixture database. [default: *].
10
+ --include GLOB Search only files whose base name matches GLOB. [default: *.{wav,flac}].
11
+ -w, --write-wav Calculate inverse transform of prediction and write .wav files
12
+ -r, --reset Reset model between each file.
13
+
14
+
15
+ Run prediction (inference) using an OpenVino model on a SonusAI mixture dataset or audio files from a regex path.
16
+
17
+ Inputs:
18
+ MODEL OpenVINO model .xml file (called the ov.Model format), requires accompanying .bin file
19
+ to exist in same location with the same base filename.
20
+
21
+ DATA The input data must be one of the following:
22
+ * WAV
23
+ Using the given model, generate feature data and run prediction. A model file must be
24
+ provided. The MIXID is ignored.
25
+
26
+ * directory
27
+ Using the given SonusAI mixture database directory, generate feature and truth data if not found.
28
+ Run prediction. The MIXID is required.
29
+
30
+
31
+ Note there are multiple ways to process model prediction over multiple audio data files:
32
+ 1. TSE (timestep single extension): mixture transform frames are fit into the timestep dimension and the model run as
33
+ a single inference call. If batch_size is > 1 then run multiple mixtures in one call with shorter mixtures
34
+ zero-padded to the size of the largest mixture.
35
+ 2. BSE (batch single extension): mixture transform frame are fit into the batch dimension. This is possible only if
36
+ timesteps=1 or there is no timestep dimension in the model (timesteps=0).
37
+ 3.
38
+
39
+ Outputs the following to ovpredict-<TIMESTAMP> directory:
40
+ <id>.h5
41
+ dataset: predict
42
+ onnx_predict.log
43
+
44
+ """
45
+
46
+ from sonusai import logger
47
+
48
+ def param_to_string(parameters) -> str:
49
+ """Convert a list / tuple of parameters returned from OV to a string."""
50
+ if isinstance(parameters, (list, tuple)):
51
+ return ', '.join([str(x) for x in parameters])
52
+ else:
53
+ return str(parameters)
54
+
55
+
56
+ def openvino_list_device_details(core):
57
+ logger.info('Requested details of OpenVINO available devices:')
58
+ for device in core.available_devices:
59
+ logger.info(f'{device} :')
60
+ logger.info('\tSUPPORTED_PROPERTIES:')
61
+ for property_key in core.get_property(device, 'SUPPORTED_PROPERTIES'):
62
+ if property_key not in ('SUPPORTED_PROPERTIES'):
63
+ try:
64
+ property_val = core.get_property(device, property_key)
65
+ except TypeError:
66
+ property_val = 'UNSUPPORTED TYPE'
67
+ logger.info(f'\t\t{property_key}: {param_to_string(property_val)}')
68
+ logger.info('')
69
+
70
+
71
+
72
+ def main() -> None:
73
+ from docopt import docopt
74
+
75
+ import sonusai
76
+ from sonusai.utils import trim_docstring
77
+ import openvino as ov
78
+
79
+ args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
80
+
81
+ verbose = args['--verbose']
82
+ listdd = args['--list-device-details']
83
+ writewav= args['--write-wav']
84
+ mixids = args['--mixid']
85
+ reset = args['--reset']
86
+ include = args['--include']
87
+ model_name = args['MODEL']
88
+ datapaths = args['DATA']
89
+
90
+
91
+ core = ov.Core() # Create runtime
92
+ avail_devices = core.available_devices # simple report of available devices
93
+ logger.info(f'Loaded OpenVINO runtime, available devices: {avail_devices}.')
94
+ if listdd is True: # print
95
+ openvino_list_device_details(core)
96
+
97
+ from os.path import abspath, join, realpath, basename, isdir, normpath, splitext
98
+ from sonusai.utils.asr_manifest_functions import PathInfo
99
+ from sonusai.utils import braced_iglob
100
+ from sonusai.mixture import MixtureDatabase
101
+ from typing import Any
102
+
103
+ mixdb_path = None
104
+ entries = None
105
+ if len(datapaths) == 1 and isdir(datapaths[0]): # Assume it's a single path to sonusai mixdb subdir
106
+ in_basename = basename(normpath(datapaths[0]))
107
+ mixdb_path= datapaths[0]
108
+ logger.debug(f'Attempting to load mixture database from {mixdb_path}')
109
+ mixdb = MixtureDatabase(mixdb_path)
110
+ logger.debug(f'Sonusai mixture db load success: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes')
111
+ p_mixids = mixdb.mixids_to_list(mixids)
112
+ if len(p_mixids) != mixdb.num_mixtures:
113
+ logger.info(f'Processing a subset of {p_mixids} from available mixtures.')
114
+ else: # search all datapaths for .wav, .flac (or whatever is specified in include)
115
+ in_basename = ''
116
+ entries: list[PathInfo] = []
117
+ for p in datapaths:
118
+ location = join(realpath(abspath(p)), '**', include)
119
+ logger.debug(f'Processing {location}')
120
+ for file in braced_iglob(pathname=location, recursive=True):
121
+ name = file
122
+ entries.append(PathInfo(abs_path=file, audio_filepath=name))
123
+
124
+ logger.info(f'Reading and compiling model from {model_name}.')
125
+ compiled_model = core.compile_model(model_name, "AUTO")
126
+ logger.info(f'Compiled model using default OpenVino compile settings.')
127
+
128
+ from sonusai.utils import create_ts_name
129
+ from os import makedirs
130
+ from sonusai import create_file_handler
131
+ from sonusai import initial_log_messages
132
+ from sonusai import update_console_handler
133
+ output_dir = create_ts_name('ovpredict-' + in_basename)
134
+ makedirs(output_dir, exist_ok=True)
135
+ # Setup logging file
136
+ create_file_handler(join(output_dir, 'ovino_predict.log'))
137
+ update_console_handler(verbose)
138
+ initial_log_messages('ovino_predict')
139
+ logger.info(f'Read and compiled OpenVINO model from {model_name}.')
140
+ if len(datapaths) == 1 and isdir(datapaths[0]): # Assume it's a single path to sonusai mixdb subdir
141
+ logger.info(f'Loaded mixture database from {datapaths}')
142
+ logger.info(f'Sonusai mixture db: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes')
143
+ else:
144
+ logger.info(f'{len(datapaths)} data paths specified, found {len(entries)} audio files.')
145
+ if listdd is True: # print
146
+ openvino_list_device_details(core)
147
+
148
+ if len(compiled_model.inputs) != 1 or len(compiled_model.outputs) != 1:
149
+ logger.warning(f'Model has incorrect i/o, expected 1,1 but got {len(compiled_model.inputs)}, {len(compiled_model.outputs)}.')
150
+ logger.warning(f'Using the first input and output.')
151
+ else:
152
+ logger.info(f'Model has 1 input and 1 output as expected.')
153
+
154
+ inp0shape = compiled_model.inputs[0].partial_shape #compiled_model.inputs[0].shape # fails on dynamic
155
+ # out0shape = compiled_model.outputs[0].shape # fails onb dynamic
156
+ logger.info(f'Model input0 shape: {inp0shape}')
157
+ # model_num_classes = inp0shape[-1]
158
+ # logger.info(f'Model input0 shape {inp0shape}')
159
+ # logger.info(f'Model output0 shape {out0shape}')
160
+
161
+
162
+ # Check/calculate key Sonusai hyper-params
163
+ # model_num_classes = inp0shape[-1]
164
+ # batch_size = inp0shape[0] # note batch dim can be used for frames dim if no timestep dimension
165
+ # if batch_size != 1:
166
+ # logger.warning(f'Model batch size is not 1, but {batch_size}.')
167
+ #
168
+ # if len(inp0shape) > 2:
169
+ # if len(inp0shape) == 3:
170
+ # tsteps = inp0shape[1] # note tstep dim can be used for frames dim if it exists
171
+ # else:
172
+ # logger.debug(f'Model input has more than 3 dims, assuming timestep dimension is 0 (does not exist).')
173
+ # tsteps = 0
174
+ # else:
175
+ # logger.debug(f'Model input has 2 dims, timestep dimension is 0 (does not exist).')
176
+ # tsteps = 0
177
+ #
178
+ # flattened = True # TBD get from model
179
+ # add1ch = False
180
+
181
+ if mixdb_path is not None: # mixdb input
182
+ # Assume (of course) that mixdb feature, etc. is what model expects
183
+ feature_mode = mixdb.feature
184
+ # if mixdb.num_classes != model_num_classes:
185
+ # logger.error(f'Feature parameters in mixture db {mixdb.num_classes} does not match num_classes in model {inp0shape[-1]}')
186
+ # raise SystemExit(1)
187
+ # TBD add custom parameters in OpenVino model file? Then add feature, tsteps, truth mutex
188
+
189
+ # note simple compile and prediction run:
190
+ # compiled_model = ov.compile_model(ov_model, "AUTO") # compile model
191
+ import numpy as np
192
+ inp_sample_np = np.ones([1,1,402], dtype=np.single)
193
+ #inp_sample_t = torch.rand(1, 1, 402)
194
+ shared_in = ov.Tensor(array=inp_sample_np, shared_memory=True) # Create tensor, external memory, from numpy
195
+ infer_request = compiled_model.create_infer_request()
196
+ infer_request.set_input_tensor(shared_in) # Set input tensor for model with one input
197
+ input_tensor = infer_request.get_input_tensor() # for debug
198
+ output_tensor = infer_request.get_output_tensor() # for debug
199
+
200
+ infer_request.start_async()
201
+ infer_request.wait()
202
+ output = infer_request.get_output_tensor() # Get output tensor for model with one output
203
+ output_buffer = output.data # output_buffer[] - accessing output tensor data
204
+
205
+
206
+
207
+
208
+ from sonusai.utils import reshape_inputs
209
+ from sonusai.utils import reshape_outputs
210
+ from sonusai.mixture import get_audio_from_feature
211
+ from sonusai.utils import write_wav
212
+ import h5py
213
+ for mixid in p_mixids:
214
+ feature, _ = mixdb.mixture_ft(mixid) # frames x stride x feature_params
215
+ feature, _ = reshape_inputs(feature=feature,
216
+ batch_size=1,
217
+ timesteps=feature.shape[0],
218
+ flatten=flattened,
219
+ add1ch=add1ch)
220
+
221
+ predict = compiled_model(feature) # run inference, model wants i.e. batch x tsteps x feat_params
222
+ #TBD convert to numpy
223
+ predict, _ = reshape_outputs(predict=predict, timesteps=tsteps)
224
+ output_name = join(output_dir, mixdb.mixtures[mixid].name)
225
+ with h5py.File(output_name, 'a') as f:
226
+ if 'predict' in f:
227
+ del f['predict']
228
+ f.create_dataset('predict', data=predict)
229
+ if writewav:
230
+ predict_audio = get_audio_from_feature(feature=predict, feature_mode=feature_mode)
231
+ write_wav()
232
+
233
+
234
+
235
+ # sampler = None
236
+ # p_datagen = TorchFromMixtureDatabase(mixdb=mixdb,
237
+ # mixids=p_mixids,
238
+ # batch_size=batch_size,
239
+ # cut_len=0,
240
+ # flatten=model.flatten,
241
+ # add1ch=model.add1ch,
242
+ # random_cut=False,
243
+ # sampler=sampler,
244
+ # drop_last=False,
245
+ # num_workers=dlcpu)
246
+
247
+ # Info needed to set up inverse transform
248
+ half = model.num_classes // 2
249
+ fg = FeatureGenerator(feature_mode=feature,
250
+ num_classes=model.num_classes,
251
+ truth_mutex=model.truth_mutex)
252
+ itf = TorchInverseTransform(N=fg.itransform_N,
253
+ R=fg.itransform_R,
254
+ bin_start=fg.bin_start,
255
+ bin_end=fg.bin_end,
256
+ ttype=fg.itransform_ttype)
257
+
258
+ enable_truth_wav = False
259
+ enable_mix_wav = False
260
+ if wavdbg:
261
+ if mixdb.target_files[0].truth_settings[0].function == 'target_mixture_f':
262
+ enable_mix_wav = True
263
+ enable_truth_wav = True
264
+ elif mixdb.target_files[0].truth_settings[0].function == 'target_f':
265
+ enable_truth_wav = True
266
+
267
+ if reset:
268
+ logger.info(f'Running {mixdb.num_mixtures} mixtures individually with model reset ...')
269
+ for idx, val in enumerate(p_datagen):
270
+ # truth = val[1]
271
+ feature = val[0]
272
+ with torch.no_grad():
273
+ ypred = model(feature)
274
+ output_name = join(output_dir, mixdb.mixtures[idx].name)
275
+ pdat = ypred.detach().numpy()
276
+ if timesteps > 0:
277
+ logger.debug(f'In and out tsteps: {feature.shape[1]},{pdat.shape[1]}')
278
+ logger.debug(f'Writing predict shape {pdat.shape} to {output_name}')
279
+ with h5py.File(output_name, 'a') as f:
280
+ if 'predict' in f:
281
+ del f['predict']
282
+ f.create_dataset('predict', data=pdat)
283
+
284
+ if wavdbg:
285
+ owav_base = splitext(output_name)[0]
286
+ tmp = torch.complex(ypred[..., :half], ypred[..., half:]).permute(2, 0, 1).detach()
287
+ itf.reset()
288
+ predwav, _ = itf.execute_all(tmp)
289
+ # predwav, _ = calculate_audio_from_transform(tmp.numpy(), itf, trim=True)
290
+ write_wav(owav_base + '.wav', predwav.permute([1, 0]).numpy(), 16000)
291
+ if enable_truth_wav:
292
+ # Note this support truth type target_f and target_mixture_f
293
+ tmp = torch.complex(val[0][..., :half], val[0][..., half:2 * half]).permute(2, 0, 1).detach()
294
+ itf.reset()
295
+ truthwav, _ = itf.execute_all(tmp)
296
+ write_wav(owav_base + '_truth.wav', truthwav.permute([1, 0]).numpy(), 16000)
297
+
298
+ if enable_mix_wav:
299
+ tmp = torch.complex(val[0][..., 2 * half:3 * half], val[0][..., 3 * half:]).permute(2, 0, 1)
300
+ itf.reset()
301
+ mixwav, _ = itf.execute_all(tmp.detach())
302
+ write_wav(owav_base + '_mix.wav', mixwav.permute([1, 0]).numpy(), 16000)
303
+
304
+
305
+
306
+
307
+
308
+
309
+
310
+
311
+
312
+
313
+
314
+
315
+
316
+
317
+
318
+
319
+
320
+ from os import makedirs
321
+ from os.path import isdir
322
+ from os.path import join
323
+ from os.path import splitext
324
+
325
+ import h5py
326
+ import onnxruntime as rt
327
+ import numpy as np
328
+
329
+ from sonusai.mixture import Feature
330
+ from sonusai.mixture import Predict
331
+ from sonusai.utils import SonusAIMetaData
332
+ from sonusai import create_file_handler
333
+ from sonusai import initial_log_messages
334
+ from sonusai import update_console_handler
335
+ from sonusai.mixture import MixtureDatabase
336
+ from sonusai.mixture import get_feature_from_audio
337
+ from sonusai.mixture import read_audio
338
+ from sonusai.utils import create_ts_name
339
+ from sonusai.utils import get_frames_per_batch
340
+ from sonusai.utils import get_sonusai_metadata
341
+
342
+ output_dir = create_ts_name('ovpredict')
343
+ makedirs(output_dir, exist_ok=True)
344
+
345
+
346
+
347
+ model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
348
+ model_metadata = get_sonusai_metadata(model)
349
+
350
+ batch_size = model_metadata.input_shape[0]
351
+ if model_metadata.timestep:
352
+ timesteps = model_metadata.input_shape[1]
353
+ else:
354
+ timesteps = 0
355
+ num_classes = model_metadata.output_shape[-1]
356
+
357
+ frames_per_batch = get_frames_per_batch(batch_size, timesteps)
358
+
359
+ logger.info('')
360
+ logger.info(f'feature {model_metadata.feature}')
361
+ logger.info(f'num_classes {num_classes}')
362
+ logger.info(f'batch_size {batch_size}')
363
+ logger.info(f'timesteps {timesteps}')
364
+ logger.info(f'flatten {model_metadata.flattened}')
365
+ logger.info(f'add1ch {model_metadata.channel}')
366
+ logger.info(f'truth_mutex {model_metadata.mutex}')
367
+ logger.info(f'input_shape {model_metadata.input_shape}')
368
+ logger.info(f'output_shape {model_metadata.output_shape}')
369
+ logger.info('')
370
+
371
+ if splitext(entries)[1] == '.wav':
372
+ # Convert WAV to feature data
373
+ logger.info('')
374
+ logger.info(f'Run prediction on {entries}')
375
+ audio = read_audio()
376
+ feature = get_feature_from_audio(audio=audio, feature_mode=model_metadata.feature)
377
+
378
+ predict = pad_and_predict(feature=feature,
379
+ model_name=model_name,
380
+ model_metadata=model_metadata,
381
+ frames_per_batch=frames_per_batch,
382
+ batch_size=batch_size,
383
+ timesteps=timesteps,
384
+ reset=reset)
385
+
386
+ output_name = splitext()[0] + '.h5'
387
+ with h5py.File(output_name, 'a') as f:
388
+ if 'feature' in f:
389
+ del f['feature']
390
+ f.create_dataset(name='feature', data=feature)
391
+
392
+ if 'predict' in f:
393
+ del f['predict']
394
+ f.create_dataset(name='predict', data=predict)
395
+
396
+ logger.info(f'Saved results to {output_name}')
397
+ return
398
+
399
+ if not isdir():
400
+ logger.exception(f'Do not know how to process input from {entries}')
401
+ raise SystemExit(1)
402
+
403
+ mixdb = MixtureDatabase()
404
+
405
+ if mixdb.feature != model_metadata.feature:
406
+ logger.exception(f'Feature in mixture database does not match feature in model')
407
+ raise SystemExit(1)
408
+
409
+ mixids = mixdb.mixids_to_list(mixids)
410
+ if reset:
411
+ # reset mode cycles through each file one at a time
412
+ for mixid in mixids:
413
+ feature, _ = mixdb.mixture_ft(mixid)
414
+
415
+ predict = pad_and_predict(feature=feature,
416
+ model_name=model_name,
417
+ model_metadata=model_metadata,
418
+ frames_per_batch=frames_per_batch,
419
+ batch_size=batch_size,
420
+ timesteps=timesteps,
421
+ reset=reset)
422
+
423
+ output_name = join(output_dir, mixdb.mixtures[mixid].name)
424
+ with h5py.File(output_name, 'a') as f:
425
+ if 'predict' in f:
426
+ del f['predict']
427
+ f.create_dataset(name='predict', data=predict)
428
+ else:
429
+ features: list[Feature] = []
430
+ file_indices: list[slice] = []
431
+ total_frames = 0
432
+ for mixid in mixids:
433
+ current_feature, _ = mixdb.mixture_ft(mixid)
434
+ current_frames = current_feature.shape[0]
435
+ features.append(current_feature)
436
+ file_indices.append(slice(total_frames, total_frames + current_frames))
437
+ total_frames += current_frames
438
+ feature = np.vstack([features[i] for i in range(len(features))])
439
+
440
+ predict = pad_and_predict(feature=feature,
441
+ model_name=model_name,
442
+ model_metadata=model_metadata,
443
+ frames_per_batch=frames_per_batch,
444
+ batch_size=batch_size,
445
+ timesteps=timesteps,
446
+ reset=reset)
447
+
448
+ # Write data to separate files
449
+ for idx, mixid in enumerate(mixids):
450
+ output_name = join(output_dir, mixdb.mixtures[mixid].name)
451
+ with h5py.File(output_name, 'a') as f:
452
+ if 'predict' in f:
453
+ del f['predict']
454
+ f.create_dataset('predict', data=predict[file_indices[idx]])
455
+
456
+ logger.info(f'Saved results to {output_dir}')
457
+
458
+
459
+ # def pad_and_predict(feature: Feature,
460
+ # model_name: str,
461
+ # model_metadata: SonusAIMetaData,
462
+ # frames_per_batch: int,
463
+ # batch_size: int,
464
+ # timesteps: int,
465
+ # reset: bool) -> Predict:
466
+ # import onnxruntime as rt
467
+ # import numpy as np
468
+ #
469
+ # from sonusai.utils import reshape_inputs
470
+ # from sonusai.utils import reshape_outputs
471
+ #
472
+ # frames = feature.shape[0]
473
+ # padding = frames_per_batch - frames % frames_per_batch
474
+ # feature = np.pad(array=feature, pad_width=((0, padding), (0, 0), (0, 0)))
475
+ # feature, _ = reshape_inputs(feature=feature,
476
+ # batch_size=batch_size,
477
+ # timesteps=timesteps,
478
+ # flatten=model_metadata.flattened,
479
+ # add1ch=model_metadata.channel)
480
+ # sequences = feature.shape[0] // model_metadata.input_shape[0]
481
+ # feature = np.reshape(feature, [sequences, *model_metadata.input_shape])
482
+ #
483
+ # model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
484
+ # output_names = [n.name for n in model.get_outputs()]
485
+ # input_names = [n.name for n in model.get_inputs()]
486
+ #
487
+ # predict = []
488
+ # for sequence in range(sequences):
489
+ # predict.append(model.run(output_names, {input_names[0]: feature[sequence]}))
490
+ # if reset:
491
+ # model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
492
+ #
493
+ # predict_arr = np.vstack(predict)
494
+ # # Combine [sequences, batch_size, ...] into [frames, ...]
495
+ # predict_shape = predict_arr.shape
496
+ # predict_arr = np.reshape(predict_arr, [predict_shape[0] * predict_shape[1], *predict_shape[2:]])
497
+ # predict_arr, _ = reshape_outputs(predict=predict_arr, timesteps=timesteps)
498
+ # predict_arr = predict_arr[:frames, :]
499
+ #
500
+ # return predict_arr
501
+
502
+
503
+ if __name__ == '__main__':
504
+ try:
505
+ main()
506
+ except KeyboardInterrupt:
507
+ logger.info('Canceled due to keyboard interrupt')
508
+ raise SystemExit(0)
@@ -0,0 +1,47 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (C) 2018-2024 Intel Corporation
4
+ # SPDX-License-Identifier: Apache-2.0
5
+ import logging as log
6
+ import sys
7
+
8
+ import openvino as ov
9
+
10
+
11
+ def param_to_string(parameters) -> str:
12
+ """Convert a list / tuple of parameters returned from OV to a string."""
13
+ if isinstance(parameters, (list, tuple)):
14
+ return ', '.join([str(x) for x in parameters])
15
+ else:
16
+ return str(parameters)
17
+
18
+
19
+ def main():
20
+ log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
21
+
22
+ # --------------------------- Step 1. Initialize OpenVINO Runtime Core --------------------------------------------
23
+ core = ov.Core()
24
+
25
+ # --------------------------- Step 2. Get metrics of available devices --------------------------------------------
26
+ log.info('Available devices:')
27
+ for device in core.available_devices:
28
+ log.info(f'{device} :')
29
+ log.info('\tSUPPORTED_PROPERTIES:')
30
+ for property_key in core.get_property(device, 'SUPPORTED_PROPERTIES'):
31
+ if property_key not in ('SUPPORTED_PROPERTIES'):
32
+ try:
33
+ property_val = core.get_property(device, property_key)
34
+ except TypeError:
35
+ property_val = 'UNSUPPORTED TYPE'
36
+ log.info(f'\t\t{property_key}: {param_to_string(property_val)}')
37
+ log.info('')
38
+
39
+ # -----------------------------------------------------------------------------------------------------------------
40
+ return 0
41
+
42
+
43
+ if __name__ == '__main__':
44
+ sys.exit(main())
45
+
46
+
47
+
sonusai/plot.py CHANGED
@@ -41,16 +41,29 @@ Outputs:
41
41
 
42
42
  """
43
43
 
44
+ import signal
45
+
44
46
  import numpy as np
45
47
  from matplotlib import pyplot as plt
46
48
 
47
- from sonusai import logger
48
49
  from sonusai.mixture import AudioT
49
50
  from sonusai.mixture import Feature
50
51
  from sonusai.mixture import Predict
51
52
  from sonusai.mixture import Truth
52
53
 
53
54
 
55
+ def signal_handler(_sig, _frame):
56
+ import sys
57
+
58
+ from sonusai import logger
59
+
60
+ logger.info('Canceled due to keyboard interrupt')
61
+ sys.exit(1)
62
+
63
+
64
+ signal.signal(signal.SIGINT, signal_handler)
65
+
66
+
54
67
  def spec_plot(mixture: AudioT,
55
68
  feature: Feature,
56
69
  predict: Predict = None,
@@ -264,6 +277,7 @@ def main() -> None:
264
277
  from sonusai import SonusAIError
265
278
  from sonusai import create_file_handler
266
279
  from sonusai import initial_log_messages
280
+ from sonusai import logger
267
281
  from sonusai import update_console_handler
268
282
  from sonusai.mixture import MixtureDatabase
269
283
  from sonusai.mixture import FeatureGeneratorConfig
@@ -457,8 +471,4 @@ def main() -> None:
457
471
 
458
472
 
459
473
  if __name__ == '__main__':
460
- try:
461
- main()
462
- except KeyboardInterrupt:
463
- logger.info('Canceled due to keyboard interrupt')
464
- raise SystemExit(0)
474
+ main()
@@ -20,9 +20,20 @@ Outputs the following to post_spenh_targetf-<TIMESTAMP> directory:
20
20
  post_spenh_targetf.log
21
21
 
22
22
  """
23
+ import signal
23
24
  from dataclasses import dataclass
24
25
 
25
- from sonusai import logger
26
+
27
+ def signal_handler(_sig, _frame):
28
+ import sys
29
+
30
+ from sonusai import logger
31
+
32
+ logger.info('Canceled due to keyboard interrupt')
33
+ sys.exit(1)
34
+
35
+
36
+ signal.signal(signal.SIGINT, signal_handler)
26
37
 
27
38
 
28
39
  @dataclass
@@ -146,8 +157,4 @@ def _process(file: str) -> None:
146
157
 
147
158
 
148
159
  if __name__ == '__main__':
149
- try:
150
- main()
151
- except KeyboardInterrupt:
152
- logger.info('Canceled due to keyboard interrupt')
153
- exit()
160
+ main()