sonusai 0.15.8__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. sonusai/__init__.py +35 -4
  2. sonusai/audiofe.py +237 -0
  3. sonusai/calc_metric_spenh.py +21 -12
  4. sonusai/genft.py +2 -1
  5. sonusai/genmixdb.py +5 -5
  6. sonusai/lsdb.py +2 -2
  7. sonusai/main.py +58 -61
  8. sonusai/mixture/__init__.py +4 -2
  9. sonusai/mixture/audio.py +0 -34
  10. sonusai/mixture/config.py +1 -2
  11. sonusai/mixture/datatypes.py +1 -1
  12. sonusai/mixture/feature.py +75 -21
  13. sonusai/mixture/helpers.py +60 -30
  14. sonusai/mixture/log_duration_and_sizes.py +2 -2
  15. sonusai/mixture/mixdb.py +13 -10
  16. sonusai/mixture/spectral_mask.py +14 -14
  17. sonusai/mixture/truth_functions/data.py +1 -1
  18. sonusai/mixture/truth_functions/target.py +2 -2
  19. sonusai/mkmanifest.py +29 -2
  20. sonusai/onnx_predict.py +1 -1
  21. sonusai/plot.py +4 -4
  22. sonusai/post_spenh_targetf.py +8 -8
  23. sonusai/utils/__init__.py +8 -7
  24. sonusai/utils/asl_p56.py +3 -3
  25. sonusai/utils/asr.py +35 -8
  26. sonusai/utils/asr_functions/__init__.py +0 -5
  27. sonusai/utils/asr_functions/aaware_whisper.py +2 -2
  28. sonusai/utils/asr_manifest_functions/__init__.py +1 -0
  29. sonusai/utils/asr_manifest_functions/mcgill_speech.py +29 -0
  30. sonusai/utils/audio_devices.py +41 -0
  31. sonusai/utils/calculate_input_shape.py +3 -4
  32. sonusai/utils/create_timestamp.py +5 -0
  33. sonusai/utils/{trim_docstring.py → docstring.py} +20 -0
  34. sonusai/utils/model_utils.py +30 -0
  35. sonusai/utils/onnx_utils.py +19 -45
  36. sonusai/utils/reshape.py +11 -11
  37. sonusai/utils/wave.py +12 -5
  38. {sonusai-0.15.8.dist-info → sonusai-0.16.0.dist-info}/METADATA +8 -19
  39. {sonusai-0.15.8.dist-info → sonusai-0.16.0.dist-info}/RECORD +41 -54
  40. {sonusai-0.15.8.dist-info → sonusai-0.16.0.dist-info}/WHEEL +1 -1
  41. sonusai/data_generator/__init__.py +0 -5
  42. sonusai/data_generator/dataset_from_mixdb.py +0 -143
  43. sonusai/data_generator/keras_from_mixdb.py +0 -169
  44. sonusai/data_generator/torch_from_mixdb.py +0 -122
  45. sonusai/evaluate.py +0 -245
  46. sonusai/keras_onnx.py +0 -86
  47. sonusai/keras_predict.py +0 -231
  48. sonusai/keras_train.py +0 -334
  49. sonusai/torchl_onnx.py +0 -216
  50. sonusai/torchl_predict.py +0 -547
  51. sonusai/torchl_train.py +0 -223
  52. sonusai/utils/asr_functions/aixplain_whisper.py +0 -59
  53. sonusai/utils/asr_functions/data.py +0 -16
  54. sonusai/utils/asr_functions/deepgram.py +0 -97
  55. sonusai/utils/asr_functions/fastwhisper.py +0 -90
  56. sonusai/utils/asr_functions/google.py +0 -95
  57. sonusai/utils/asr_functions/whisper.py +0 -49
  58. sonusai/utils/keras_utils.py +0 -226
  59. {sonusai-0.15.8.dist-info → sonusai-0.16.0.dist-info}/entry_points.txt +0 -0
