sonusai 0.15.8__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +35 -4
- sonusai/audiofe.py +237 -0
- sonusai/calc_metric_spenh.py +21 -12
- sonusai/genft.py +2 -1
- sonusai/genmixdb.py +5 -5
- sonusai/lsdb.py +2 -2
- sonusai/main.py +58 -61
- sonusai/mixture/__init__.py +4 -2
- sonusai/mixture/audio.py +0 -34
- sonusai/mixture/config.py +1 -2
- sonusai/mixture/datatypes.py +1 -1
- sonusai/mixture/feature.py +75 -21
- sonusai/mixture/helpers.py +60 -30
- sonusai/mixture/log_duration_and_sizes.py +2 -2
- sonusai/mixture/mixdb.py +13 -10
- sonusai/mixture/spectral_mask.py +14 -14
- sonusai/mixture/truth_functions/data.py +1 -1
- sonusai/mixture/truth_functions/target.py +2 -2
- sonusai/mkmanifest.py +29 -2
- sonusai/onnx_predict.py +1 -1
- sonusai/plot.py +4 -4
- sonusai/post_spenh_targetf.py +8 -8
- sonusai/utils/__init__.py +8 -7
- sonusai/utils/asl_p56.py +3 -3
- sonusai/utils/asr.py +35 -8
- sonusai/utils/asr_functions/__init__.py +0 -5
- sonusai/utils/asr_functions/aaware_whisper.py +2 -2
- sonusai/utils/asr_manifest_functions/__init__.py +1 -0
- sonusai/utils/asr_manifest_functions/mcgill_speech.py +29 -0
- sonusai/utils/audio_devices.py +41 -0
- sonusai/utils/calculate_input_shape.py +3 -4
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/{trim_docstring.py → docstring.py} +20 -0
- sonusai/utils/model_utils.py +30 -0
- sonusai/utils/onnx_utils.py +19 -45
- sonusai/utils/reshape.py +11 -11
- sonusai/utils/wave.py +12 -5
- {sonusai-0.15.8.dist-info → sonusai-0.16.0.dist-info}/METADATA +8 -19
- {sonusai-0.15.8.dist-info → sonusai-0.16.0.dist-info}/RECORD +41 -54
- {sonusai-0.15.8.dist-info → sonusai-0.16.0.dist-info}/WHEEL +1 -1
- sonusai/data_generator/__init__.py +0 -5
- sonusai/data_generator/dataset_from_mixdb.py +0 -143
- sonusai/data_generator/keras_from_mixdb.py +0 -169
- sonusai/data_generator/torch_from_mixdb.py +0 -122
- sonusai/evaluate.py +0 -245
- sonusai/keras_onnx.py +0 -86
- sonusai/keras_predict.py +0 -231
- sonusai/keras_train.py +0 -334
- sonusai/torchl_onnx.py +0 -216
- sonusai/torchl_predict.py +0 -547
- sonusai/torchl_train.py +0 -223
- sonusai/utils/asr_functions/aixplain_whisper.py +0 -59
- sonusai/utils/asr_functions/data.py +0 -16
- sonusai/utils/asr_functions/deepgram.py +0 -97
- sonusai/utils/asr_functions/fastwhisper.py +0 -90
- sonusai/utils/asr_functions/google.py +0 -95
- sonusai/utils/asr_functions/whisper.py +0 -49
- sonusai/utils/keras_utils.py +0 -226
- {sonusai-0.15.8.dist-info → sonusai-0.16.0.dist-info}/entry_points.txt +0 -0
sonusai/evaluate.py
DELETED
@@ -1,245 +0,0 @@
|
|
1
|
-
"""sonusai evaluate
|
2
|
-
|
3
|
-
usage: evaluate [-hv] [-i MIXID] (-f FEATURE) (-p PREDICT) [-t PTHR] LOC
|
4
|
-
|
5
|
-
options:
|
6
|
-
-h, --help
|
7
|
-
-v, --verbose Be verbose.
|
8
|
-
-i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
|
9
|
-
-p PREDICT, --predict PREDICT A directory containing prediction data.
|
10
|
-
-t PTHR, --thr PTHR Optional prediction decision threshold(s). [default: 0].
|
11
|
-
|
12
|
-
Evaluate calculates performance metrics of neural-network models from model prediction data and genft data.
|
13
|
-
|
14
|
-
Inputs:
|
15
|
-
LOC A SonusAI mixture database directory.
|
16
|
-
MIXID A glob of mixture ID(s) to generate.
|
17
|
-
PREDICT A directory containing SonusAI predict HDF5 files. Contains:
|
18
|
-
dataset: predict (either [frames, num_classes] or [frames, timesteps, num_classes])
|
19
|
-
PTHR Scalar or array of thresholds. Default 0 will select values:
|
20
|
-
argmax() if mixdb indicates single-label mode (truth_mutex = true)
|
21
|
-
0.5 if mixdb indicates multi-label mode (truth_mutex = false)
|
22
|
-
If PTHR = -1, optimal thresholds are calculated using precision_recall_curve() which
|
23
|
-
optimizes F1 score.
|
24
|
-
"""
|
25
|
-
import numpy as np
|
26
|
-
|
27
|
-
from sonusai import logger
|
28
|
-
from sonusai.mixture import Feature
|
29
|
-
from sonusai.mixture import MixtureDatabase
|
30
|
-
from sonusai.mixture import Predict
|
31
|
-
from sonusai.mixture import Segsnr
|
32
|
-
from sonusai.mixture import Truth
|
33
|
-
|
34
|
-
|
35
|
-
def evaluate(mixdb: MixtureDatabase,
|
36
|
-
truth: Truth,
|
37
|
-
predict: Predict = None,
|
38
|
-
segsnr: Segsnr = None,
|
39
|
-
output_dir: str = None,
|
40
|
-
predict_thr: float | np.ndarray = 0,
|
41
|
-
feature: Feature = None,
|
42
|
-
verbose: bool = False) -> None:
|
43
|
-
from os.path import join
|
44
|
-
|
45
|
-
from sonusai import initial_log_messages
|
46
|
-
from sonusai import update_console_handler
|
47
|
-
from sonusai.metrics import calc_optimal_thresholds
|
48
|
-
from sonusai.metrics import class_summary
|
49
|
-
from sonusai.metrics import snr_summary
|
50
|
-
from sonusai.mixture import SAMPLE_RATE
|
51
|
-
from sonusai.queries import get_mixids_from_snr
|
52
|
-
from sonusai.utils import get_num_classes_from_predict
|
53
|
-
from sonusai.utils import human_readable_size
|
54
|
-
from sonusai.utils import reshape_outputs
|
55
|
-
from sonusai.utils import seconds_to_hms
|
56
|
-
|
57
|
-
update_console_handler(verbose)
|
58
|
-
initial_log_messages('evaluate')
|
59
|
-
|
60
|
-
if truth.shape[-1] != predict.shape[-1]:
|
61
|
-
logger.exception(f'Number of classes in truth and predict are not equal. Exiting.')
|
62
|
-
raise SystemExit(1)
|
63
|
-
|
64
|
-
# truth, predict can be either [frames, num_classes] or [frames, timesteps, num_classes]
|
65
|
-
# in binary case dim may not exist, detect this and set num_classes == 1
|
66
|
-
timesteps = -1
|
67
|
-
predict, truth = reshape_outputs(predict=predict, truth=truth, timesteps=timesteps)
|
68
|
-
num_classes = get_num_classes_from_predict(predict=predict, timesteps=timesteps)
|
69
|
-
|
70
|
-
fdiff = truth.shape[0] - predict.shape[0]
|
71
|
-
if fdiff > 0:
|
72
|
-
# truth = truth[0:-fdiff,:]
|
73
|
-
predict = np.concatenate((predict, np.zeros((fdiff, num_classes), dtype=np.float32)))
|
74
|
-
logger.info(f'Truth has more feature-frames than predict, padding predict with zeros to match.')
|
75
|
-
|
76
|
-
if fdiff < 0:
|
77
|
-
predict = predict[0:fdiff, :]
|
78
|
-
logger.info(f'Predict has more feature-frames than truth, trimming predict to match.')
|
79
|
-
|
80
|
-
# Check segsnr, input is always in transform frames
|
81
|
-
compute_segsnr = False
|
82
|
-
if len(segsnr) > 0:
|
83
|
-
segsnr_feature_frames = segsnr.shape[0] / (mixdb.feature_step_samples / mixdb.ft_config.R)
|
84
|
-
if segsnr_feature_frames == truth.shape[0]:
|
85
|
-
compute_segsnr = True
|
86
|
-
else:
|
87
|
-
logger.warning('segsnr length does not match truth, ignoring.')
|
88
|
-
|
89
|
-
# Check predict_thr array or scalar and return final scalar predict_thr value
|
90
|
-
if not mixdb.truth_mutex:
|
91
|
-
if num_classes > 1:
|
92
|
-
if not isinstance(predict_thr, np.ndarray):
|
93
|
-
if predict_thr == 0:
|
94
|
-
# multi-label predict_thr scalar 0 force to 0.5 default
|
95
|
-
predict_thr = np.atleast_1d(0.5)
|
96
|
-
else:
|
97
|
-
predict_thr = np.atleast_1d(predict_thr)
|
98
|
-
else:
|
99
|
-
if predict_thr.ndim == 1:
|
100
|
-
if predict_thr[0] == 0:
|
101
|
-
# multi-label predict_thr array scalar 0 force to 0.5 default
|
102
|
-
predict_thr = np.atleast_1d(0.5)
|
103
|
-
else:
|
104
|
-
# multi-label predict_thr array set to scalar = array[0]
|
105
|
-
predict_thr = predict_thr[0]
|
106
|
-
else:
|
107
|
-
# single-label mode, force argmax mode
|
108
|
-
predict_thr = np.atleast_1d(0)
|
109
|
-
|
110
|
-
if predict_thr == -1:
|
111
|
-
thrpr, thrroc, _, _ = calc_optimal_thresholds(truth, predict, timesteps)
|
112
|
-
predict_thr = np.atleast_1d(thrpr)
|
113
|
-
predict_thr = np.maximum(predict_thr, 0.001) # enforce lower limit
|
114
|
-
predict_thr = np.minimum(predict_thr, 0.999) # enforce upper limit
|
115
|
-
predict_thr = predict_thr.round(2)
|
116
|
-
|
117
|
-
# Summarize the mixture data
|
118
|
-
num_mixtures = mixdb.num_mixtures
|
119
|
-
total_samples = sum([mixture.samples for mixture in mixdb.mixtures])
|
120
|
-
duration = total_samples / SAMPLE_RATE
|
121
|
-
|
122
|
-
logger.info('')
|
123
|
-
logger.info(f'Mixtures: {num_mixtures}')
|
124
|
-
logger.info(f'Duration: {seconds_to_hms(seconds=duration)}')
|
125
|
-
logger.info(f'truth: {human_readable_size(truth.nbytes, 1)}')
|
126
|
-
logger.info(f'predict: {human_readable_size(predict.nbytes, 1)}')
|
127
|
-
if compute_segsnr:
|
128
|
-
logger.info(f'segsnr: {human_readable_size(segsnr.nbytes, 1)}')
|
129
|
-
if feature:
|
130
|
-
logger.info(f'feature: {human_readable_size(feature.nbytes, 1)}')
|
131
|
-
|
132
|
-
logger.info(f'Classes: {num_classes}')
|
133
|
-
if mixdb.truth_mutex:
|
134
|
-
logger.info(f'Mode: Single-label / truth_mutex / softmax')
|
135
|
-
else:
|
136
|
-
logger.info(f'Mode: Multi-label / Binary')
|
137
|
-
|
138
|
-
mxid_snro = get_mixids_from_snr(mixdb=mixdb)
|
139
|
-
snrlist = list(mxid_snro.keys())
|
140
|
-
snrlist.sort(reverse=True)
|
141
|
-
logger.info(f'Ordered SNRs: {snrlist}\n')
|
142
|
-
predict_thr_info = predict_thr.transpose() if isinstance(predict_thr, np.ndarray) else predict_thr
|
143
|
-
logger.info(f'Prediction Threshold(s): {predict_thr_info}\n')
|
144
|
-
|
145
|
-
# Top-level report over all mixtures
|
146
|
-
macrodf, microdf, wghtdf, mxid_snro = snr_summary(mixdb=mixdb,
|
147
|
-
mixid=':',
|
148
|
-
truth_f=truth,
|
149
|
-
predict=predict,
|
150
|
-
segsnr=segsnr if compute_segsnr else None,
|
151
|
-
predict_thr=predict_thr)
|
152
|
-
|
153
|
-
if num_classes > 1:
|
154
|
-
logger.info(f'Metrics micro-avg per SNR over all {num_mixtures} mixtures:')
|
155
|
-
else:
|
156
|
-
logger.info(f'Metrics per SNR over all {num_mixtures} mixtures:')
|
157
|
-
logger.info(microdf.round(3).to_string())
|
158
|
-
logger.info('')
|
159
|
-
if output_dir:
|
160
|
-
microdf.round(3).to_csv(join(output_dir, 'snr.csv'))
|
161
|
-
|
162
|
-
if mixdb.truth_mutex:
|
163
|
-
macrodf, microdf, wghtdf, mxid_snro = snr_summary(mixdb=mixdb,
|
164
|
-
mixid=':',
|
165
|
-
truth_f=truth[:, 0:-1],
|
166
|
-
predict=predict[:, 0:-1],
|
167
|
-
segsnr=segsnr if compute_segsnr else None,
|
168
|
-
predict_thr=predict_thr)
|
169
|
-
|
170
|
-
logger.info(f'Metrics micro-avg without "Other" class per SNR over all {num_mixtures} mixtures:')
|
171
|
-
logger.info(microdf.round(3).to_string())
|
172
|
-
logger.info('')
|
173
|
-
if output_dir:
|
174
|
-
microdf.round(3).to_csv(join(output_dir, 'snrwo.csv'))
|
175
|
-
|
176
|
-
for snri in snrlist:
|
177
|
-
mxids = mxid_snro[snri]
|
178
|
-
classdf = class_summary(mixdb, mxids, truth, predict, predict_thr)
|
179
|
-
logger.info(f'Metrics per class for SNR {snri} over {len(mxids)} mixtures:')
|
180
|
-
logger.info(classdf.round(3).to_string())
|
181
|
-
logger.info('')
|
182
|
-
if output_dir:
|
183
|
-
classdf.round(3).to_csv(join(output_dir, f'class_snr{snri}.csv'))
|
184
|
-
|
185
|
-
|
186
|
-
def main() -> None:
|
187
|
-
from datetime import datetime
|
188
|
-
from os import mkdir
|
189
|
-
from os.path import join
|
190
|
-
|
191
|
-
import h5py
|
192
|
-
from docopt import docopt
|
193
|
-
|
194
|
-
import sonusai
|
195
|
-
from sonusai import SonusAIError
|
196
|
-
from sonusai import create_file_handler
|
197
|
-
from sonusai.utils import read_predict_data
|
198
|
-
from sonusai.utils import trim_docstring
|
199
|
-
|
200
|
-
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
201
|
-
|
202
|
-
verbose = args['--verbose']
|
203
|
-
feature_name = args['--feature']
|
204
|
-
predict_name = args['--predict']
|
205
|
-
predict_threshold = np.array(float(args['--thr']), dtype=np.float32)
|
206
|
-
location = args['LOC']
|
207
|
-
|
208
|
-
mixdb = MixtureDatabase(location)
|
209
|
-
|
210
|
-
# create output directory
|
211
|
-
output_dir = f'evaluate-{datetime.now():%Y%m%d}'
|
212
|
-
try:
|
213
|
-
mkdir(output_dir)
|
214
|
-
except OSError as _:
|
215
|
-
output_dir = f'evaluate-{datetime.now():%Y%m%d-%H%M%S}'
|
216
|
-
try:
|
217
|
-
mkdir(output_dir)
|
218
|
-
except OSError as error:
|
219
|
-
raise SonusAIError(f'Could not create directory, {output_dir}: {error}')
|
220
|
-
|
221
|
-
create_file_handler(join(output_dir, 'evaluate.log'))
|
222
|
-
|
223
|
-
with h5py.File(feature_name, 'r') as f:
|
224
|
-
truth_f = np.array(f['truth_f'])
|
225
|
-
segsnr = np.array(f['segsnr'])
|
226
|
-
|
227
|
-
predict = read_predict_data(predict_name)
|
228
|
-
|
229
|
-
evaluate(mixdb=mixdb,
|
230
|
-
truth=truth_f,
|
231
|
-
segsnr=segsnr,
|
232
|
-
output_dir=output_dir,
|
233
|
-
predict=predict,
|
234
|
-
predict_thr=predict_threshold,
|
235
|
-
verbose=verbose)
|
236
|
-
|
237
|
-
logger.info(f'Wrote results to {output_dir}')
|
238
|
-
|
239
|
-
|
240
|
-
if __name__ == '__main__':
|
241
|
-
try:
|
242
|
-
main()
|
243
|
-
except KeyboardInterrupt:
|
244
|
-
logger.info('Canceled due to keyboard interrupt')
|
245
|
-
raise SystemExit(0)
|
sonusai/keras_onnx.py
DELETED
@@ -1,86 +0,0 @@
|
|
1
|
-
"""sonusai keras_onnx
|
2
|
-
|
3
|
-
usage: keras_onnx [-hvr] (-m MODEL) (-w WEIGHTS) [-b BATCH] [-t TSTEPS] [-o OUTPUT]
|
4
|
-
|
5
|
-
options:
|
6
|
-
-h, --help
|
7
|
-
-v, --verbose Be verbose.
|
8
|
-
-m MODEL, --model MODEL Python model file.
|
9
|
-
-w WEIGHTS, --weights WEIGHTS Keras model weights file.
|
10
|
-
-b BATCH, --batch BATCH Batch size.
|
11
|
-
-t TSTEPS, --tsteps TSTEPS Timesteps.
|
12
|
-
-o OUTPUT, --output OUTPUT Output directory.
|
13
|
-
|
14
|
-
Convert a trained Keras model to ONNX.
|
15
|
-
|
16
|
-
Inputs:
|
17
|
-
MODEL A SonusAI Python model file with build and/or hypermodel functions.
|
18
|
-
WEIGHTS A Keras model weights file (or model file with weights).
|
19
|
-
|
20
|
-
Outputs:
|
21
|
-
OUTPUT/ A directory containing:
|
22
|
-
<MODEL>.onnx Model file with batch_size and timesteps equal to provided parameters
|
23
|
-
<MODEL>-b1.onnx Model file with batch_size=1 and if the timesteps dimension exists it
|
24
|
-
is set to 1 (useful for real-time inference applications)
|
25
|
-
keras_onnx.log
|
26
|
-
|
27
|
-
Results are written into subdirectory <MODEL>-<TIMESTAMP> unless OUTPUT is specified.
|
28
|
-
|
29
|
-
"""
|
30
|
-
from sonusai import logger
|
31
|
-
|
32
|
-
|
33
|
-
def main() -> None:
|
34
|
-
from docopt import docopt
|
35
|
-
|
36
|
-
import sonusai
|
37
|
-
from sonusai.utils import trim_docstring
|
38
|
-
|
39
|
-
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
40
|
-
|
41
|
-
verbose = args['--verbose']
|
42
|
-
model_name = args['--model']
|
43
|
-
weight_name = args['--weights']
|
44
|
-
batch_size = args['--batch']
|
45
|
-
timesteps = args['--tsteps']
|
46
|
-
output_dir = args['--output']
|
47
|
-
|
48
|
-
from os import makedirs
|
49
|
-
from os.path import basename
|
50
|
-
from os.path import join
|
51
|
-
from os.path import splitext
|
52
|
-
|
53
|
-
from sonusai import create_file_handler
|
54
|
-
from sonusai import initial_log_messages
|
55
|
-
from sonusai import update_console_handler
|
56
|
-
from sonusai.utils import create_ts_name
|
57
|
-
from sonusai.utils import keras_onnx
|
58
|
-
|
59
|
-
model_tail = basename(model_name)
|
60
|
-
model_root = splitext(model_tail)[0]
|
61
|
-
|
62
|
-
if batch_size is not None:
|
63
|
-
batch_size = int(batch_size)
|
64
|
-
|
65
|
-
if timesteps is not None:
|
66
|
-
timesteps = int(timesteps)
|
67
|
-
|
68
|
-
if output_dir is None:
|
69
|
-
output_dir = create_ts_name(model_root)
|
70
|
-
|
71
|
-
makedirs(output_dir, exist_ok=True)
|
72
|
-
|
73
|
-
# Setup logging file
|
74
|
-
create_file_handler(join(output_dir, 'keras_onnx.log'))
|
75
|
-
update_console_handler(verbose)
|
76
|
-
initial_log_messages('keras_onnx')
|
77
|
-
|
78
|
-
keras_onnx(model_name, weight_name, timesteps, batch_size, output_dir)
|
79
|
-
|
80
|
-
|
81
|
-
if __name__ == '__main__':
|
82
|
-
try:
|
83
|
-
main()
|
84
|
-
except KeyboardInterrupt:
|
85
|
-
logger.info('Canceled due to keyboard interrupt')
|
86
|
-
exit()
|
sonusai/keras_predict.py
DELETED
@@ -1,231 +0,0 @@
|
|
1
|
-
"""sonusai keras_predict
|
2
|
-
|
3
|
-
usage: keras_predict [-hvr] [-i MIXID] (-m MODEL) (-w KMODEL) [-b BATCH] [-t TSTEPS] INPUT ...
|
4
|
-
|
5
|
-
options:
|
6
|
-
-h, --help
|
7
|
-
-v, --verbose Be verbose.
|
8
|
-
-i MIXID, --mixid MIXID Mixture ID(s) to generate if input is a mixture database. [default: *].
|
9
|
-
-m MODEL, --model MODEL Python model file.
|
10
|
-
-w KMODEL, --weights KMODEL Keras model weights file.
|
11
|
-
-b BATCH, --batch BATCH Batch size.
|
12
|
-
-t TSTEPS, --tsteps TSTEPS Timesteps.
|
13
|
-
-r, --reset Reset model between each file.
|
14
|
-
|
15
|
-
Run prediction on a trained Keras model defined by a SonusAI Keras Python model file using SonusAI genft or WAV data.
|
16
|
-
|
17
|
-
Inputs:
|
18
|
-
MODEL A SonusAI Python model file with build and/or hypermodel functions.
|
19
|
-
KMODEL A Keras model weights file (or model file with weights).
|
20
|
-
INPUT The input data must be one of the following:
|
21
|
-
* Single WAV file or glob of WAV files
|
22
|
-
Using the given model, generate feature data and run prediction. A model file must be
|
23
|
-
provided. The MIXID is ignored.
|
24
|
-
|
25
|
-
* directory
|
26
|
-
Using the given SonusAI mixture database directory, generate feature and truth data if not found.
|
27
|
-
Run prediction. The MIXID is required.
|
28
|
-
|
29
|
-
Outputs the following to kpredict-<TIMESTAMP> directory:
|
30
|
-
<id>.h5
|
31
|
-
dataset: predict
|
32
|
-
keras_predict.log
|
33
|
-
|
34
|
-
"""
|
35
|
-
from typing import Any
|
36
|
-
|
37
|
-
from sonusai import logger
|
38
|
-
from sonusai.mixture import Feature
|
39
|
-
from sonusai.mixture import Predict
|
40
|
-
|
41
|
-
|
42
|
-
def main() -> None:
|
43
|
-
from docopt import docopt
|
44
|
-
|
45
|
-
import sonusai
|
46
|
-
from sonusai.utils import trim_docstring
|
47
|
-
|
48
|
-
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
49
|
-
|
50
|
-
verbose = args['--verbose']
|
51
|
-
mixids = args['--mixid']
|
52
|
-
model_name = args['--model']
|
53
|
-
weights_name = args['--weights']
|
54
|
-
batch_size = args['--batch']
|
55
|
-
timesteps = args['--tsteps']
|
56
|
-
reset = args['--reset']
|
57
|
-
input_name = args['INPUT']
|
58
|
-
|
59
|
-
from os import makedirs
|
60
|
-
from os.path import basename
|
61
|
-
from os.path import isdir
|
62
|
-
from os.path import isfile
|
63
|
-
from os.path import join
|
64
|
-
from os.path import splitext
|
65
|
-
|
66
|
-
import h5py
|
67
|
-
import keras_tuner as kt
|
68
|
-
import tensorflow as tf
|
69
|
-
from keras import backend as kb
|
70
|
-
|
71
|
-
from sonusai import create_file_handler
|
72
|
-
from sonusai import initial_log_messages
|
73
|
-
from sonusai import update_console_handler
|
74
|
-
from sonusai.data_generator import KerasFromH5
|
75
|
-
from sonusai.mixture import MixtureDatabase
|
76
|
-
from sonusai.mixture import get_feature_from_audio
|
77
|
-
from sonusai.mixture import read_audio
|
78
|
-
from sonusai.utils import create_ts_name
|
79
|
-
from sonusai.utils import get_frames_per_batch
|
80
|
-
from sonusai.utils import import_and_check_keras_model
|
81
|
-
from sonusai.utils import reshape_outputs
|
82
|
-
|
83
|
-
if batch_size is not None:
|
84
|
-
batch_size = int(batch_size)
|
85
|
-
|
86
|
-
if timesteps is not None:
|
87
|
-
timesteps = int(timesteps)
|
88
|
-
|
89
|
-
output_dir = create_ts_name('kpredict')
|
90
|
-
makedirs(output_dir, exist_ok=True)
|
91
|
-
|
92
|
-
# Setup logging file
|
93
|
-
create_file_handler(join(output_dir, 'keras_predict.log'))
|
94
|
-
update_console_handler(verbose)
|
95
|
-
initial_log_messages('keras_predict')
|
96
|
-
|
97
|
-
logger.info(f'tensorflow {tf.__version__}')
|
98
|
-
logger.info(f'keras {tf.keras.__version__}')
|
99
|
-
logger.info('')
|
100
|
-
|
101
|
-
hypermodel = import_and_check_keras_model(model_name=model_name,
|
102
|
-
weights_name=weights_name,
|
103
|
-
timesteps=timesteps,
|
104
|
-
batch_size=batch_size)
|
105
|
-
built_model = hypermodel.build_model(kt.HyperParameters())
|
106
|
-
|
107
|
-
frames_per_batch = get_frames_per_batch(hypermodel.batch_size, hypermodel.timesteps)
|
108
|
-
|
109
|
-
kb.clear_session()
|
110
|
-
logger.info('')
|
111
|
-
built_model.summary(print_fn=logger.info)
|
112
|
-
logger.info('')
|
113
|
-
logger.info(f'feature {hypermodel.feature}')
|
114
|
-
logger.info(f'num_classes {hypermodel.num_classes}')
|
115
|
-
logger.info(f'batch_size {hypermodel.batch_size}')
|
116
|
-
logger.info(f'timesteps {hypermodel.timesteps}')
|
117
|
-
logger.info(f'flatten {hypermodel.flatten}')
|
118
|
-
logger.info(f'add1ch {hypermodel.add1ch}')
|
119
|
-
logger.info(f'truth_mutex {hypermodel.truth_mutex}')
|
120
|
-
logger.info(f'input_shape {hypermodel.input_shape}')
|
121
|
-
logger.info('')
|
122
|
-
|
123
|
-
logger.info(f'Loading weights from {weights_name}')
|
124
|
-
built_model.load_weights(weights_name)
|
125
|
-
|
126
|
-
logger.info('')
|
127
|
-
if len(input_name) == 1 and isdir(input_name[0]):
|
128
|
-
input_name = input_name[0]
|
129
|
-
logger.info(f'Load mixture database from {input_name}')
|
130
|
-
mixdb = MixtureDatabase(input_name)
|
131
|
-
|
132
|
-
if mixdb.feature != hypermodel.feature:
|
133
|
-
logger.exception(f'Feature in mixture database does not match feature in model')
|
134
|
-
raise SystemExit(1)
|
135
|
-
|
136
|
-
mixids = mixdb.mixids_to_list(mixids)
|
137
|
-
if reset:
|
138
|
-
# reset mode cycles through each file one at a time
|
139
|
-
for mixid in mixids:
|
140
|
-
feature, _ = mixdb.mixture_ft(mixid)
|
141
|
-
|
142
|
-
feature, predict = _pad_and_predict(hypermodel=hypermodel,
|
143
|
-
built_model=built_model,
|
144
|
-
feature=feature,
|
145
|
-
frames_per_batch=frames_per_batch)
|
146
|
-
|
147
|
-
output_name = join(output_dir, mixdb.mixtures[mixid].name)
|
148
|
-
with h5py.File(output_name, 'a') as f:
|
149
|
-
if 'predict' in f:
|
150
|
-
del f['predict']
|
151
|
-
f.create_dataset(name='predict', data=predict)
|
152
|
-
else:
|
153
|
-
# Run all data at once using a data generator
|
154
|
-
feature = KerasFromH5(mixdb=mixdb,
|
155
|
-
mixids=mixids,
|
156
|
-
batch_size=hypermodel.batch_size,
|
157
|
-
timesteps=hypermodel.timesteps,
|
158
|
-
flatten=hypermodel.flatten,
|
159
|
-
add1ch=hypermodel.add1ch)
|
160
|
-
|
161
|
-
predict = built_model.predict(feature, batch_size=hypermodel.batch_size, verbose=1)
|
162
|
-
predict, _ = reshape_outputs(predict=predict, timesteps=hypermodel.timesteps)
|
163
|
-
|
164
|
-
# Write data to separate files
|
165
|
-
for idx, mixid in enumerate(mixids):
|
166
|
-
output_name = join(output_dir, mixdb.mixtures[mixid].name)
|
167
|
-
with h5py.File(output_name, 'a') as f:
|
168
|
-
if 'predict' in f:
|
169
|
-
del f['predict']
|
170
|
-
f.create_dataset('predict', data=predict[feature.file_indices[idx]])
|
171
|
-
|
172
|
-
logger.info(f'Saved results to {output_dir}')
|
173
|
-
return
|
174
|
-
|
175
|
-
if not all(isfile(file) and splitext(file)[1] == '.wav' for file in input_name):
|
176
|
-
logger.exception(f'Do not know how to process input from {input_name}')
|
177
|
-
raise SystemExit(1)
|
178
|
-
|
179
|
-
logger.info(f'Run prediction on {len(input_name):,} WAV files')
|
180
|
-
for file in input_name:
|
181
|
-
# Convert WAV to feature data
|
182
|
-
audio = read_audio(file)
|
183
|
-
feature = get_feature_from_audio(audio=audio, feature=hypermodel.feature)
|
184
|
-
|
185
|
-
feature, predict = _pad_and_predict(hypermodel=hypermodel,
|
186
|
-
built_model=built_model,
|
187
|
-
feature=feature,
|
188
|
-
frames_per_batch=frames_per_batch)
|
189
|
-
|
190
|
-
output_name = join(output_dir, splitext(basename(file))[0] + '.h5')
|
191
|
-
with h5py.File(output_name, 'a') as f:
|
192
|
-
if 'feature' in f:
|
193
|
-
del f['feature']
|
194
|
-
f.create_dataset(name='feature', data=feature)
|
195
|
-
|
196
|
-
if 'predict' in f:
|
197
|
-
del f['predict']
|
198
|
-
f.create_dataset(name='predict', data=predict)
|
199
|
-
|
200
|
-
logger.info(f'Saved results to {output_dir}')
|
201
|
-
|
202
|
-
|
203
|
-
def _pad_and_predict(hypermodel: Any,
|
204
|
-
built_model: Any,
|
205
|
-
feature: Feature,
|
206
|
-
frames_per_batch: int) -> tuple[Feature, Predict]:
|
207
|
-
import numpy as np
|
208
|
-
|
209
|
-
from sonusai.utils import reshape_inputs
|
210
|
-
from sonusai.utils import reshape_outputs
|
211
|
-
|
212
|
-
frames = feature.shape[0]
|
213
|
-
padding = frames_per_batch - frames % frames_per_batch
|
214
|
-
feature = np.pad(array=feature, pad_width=((0, padding), (0, 0), (0, 0)))
|
215
|
-
feature, _ = reshape_inputs(feature=feature,
|
216
|
-
batch_size=hypermodel.batch_size,
|
217
|
-
timesteps=hypermodel.timesteps,
|
218
|
-
flatten=hypermodel.flatten,
|
219
|
-
add1ch=hypermodel.add1ch)
|
220
|
-
predict = built_model.predict(feature, batch_size=hypermodel.batch_size, verbose=1)
|
221
|
-
predict, _ = reshape_outputs(predict=predict, timesteps=hypermodel.timesteps)
|
222
|
-
predict = predict[:frames, :]
|
223
|
-
return feature, predict
|
224
|
-
|
225
|
-
|
226
|
-
if __name__ == '__main__':
|
227
|
-
try:
|
228
|
-
main()
|
229
|
-
except KeyboardInterrupt:
|
230
|
-
logger.info('Canceled due to keyboard interrupt')
|
231
|
-
exit()
|