sonusai/evaluate.py DELETED
@@ -1,245 +0,0 @@
1
- """sonusai evaluate
2
-
3
- usage: evaluate [-hv] [-i MIXID] (-f FEATURE) (-p PREDICT) [-t PTHR] LOC
4
-
5
- options:
6
- -h, --help
7
- -v, --verbose Be verbose.
8
- -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
- -p PREDICT, --predict PREDICT A directory containing prediction data.
10
- -t PTHR, --thr PTHR Optional prediction decision threshold(s). [default: 0].
11
-
12
- Evaluate calculates performance metrics of neural-network models from model prediction data and genft data.
13
-
14
- Inputs:
15
- LOC A SonusAI mixture database directory.
16
- MIXID A glob of mixture ID(s) to generate.
17
- PREDICT A directory containing SonusAI predict HDF5 files. Contains:
18
- dataset: predict (either [frames, num_classes] or [frames, timesteps, num_classes])
19
- PTHR Scalar or array of thresholds. Default 0 will select values:
20
- argmax() if mixdb indicates single-label mode (truth_mutex = true)
21
- 0.5 if mixdb indicates multi-label mode (truth_mutex = false)
22
- If PTHR = -1, optimal thresholds are calculated using precision_recall_curve() which
23
- optimizes F1 score.
24
- """
25
- import numpy as np
26
-
27
- from sonusai import logger
28
- from sonusai.mixture import Feature
29
- from sonusai.mixture import MixtureDatabase
30
- from sonusai.mixture import Predict
31
- from sonusai.mixture import Segsnr
32
- from sonusai.mixture import Truth
33
-
34
-
35
- def evaluate(mixdb: MixtureDatabase,
36
- truth: Truth,
37
- predict: Predict = None,
38
- segsnr: Segsnr = None,
39
- output_dir: str = None,
40
- predict_thr: float | np.ndarray = 0,
41
- feature: Feature = None,
42
- verbose: bool = False) -> None:
43
- from os.path import join
44
-
45
- from sonusai import initial_log_messages
46
- from sonusai import update_console_handler
47
- from sonusai.metrics import calc_optimal_thresholds
48
- from sonusai.metrics import class_summary
49
- from sonusai.metrics import snr_summary
50
- from sonusai.mixture import SAMPLE_RATE
51
- from sonusai.queries import get_mixids_from_snr
52
- from sonusai.utils import get_num_classes_from_predict
53
- from sonusai.utils import human_readable_size
54
- from sonusai.utils import reshape_outputs
55
- from sonusai.utils import seconds_to_hms
56
-
57
- update_console_handler(verbose)
58
- initial_log_messages('evaluate')
59
-
60
- if truth.shape[-1] != predict.shape[-1]:
61
- logger.exception(f'Number of classes in truth and predict are not equal. Exiting.')
62
- raise SystemExit(1)
63
-
64
- # truth, predict can be either [frames, num_classes] or [frames, timesteps, num_classes]
65
- # in binary case dim may not exist, detect this and set num_classes == 1
66
- timesteps = -1
67
- predict, truth = reshape_outputs(predict=predict, truth=truth, timesteps=timesteps)
68
- num_classes = get_num_classes_from_predict(predict=predict, timesteps=timesteps)
69
-
70
- fdiff = truth.shape[0] - predict.shape[0]
71
- if fdiff > 0:
72
- # truth = truth[0:-fdiff,:]
73
- predict = np.concatenate((predict, np.zeros((fdiff, num_classes), dtype=np.float32)))
74
- logger.info(f'Truth has more feature-frames than predict, padding predict with zeros to match.')
75
-
76
- if fdiff < 0:
77
- predict = predict[0:fdiff, :]
78
- logger.info(f'Predict has more feature-frames than truth, trimming predict to match.')
79
-
80
- # Check segsnr, input is always in transform frames
81
- compute_segsnr = False
82
- if len(segsnr) > 0:
83
- segsnr_feature_frames = segsnr.shape[0] / (mixdb.feature_step_samples / mixdb.ft_config.R)
84
- if segsnr_feature_frames == truth.shape[0]:
85
- compute_segsnr = True
86
- else:
87
- logger.warning('segsnr length does not match truth, ignoring.')
88
-
89
- # Check predict_thr array or scalar and return final scalar predict_thr value
90
- if not mixdb.truth_mutex:
91
- if num_classes > 1:
92
- if not isinstance(predict_thr, np.ndarray):
93
- if predict_thr == 0:
94
- # multi-label predict_thr scalar 0 force to 0.5 default
95
- predict_thr = np.atleast_1d(0.5)
96
- else:
97
- predict_thr = np.atleast_1d(predict_thr)
98
- else:
99
- if predict_thr.ndim == 1:
100
- if predict_thr[0] == 0:
101
- # multi-label predict_thr array scalar 0 force to 0.5 default
102
- predict_thr = np.atleast_1d(0.5)
103
- else:
104
- # multi-label predict_thr array set to scalar = array[0]
105
- predict_thr = predict_thr[0]
106
- else:
107
- # single-label mode, force argmax mode
108
- predict_thr = np.atleast_1d(0)
109
-
110
- if predict_thr == -1:
111
- thrpr, thrroc, _, _ = calc_optimal_thresholds(truth, predict, timesteps)
112
- predict_thr = np.atleast_1d(thrpr)
113
- predict_thr = np.maximum(predict_thr, 0.001) # enforce lower limit
114
- predict_thr = np.minimum(predict_thr, 0.999) # enforce upper limit
115
- predict_thr = predict_thr.round(2)
116
-
117
- # Summarize the mixture data
118
- num_mixtures = mixdb.num_mixtures
119
- total_samples = sum([mixture.samples for mixture in mixdb.mixtures])
120
- duration = total_samples / SAMPLE_RATE
121
-
122
- logger.info('')
123
- logger.info(f'Mixtures: {num_mixtures}')
124
- logger.info(f'Duration: {seconds_to_hms(seconds=duration)}')
125
- logger.info(f'truth: {human_readable_size(truth.nbytes, 1)}')
126
- logger.info(f'predict: {human_readable_size(predict.nbytes, 1)}')
127
- if compute_segsnr:
128
- logger.info(f'segsnr: {human_readable_size(segsnr.nbytes, 1)}')
129
- if feature:
130
- logger.info(f'feature: {human_readable_size(feature.nbytes, 1)}')
131
-
132
- logger.info(f'Classes: {num_classes}')
133
- if mixdb.truth_mutex:
134
- logger.info(f'Mode: Single-label / truth_mutex / softmax')
135
- else:
136
- logger.info(f'Mode: Multi-label / Binary')
137
-
138
- mxid_snro = get_mixids_from_snr(mixdb=mixdb)
139
- snrlist = list(mxid_snro.keys())
140
- snrlist.sort(reverse=True)
141
- logger.info(f'Ordered SNRs: {snrlist}\n')
142
- predict_thr_info = predict_thr.transpose() if isinstance(predict_thr, np.ndarray) else predict_thr
143
- logger.info(f'Prediction Threshold(s): {predict_thr_info}\n')
144
-
145
- # Top-level report over all mixtures
146
- macrodf, microdf, wghtdf, mxid_snro = snr_summary(mixdb=mixdb,
147
- mixid=':',
148
- truth_f=truth,
149
- predict=predict,
150
- segsnr=segsnr if compute_segsnr else None,
151
- predict_thr=predict_thr)
152
-
153
- if num_classes > 1:
154
- logger.info(f'Metrics micro-avg per SNR over all {num_mixtures} mixtures:')
155
- else:
156
- logger.info(f'Metrics per SNR over all {num_mixtures} mixtures:')
157
- logger.info(microdf.round(3).to_string())
158
- logger.info('')
159
- if output_dir:
160
- microdf.round(3).to_csv(join(output_dir, 'snr.csv'))
161
-
162
- if mixdb.truth_mutex:
163
- macrodf, microdf, wghtdf, mxid_snro = snr_summary(mixdb=mixdb,
164
- mixid=':',
165
- truth_f=truth[:, 0:-1],
166
- predict=predict[:, 0:-1],
167
- segsnr=segsnr if compute_segsnr else None,
168
- predict_thr=predict_thr)
169
-
170
- logger.info(f'Metrics micro-avg without "Other" class per SNR over all {num_mixtures} mixtures:')
171
- logger.info(microdf.round(3).to_string())
172
- logger.info('')
173
- if output_dir:
174
- microdf.round(3).to_csv(join(output_dir, 'snrwo.csv'))
175
-
176
- for snri in snrlist:
177
- mxids = mxid_snro[snri]
178
- classdf = class_summary(mixdb, mxids, truth, predict, predict_thr)
179
- logger.info(f'Metrics per class for SNR {snri} over {len(mxids)} mixtures:')
180
- logger.info(classdf.round(3).to_string())
181
- logger.info('')
182
- if output_dir:
183
- classdf.round(3).to_csv(join(output_dir, f'class_snr{snri}.csv'))
184
-
185
-
186
- def main() -> None:
187
- from datetime import datetime
188
- from os import mkdir
189
- from os.path import join
190
-
191
- import h5py
192
- from docopt import docopt
193
-
194
- import sonusai
195
- from sonusai import SonusAIError
196
- from sonusai import create_file_handler
197
- from sonusai.utils import read_predict_data
198
- from sonusai.utils import trim_docstring
199
-
200
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
201
-
202
- verbose = args['--verbose']
203
- feature_name = args['--feature']
204
- predict_name = args['--predict']
205
- predict_threshold = np.array(float(args['--thr']), dtype=np.float32)
206
- location = args['LOC']
207
-
208
- mixdb = MixtureDatabase(location)
209
-
210
- # create output directory
211
- output_dir = f'evaluate-{datetime.now():%Y%m%d}'
212
- try:
213
- mkdir(output_dir)
214
- except OSError as _:
215
- output_dir = f'evaluate-{datetime.now():%Y%m%d-%H%M%S}'
216
- try:
217
- mkdir(output_dir)
218
- except OSError as error:
219
- raise SonusAIError(f'Could not create directory, {output_dir}: {error}')
220
-
221
- create_file_handler(join(output_dir, 'evaluate.log'))
222
-
223
- with h5py.File(feature_name, 'r') as f:
224
- truth_f = np.array(f['truth_f'])
225
- segsnr = np.array(f['segsnr'])
226
-
227
- predict = read_predict_data(predict_name)
228
-
229
- evaluate(mixdb=mixdb,
230
- truth=truth_f,
231
- segsnr=segsnr,
232
- output_dir=output_dir,
233
- predict=predict,
234
- predict_thr=predict_threshold,
235
- verbose=verbose)
236
-
237
- logger.info(f'Wrote results to {output_dir}')
238
-
239
-
240
- if __name__ == '__main__':
241
- try:
242
- main()
243
- except KeyboardInterrupt:
244
- logger.info('Canceled due to keyboard interrupt')
245
- raise SystemExit(0)
sonusai/keras_onnx.py DELETED
@@ -1,86 +0,0 @@
1
- """sonusai keras_onnx
2
-
3
- usage: keras_onnx [-hvr] (-m MODEL) (-w WEIGHTS) [-b BATCH] [-t TSTEPS] [-o OUTPUT]
4
-
5
- options:
6
- -h, --help
7
- -v, --verbose Be verbose.
8
- -m MODEL, --model MODEL Python model file.
9
- -w WEIGHTS, --weights WEIGHTS Keras model weights file.
10
- -b BATCH, --batch BATCH Batch size.
11
- -t TSTEPS, --tsteps TSTEPS Timesteps.
12
- -o OUTPUT, --output OUTPUT Output directory.
13
-
14
- Convert a trained Keras model to ONNX.
15
-
16
- Inputs:
17
- MODEL A SonusAI Python model file with build and/or hypermodel functions.
18
- WEIGHTS A Keras model weights file (or model file with weights).
19
-
20
- Outputs:
21
- OUTPUT/ A directory containing:
22
- <MODEL>.onnx Model file with batch_size and timesteps equal to provided parameters
23
- <MODEL>-b1.onnx Model file with batch_size=1 and if the timesteps dimension exists it
24
- is set to 1 (useful for real-time inference applications)
25
- keras_onnx.log
26
-
27
- Results are written into subdirectory <MODEL>-<TIMESTAMP> unless OUTPUT is specified.
28
-
29
- """
30
- from sonusai import logger
31
-
32
-
33
- def main() -> None:
34
- from docopt import docopt
35
-
36
- import sonusai
37
- from sonusai.utils import trim_docstring
38
-
39
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
40
-
41
- verbose = args['--verbose']
42
- model_name = args['--model']
43
- weight_name = args['--weights']
44
- batch_size = args['--batch']
45
- timesteps = args['--tsteps']
46
- output_dir = args['--output']
47
-
48
- from os import makedirs
49
- from os.path import basename
50
- from os.path import join
51
- from os.path import splitext
52
-
53
- from sonusai import create_file_handler
54
- from sonusai import initial_log_messages
55
- from sonusai import update_console_handler
56
- from sonusai.utils import create_ts_name
57
- from sonusai.utils import keras_onnx
58
-
59
- model_tail = basename(model_name)
60
- model_root = splitext(model_tail)[0]
61
-
62
- if batch_size is not None:
63
- batch_size = int(batch_size)
64
-
65
- if timesteps is not None:
66
- timesteps = int(timesteps)
67
-
68
- if output_dir is None:
69
- output_dir = create_ts_name(model_root)
70
-
71
- makedirs(output_dir, exist_ok=True)
72
-
73
- # Setup logging file
74
- create_file_handler(join(output_dir, 'keras_onnx.log'))
75
- update_console_handler(verbose)
76
- initial_log_messages('keras_onnx')
77
-
78
- keras_onnx(model_name, weight_name, timesteps, batch_size, output_dir)
79
-
80
-
81
- if __name__ == '__main__':
82
- try:
83
- main()
84
- except KeyboardInterrupt:
85
- logger.info('Canceled due to keyboard interrupt')
86
- exit()
sonusai/keras_predict.py DELETED
@@ -1,231 +0,0 @@
1
- """sonusai keras_predict
2
-
3
- usage: keras_predict [-hvr] [-i MIXID] (-m MODEL) (-w KMODEL) [-b BATCH] [-t TSTEPS] INPUT ...
4
-
5
- options:
6
- -h, --help
7
- -v, --verbose Be verbose.
8
- -i MIXID, --mixid MIXID Mixture ID(s) to generate if input is a mixture database. [default: *].
9
- -m MODEL, --model MODEL Python model file.
10
- -w KMODEL, --weights KMODEL Keras model weights file.
11
- -b BATCH, --batch BATCH Batch size.
12
- -t TSTEPS, --tsteps TSTEPS Timesteps.
13
- -r, --reset Reset model between each file.
14
-
15
- Run prediction on a trained Keras model defined by a SonusAI Keras Python model file using SonusAI genft or WAV data.
16
-
17
- Inputs:
18
- MODEL A SonusAI Python model file with build and/or hypermodel functions.
19
- KMODEL A Keras model weights file (or model file with weights).
20
- INPUT The input data must be one of the following:
21
- * Single WAV file or glob of WAV files
22
- Using the given model, generate feature data and run prediction. A model file must be
23
- provided. The MIXID is ignored.
24
-
25
- * directory
26
- Using the given SonusAI mixture database directory, generate feature and truth data if not found.
27
- Run prediction. The MIXID is required.
28
-
29
- Outputs the following to kpredict-<TIMESTAMP> directory:
30
- <id>.h5
31
- dataset: predict
32
- keras_predict.log
33
-
34
- """
35
- from typing import Any
36
-
37
- from sonusai import logger
38
- from sonusai.mixture import Feature
39
- from sonusai.mixture import Predict
40
-
41
-
42
- def main() -> None:
43
- from docopt import docopt
44
-
45
- import sonusai
46
- from sonusai.utils import trim_docstring
47
-
48
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
49
-
50
- verbose = args['--verbose']
51
- mixids = args['--mixid']
52
- model_name = args['--model']
53
- weights_name = args['--weights']
54
- batch_size = args['--batch']
55
- timesteps = args['--tsteps']
56
- reset = args['--reset']
57
- input_name = args['INPUT']
58
-
59
- from os import makedirs
60
- from os.path import basename
61
- from os.path import isdir
62
- from os.path import isfile
63
- from os.path import join
64
- from os.path import splitext
65
-
66
- import h5py
67
- import keras_tuner as kt
68
- import tensorflow as tf
69
- from keras import backend as kb
70
-
71
- from sonusai import create_file_handler
72
- from sonusai import initial_log_messages
73
- from sonusai import update_console_handler
74
- from sonusai.data_generator import KerasFromH5
75
- from sonusai.mixture import MixtureDatabase
76
- from sonusai.mixture import get_feature_from_audio
77
- from sonusai.mixture import read_audio
78
- from sonusai.utils import create_ts_name
79
- from sonusai.utils import get_frames_per_batch
80
- from sonusai.utils import import_and_check_keras_model
81
- from sonusai.utils import reshape_outputs
82
-
83
- if batch_size is not None:
84
- batch_size = int(batch_size)
85
-
86
- if timesteps is not None:
87
- timesteps = int(timesteps)
88
-
89
- output_dir = create_ts_name('kpredict')
90
- makedirs(output_dir, exist_ok=True)
91
-
92
- # Setup logging file
93
- create_file_handler(join(output_dir, 'keras_predict.log'))
94
- update_console_handler(verbose)
95
- initial_log_messages('keras_predict')
96
-
97
- logger.info(f'tensorflow {tf.__version__}')
98
- logger.info(f'keras {tf.keras.__version__}')
99
- logger.info('')
100
-
101
- hypermodel = import_and_check_keras_model(model_name=model_name,
102
- weights_name=weights_name,
103
- timesteps=timesteps,
104
- batch_size=batch_size)
105
- built_model = hypermodel.build_model(kt.HyperParameters())
106
-
107
- frames_per_batch = get_frames_per_batch(hypermodel.batch_size, hypermodel.timesteps)
108
-
109
- kb.clear_session()
110
- logger.info('')
111
- built_model.summary(print_fn=logger.info)
112
- logger.info('')
113
- logger.info(f'feature {hypermodel.feature}')
114
- logger.info(f'num_classes {hypermodel.num_classes}')
115
- logger.info(f'batch_size {hypermodel.batch_size}')
116
- logger.info(f'timesteps {hypermodel.timesteps}')
117
- logger.info(f'flatten {hypermodel.flatten}')
118
- logger.info(f'add1ch {hypermodel.add1ch}')
119
- logger.info(f'truth_mutex {hypermodel.truth_mutex}')
120
- logger.info(f'input_shape {hypermodel.input_shape}')
121
- logger.info('')
122
-
123
- logger.info(f'Loading weights from {weights_name}')
124
- built_model.load_weights(weights_name)
125
-
126
- logger.info('')
127
- if len(input_name) == 1 and isdir(input_name[0]):
128
- input_name = input_name[0]
129
- logger.info(f'Load mixture database from {input_name}')
130
- mixdb = MixtureDatabase(input_name)
131
-
132
- if mixdb.feature != hypermodel.feature:
133
- logger.exception(f'Feature in mixture database does not match feature in model')
134
- raise SystemExit(1)
135
-
136
- mixids = mixdb.mixids_to_list(mixids)
137
- if reset:
138
- # reset mode cycles through each file one at a time
139
- for mixid in mixids:
140
- feature, _ = mixdb.mixture_ft(mixid)
141
-
142
- feature, predict = _pad_and_predict(hypermodel=hypermodel,
143
- built_model=built_model,
144
- feature=feature,
145
- frames_per_batch=frames_per_batch)
146
-
147
- output_name = join(output_dir, mixdb.mixtures[mixid].name)
148
- with h5py.File(output_name, 'a') as f:
149
- if 'predict' in f:
150
- del f['predict']
151
- f.create_dataset(name='predict', data=predict)
152
- else:
153
- # Run all data at once using a data generator
154
- feature = KerasFromH5(mixdb=mixdb,
155
- mixids=mixids,
156
- batch_size=hypermodel.batch_size,
157
- timesteps=hypermodel.timesteps,
158
- flatten=hypermodel.flatten,
159
- add1ch=hypermodel.add1ch)
160
-
161
- predict = built_model.predict(feature, batch_size=hypermodel.batch_size, verbose=1)
162
- predict, _ = reshape_outputs(predict=predict, timesteps=hypermodel.timesteps)
163
-
164
- # Write data to separate files
165
- for idx, mixid in enumerate(mixids):
166
- output_name = join(output_dir, mixdb.mixtures[mixid].name)
167
- with h5py.File(output_name, 'a') as f:
168
- if 'predict' in f:
169
- del f['predict']
170
- f.create_dataset('predict', data=predict[feature.file_indices[idx]])
171
-
172
- logger.info(f'Saved results to {output_dir}')
173
- return
174
-
175
- if not all(isfile(file) and splitext(file)[1] == '.wav' for file in input_name):
176
- logger.exception(f'Do not know how to process input from {input_name}')
177
- raise SystemExit(1)
178
-
179
- logger.info(f'Run prediction on {len(input_name):,} WAV files')
180
- for file in input_name:
181
- # Convert WAV to feature data
182
- audio = read_audio(file)
183
- feature = get_feature_from_audio(audio=audio, feature=hypermodel.feature)
184
-
185
- feature, predict = _pad_and_predict(hypermodel=hypermodel,
186
- built_model=built_model,
187
- feature=feature,
188
- frames_per_batch=frames_per_batch)
189
-
190
- output_name = join(output_dir, splitext(basename(file))[0] + '.h5')
191
- with h5py.File(output_name, 'a') as f:
192
- if 'feature' in f:
193
- del f['feature']
194
- f.create_dataset(name='feature', data=feature)
195
-
196
- if 'predict' in f:
197
- del f['predict']
198
- f.create_dataset(name='predict', data=predict)
199
-
200
- logger.info(f'Saved results to {output_dir}')
201
-
202
-
203
- def _pad_and_predict(hypermodel: Any,
204
- built_model: Any,
205
- feature: Feature,
206
- frames_per_batch: int) -> tuple[Feature, Predict]:
207
- import numpy as np
208
-
209
- from sonusai.utils import reshape_inputs
210
- from sonusai.utils import reshape_outputs
211
-
212
- frames = feature.shape[0]
213
- padding = frames_per_batch - frames % frames_per_batch
214
- feature = np.pad(array=feature, pad_width=((0, padding), (0, 0), (0, 0)))
215
- feature, _ = reshape_inputs(feature=feature,
216
- batch_size=hypermodel.batch_size,
217
- timesteps=hypermodel.timesteps,
218
- flatten=hypermodel.flatten,
219
- add1ch=hypermodel.add1ch)
220
- predict = built_model.predict(feature, batch_size=hypermodel.batch_size, verbose=1)
221
- predict, _ = reshape_outputs(predict=predict, timesteps=hypermodel.timesteps)
222
- predict = predict[:frames, :]
223
- return feature, predict
224
-
225
-
226
- if __name__ == '__main__':
227
- try:
228
- main()
229
- except KeyboardInterrupt:
230
- logger.info('Canceled due to keyboard interrupt')
231
- exit()