sonusai 0.15.9__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +36 -4
- sonusai/audiofe.py +111 -106
- sonusai/calc_metric_spenh.py +38 -22
- sonusai/genft.py +15 -6
- sonusai/genmix.py +14 -6
- sonusai/genmixdb.py +15 -7
- sonusai/gentcst.py +13 -6
- sonusai/lsdb.py +15 -5
- sonusai/main.py +58 -61
- sonusai/mixture/__init__.py +1 -0
- sonusai/mixture/config.py +1 -2
- sonusai/mkmanifest.py +43 -8
- sonusai/mkwav.py +15 -6
- sonusai/onnx_predict.py +16 -6
- sonusai/plot.py +16 -6
- sonusai/post_spenh_targetf.py +13 -6
- sonusai/summarize_metric_spenh.py +71 -0
- sonusai/tplot.py +14 -6
- sonusai/utils/__init__.py +4 -7
- sonusai/utils/asl_p56.py +3 -3
- sonusai/utils/asr.py +35 -8
- sonusai/utils/asr_functions/__init__.py +0 -5
- sonusai/utils/asr_functions/aaware_whisper.py +2 -2
- sonusai/utils/asr_manifest_functions/__init__.py +1 -0
- sonusai/utils/asr_manifest_functions/mcgill_speech.py +29 -0
- sonusai/utils/{trim_docstring.py → docstring.py} +20 -0
- sonusai/utils/model_utils.py +30 -0
- sonusai/utils/onnx_utils.py +19 -45
- {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/METADATA +7 -25
- {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/RECORD +32 -46
- sonusai/data_generator/__init__.py +0 -5
- sonusai/data_generator/dataset_from_mixdb.py +0 -143
- sonusai/data_generator/keras_from_mixdb.py +0 -169
- sonusai/data_generator/torch_from_mixdb.py +0 -122
- sonusai/keras_onnx.py +0 -86
- sonusai/keras_predict.py +0 -231
- sonusai/keras_train.py +0 -334
- sonusai/torchl_onnx.py +0 -216
- sonusai/torchl_predict.py +0 -542
- sonusai/torchl_train.py +0 -223
- sonusai/utils/asr_functions/aixplain_whisper.py +0 -59
- sonusai/utils/asr_functions/data.py +0 -16
- sonusai/utils/asr_functions/deepgram.py +0 -97
- sonusai/utils/asr_functions/fastwhisper.py +0 -90
- sonusai/utils/asr_functions/google.py +0 -95
- sonusai/utils/asr_functions/whisper.py +0 -49
- sonusai/utils/keras_utils.py +0 -226
- {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/WHEEL +0 -0
- {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/entry_points.txt +0 -0
sonusai/keras_train.py
DELETED
@@ -1,334 +0,0 @@
|
|
1
|
-
"""sonusai keras_train
|
2
|
-
|
3
|
-
usage: keras_train [-hgv] (-m MODEL) (-l VLOC) [-w KMODEL] [-e EPOCHS] [-b BATCH] [-t TSTEPS] [-p ESP] TLOC
|
4
|
-
|
5
|
-
options:
|
6
|
-
-h, --help
|
7
|
-
-v, --verbose Be verbose.
|
8
|
-
-m MODEL, --model MODEL Model Python file with build and/or hypermodel functions.
|
9
|
-
-l VLOC, --vloc VLOC Location of SonusAI mixture database to use for validation.
|
10
|
-
-w KMODEL, --weights KMODEL Keras model weights file.
|
11
|
-
-e EPOCHS, --epochs EPOCHS Number of epochs to use in training. [default: 8].
|
12
|
-
-b BATCH, --batch BATCH Batch size.
|
13
|
-
-t TSTEPS, --tsteps TSTEPS Timesteps.
|
14
|
-
-p ESP, --patience ESP Early stopping patience.
|
15
|
-
-g, --loss-batch-log Enable per-batch loss log. [default: False]
|
16
|
-
|
17
|
-
Use Keras to train a model defined by a Python definition file and SonusAI genft data.
|
18
|
-
|
19
|
-
Inputs:
|
20
|
-
TLOC A SonusAI mixture database directory to use for training data.
|
21
|
-
VLOC A SonusAI mixture database directory to use for validation data.
|
22
|
-
|
23
|
-
Results are written into subdirectory <MODEL>-<TIMESTAMP>.
|
24
|
-
Per-batch loss history, if enabled, is written to <basename>-history-lossb.npy
|
25
|
-
|
26
|
-
"""
|
27
|
-
import tensorflow as tf
|
28
|
-
|
29
|
-
from sonusai import logger
|
30
|
-
|
31
|
-
|
32
|
-
class LossBatchHistory(tf.keras.callbacks.Callback):
|
33
|
-
def __init__(self):
|
34
|
-
super().__init__()
|
35
|
-
self.history = None
|
36
|
-
|
37
|
-
def on_train_begin(self, logs=None):
|
38
|
-
self.history = {'loss': []}
|
39
|
-
|
40
|
-
def on_batch_end(self, batch, logs=None):
|
41
|
-
if logs is None:
|
42
|
-
logs = {}
|
43
|
-
self.history['loss'].append(logs.get('loss'))
|
44
|
-
|
45
|
-
|
46
|
-
class SonusAIModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
|
47
|
-
def __init__(self,
|
48
|
-
filepath,
|
49
|
-
monitor: str = "val_loss",
|
50
|
-
verbose: int = 0,
|
51
|
-
save_best_only: bool = False,
|
52
|
-
save_weights_only: bool = False,
|
53
|
-
mode: str = "auto",
|
54
|
-
save_freq="epoch",
|
55
|
-
options=None,
|
56
|
-
initial_value_threshold=None,
|
57
|
-
**kwargs):
|
58
|
-
super().__init__(filepath,
|
59
|
-
monitor,
|
60
|
-
verbose,
|
61
|
-
save_best_only,
|
62
|
-
save_weights_only,
|
63
|
-
mode,
|
64
|
-
save_freq,
|
65
|
-
options,
|
66
|
-
initial_value_threshold,
|
67
|
-
**kwargs)
|
68
|
-
self.feature = kwargs.get('feature', None)
|
69
|
-
self.num_classes = kwargs.get('num_classes', None)
|
70
|
-
|
71
|
-
def _save_model(self, epoch, batch, logs):
|
72
|
-
import h5py
|
73
|
-
|
74
|
-
super()._save_model(epoch, batch, logs)
|
75
|
-
|
76
|
-
with h5py.File(self.filepath, 'a') as f:
|
77
|
-
if self.feature is not None:
|
78
|
-
f.attrs['sonusai_feature'] = self.feature
|
79
|
-
if self.num_classes is not None:
|
80
|
-
f.attrs['sonusai_num_classes'] = str(self.num_classes)
|
81
|
-
|
82
|
-
|
83
|
-
def main() -> None:
|
84
|
-
from docopt import docopt
|
85
|
-
|
86
|
-
import sonusai
|
87
|
-
from sonusai.utils import trim_docstring
|
88
|
-
|
89
|
-
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
90
|
-
|
91
|
-
verbose = args['--verbose']
|
92
|
-
model_name = args['--model']
|
93
|
-
weights_name = args['--weights']
|
94
|
-
v_name = args['--vloc']
|
95
|
-
epochs = int(args['--epochs'])
|
96
|
-
batch_size = args['--batch']
|
97
|
-
timesteps = args['--tsteps']
|
98
|
-
esp = args['--patience']
|
99
|
-
loss_batch_log = args['--loss-batch-log']
|
100
|
-
t_name = args['TLOC']
|
101
|
-
|
102
|
-
import warnings
|
103
|
-
from os import makedirs
|
104
|
-
from os import walk
|
105
|
-
from os.path import basename
|
106
|
-
from os.path import join
|
107
|
-
from os.path import splitext
|
108
|
-
|
109
|
-
import h5py
|
110
|
-
import keras_tuner as kt
|
111
|
-
import numpy as np
|
112
|
-
|
113
|
-
with warnings.catch_warnings():
|
114
|
-
warnings.simplefilter('ignore')
|
115
|
-
from keras import backend as kb
|
116
|
-
from keras.callbacks import EarlyStopping
|
117
|
-
|
118
|
-
from sonusai import create_file_handler
|
119
|
-
from sonusai import initial_log_messages
|
120
|
-
from sonusai import update_console_handler
|
121
|
-
from sonusai.data_generator import KerasFromH5
|
122
|
-
from sonusai.mixture import MixtureDatabase
|
123
|
-
from sonusai.utils import check_keras_overrides
|
124
|
-
from sonusai.utils import create_ts_name
|
125
|
-
from sonusai.utils import import_keras_model
|
126
|
-
from sonusai.utils import stratified_shuffle_split_mixid
|
127
|
-
from sonusai.utils import reshape_outputs
|
128
|
-
from sonusai.utils import get_frames_per_batch
|
129
|
-
|
130
|
-
model_base = basename(model_name)
|
131
|
-
model_root = splitext(model_base)[0]
|
132
|
-
|
133
|
-
if batch_size is not None:
|
134
|
-
batch_size = int(batch_size)
|
135
|
-
|
136
|
-
if timesteps is not None:
|
137
|
-
timesteps = int(timesteps)
|
138
|
-
|
139
|
-
output_dir = create_ts_name(model_root)
|
140
|
-
makedirs(output_dir, exist_ok=True)
|
141
|
-
base_name = join(output_dir, model_root)
|
142
|
-
|
143
|
-
# Setup logging file
|
144
|
-
create_file_handler(join(output_dir, 'keras_train.log'))
|
145
|
-
update_console_handler(verbose)
|
146
|
-
initial_log_messages('keras_train')
|
147
|
-
|
148
|
-
logger.info(f'tensorflow {tf.__version__}')
|
149
|
-
logger.info(f'keras {tf.keras.__version__}')
|
150
|
-
logger.info('')
|
151
|
-
|
152
|
-
t_mixdb = MixtureDatabase(t_name)
|
153
|
-
logger.info(f'Training: found {len(t_mixdb.mixtures)} mixtures with {t_mixdb.num_classes} classes from {t_name}')
|
154
|
-
|
155
|
-
v_mixdb = MixtureDatabase(v_name)
|
156
|
-
logger.info(f'Validation: found {len(v_mixdb.mixtures)} mixtures with {v_mixdb.num_classes} classes from {v_name}')
|
157
|
-
|
158
|
-
# Import model definition file
|
159
|
-
logger.info(f'Importing {model_base}')
|
160
|
-
model = import_keras_model(model_name)
|
161
|
-
|
162
|
-
# Check overrides
|
163
|
-
timesteps = check_keras_overrides(model, t_mixdb.feature, t_mixdb.num_classes, timesteps, batch_size)
|
164
|
-
# Calculate batches per epoch, use ceiling as last batch is zero extended
|
165
|
-
frames_per_batch = get_frames_per_batch(batch_size, timesteps)
|
166
|
-
batches_per_epoch = int(np.ceil(t_mixdb.total_feature_frames('*') / frames_per_batch))
|
167
|
-
|
168
|
-
logger.info('Building and compiling model')
|
169
|
-
try:
|
170
|
-
hypermodel = model.MyHyperModel(feature=t_mixdb.feature,
|
171
|
-
num_classes=t_mixdb.num_classes,
|
172
|
-
timesteps=timesteps,
|
173
|
-
batch_size=batch_size)
|
174
|
-
built_model = hypermodel.build_model(kt.HyperParameters())
|
175
|
-
built_model = hypermodel.compile_default(built_model, batches_per_epoch)
|
176
|
-
except Exception as e:
|
177
|
-
logger.exception(f'Error: build_model() in {model_base} failed: {e}')
|
178
|
-
raise SystemExit(1)
|
179
|
-
|
180
|
-
kb.clear_session()
|
181
|
-
logger.info('')
|
182
|
-
built_model.summary(print_fn=logger.info)
|
183
|
-
logger.info('')
|
184
|
-
logger.info(f'feature {hypermodel.feature}')
|
185
|
-
logger.info(f'num_classes {hypermodel.num_classes}')
|
186
|
-
logger.info(f'batch_size {hypermodel.batch_size}')
|
187
|
-
logger.info(f'timesteps {hypermodel.timesteps}')
|
188
|
-
logger.info(f'flatten {hypermodel.flatten}')
|
189
|
-
logger.info(f'add1ch {hypermodel.add1ch}')
|
190
|
-
logger.info(f'truth_mutex {hypermodel.truth_mutex}')
|
191
|
-
logger.info(f'lossf {hypermodel.lossf}')
|
192
|
-
logger.info(f'input_shape {hypermodel.input_shape}')
|
193
|
-
logger.info(f'optimizer {built_model.optimizer.get_config()}')
|
194
|
-
logger.info('')
|
195
|
-
|
196
|
-
t_mixid = t_mixdb.mixids_to_list()
|
197
|
-
v_mixid = v_mixdb.mixids_to_list()
|
198
|
-
|
199
|
-
stratify = False
|
200
|
-
if stratify:
|
201
|
-
logger.info(f'Stratifying training data')
|
202
|
-
t_mixid, _, _, _ = stratified_shuffle_split_mixid(t_mixdb, vsplit=0)
|
203
|
-
|
204
|
-
# Use SonusAI DataGenerator to create validation feature/truth on the fly
|
205
|
-
v_datagen = KerasFromH5(mixdb=v_mixdb,
|
206
|
-
mixids=v_mixid,
|
207
|
-
batch_size=hypermodel.batch_size,
|
208
|
-
timesteps=hypermodel.timesteps,
|
209
|
-
flatten=hypermodel.flatten,
|
210
|
-
add1ch=hypermodel.add1ch,
|
211
|
-
shuffle=False)
|
212
|
-
|
213
|
-
# Prepare class weighting
|
214
|
-
# class_count = np.ceil(np.array(get_class_count_from_mixids(t_mixdb, t_mixid)) / t_mixdb.feature_step_samples)
|
215
|
-
# if t_mixdb.truth_mutex:
|
216
|
-
# other_weight = 16.0
|
217
|
-
# logger.info(f'Detected single-label mode (truth_mutex); setting other weight to {other_weight}')
|
218
|
-
# class_count[-1] = class_count[-1] / other_weight
|
219
|
-
|
220
|
-
# Use SonusAI DataGenerator to create training feature/truth on the fly
|
221
|
-
t_datagen = KerasFromH5(mixdb=t_mixdb,
|
222
|
-
mixids=t_mixid,
|
223
|
-
batch_size=hypermodel.batch_size,
|
224
|
-
timesteps=hypermodel.timesteps,
|
225
|
-
flatten=hypermodel.flatten,
|
226
|
-
add1ch=hypermodel.add1ch,
|
227
|
-
shuffle=True)
|
228
|
-
|
229
|
-
# TODO: If hypermodel.es exists, then use it; otherwise use default here
|
230
|
-
if esp is None:
|
231
|
-
es = EarlyStopping(monitor='val_loss',
|
232
|
-
mode='min',
|
233
|
-
verbose=1,
|
234
|
-
patience=8)
|
235
|
-
else:
|
236
|
-
es = EarlyStopping(monitor='val_loss',
|
237
|
-
mode='min',
|
238
|
-
verbose=1,
|
239
|
-
patience=int(esp))
|
240
|
-
|
241
|
-
ckpt_callback = SonusAIModelCheckpoint(filepath=base_name + '-ckpt-weights.h5',
|
242
|
-
save_weights_only=True,
|
243
|
-
monitor='val_loss',
|
244
|
-
mode='min',
|
245
|
-
save_best_only=True,
|
246
|
-
feature=hypermodel.feature,
|
247
|
-
num_classes=hypermodel.num_classes)
|
248
|
-
|
249
|
-
csv_logger = tf.keras.callbacks.CSVLogger(base_name + '-history.csv')
|
250
|
-
callbacks = [es, ckpt_callback, csv_logger]
|
251
|
-
# loss_batch_log = True
|
252
|
-
loss_batchlogger = None
|
253
|
-
if loss_batch_log is True:
|
254
|
-
loss_batchlogger = LossBatchHistory()
|
255
|
-
callbacks.append(loss_batchlogger)
|
256
|
-
logger.info(f'Adding per batch loss logging to training')
|
257
|
-
|
258
|
-
if weights_name is not None:
|
259
|
-
logger.info(f'Loading weights from {weights_name}')
|
260
|
-
built_model.load_weights(weights_name)
|
261
|
-
|
262
|
-
logger.info('')
|
263
|
-
logger.info(f'Training with no class weighting and early stopping patience = {es.patience}')
|
264
|
-
logger.info(f' training mixtures {len(t_mixid)}')
|
265
|
-
logger.info(f' validation mixtures {len(v_mixid)}')
|
266
|
-
logger.info('')
|
267
|
-
|
268
|
-
history = built_model.fit(t_datagen,
|
269
|
-
batch_size=hypermodel.batch_size,
|
270
|
-
epochs=epochs,
|
271
|
-
validation_data=v_datagen,
|
272
|
-
shuffle=False,
|
273
|
-
callbacks=callbacks)
|
274
|
-
|
275
|
-
# Save history into numpy file
|
276
|
-
history_name = base_name + '-history'
|
277
|
-
np.save(history_name, history.history)
|
278
|
-
# Note: Reload with history=np.load(history_name, allow_pickle='TRUE').item()
|
279
|
-
logger.info(f'Saved training history to numpy file {history_name}.npy')
|
280
|
-
if loss_batch_log is True:
|
281
|
-
his_batch_loss_name = base_name + '-history-lossb.npy'
|
282
|
-
np.save(his_batch_loss_name, loss_batchlogger.history)
|
283
|
-
logger.info(f'Saved per-batch loss history to numpy file {his_batch_loss_name}')
|
284
|
-
|
285
|
-
# Find checkpoint file and load weights for prediction and model save
|
286
|
-
checkpoint_name = None
|
287
|
-
for path, dirs, files in walk(output_dir):
|
288
|
-
for file in files:
|
289
|
-
if "ckpt" in file:
|
290
|
-
checkpoint_name = file
|
291
|
-
|
292
|
-
if checkpoint_name is not None:
|
293
|
-
logger.info('Using best checkpoint for prediction and model exports')
|
294
|
-
built_model.load_weights(join(output_dir, checkpoint_name))
|
295
|
-
else:
|
296
|
-
logger.info('Using last epoch for prediction and model exports')
|
297
|
-
|
298
|
-
# save for later model export(s)
|
299
|
-
weight_name = base_name + '.h5'
|
300
|
-
built_model.save(weight_name)
|
301
|
-
with h5py.File(weight_name, 'a') as f:
|
302
|
-
f.attrs['sonusai_feature'] = hypermodel.feature
|
303
|
-
f.attrs['sonusai_num_classes'] = str(hypermodel.num_classes)
|
304
|
-
logger.info(f'Saved trained model to {weight_name}')
|
305
|
-
|
306
|
-
# Compute prediction metrics on validation data using the best checkpoint
|
307
|
-
v_predict = built_model.predict(v_datagen, batch_size=hypermodel.batch_size, verbose=1)
|
308
|
-
v_predict, _ = reshape_outputs(predict=v_predict, timesteps=hypermodel.timesteps)
|
309
|
-
|
310
|
-
# Write data to separate files
|
311
|
-
v_predict_dir = base_name + '-valpredict'
|
312
|
-
makedirs(v_predict_dir, exist_ok=True)
|
313
|
-
for idx, mixid in enumerate(v_mixid):
|
314
|
-
output_name = join(v_predict_dir, v_mixdb.mixtures[mixid].name)
|
315
|
-
indices = v_datagen.file_indices[idx]
|
316
|
-
frames = indices.stop - indices.start
|
317
|
-
data = v_predict[indices]
|
318
|
-
# The predict operation may produce less data due to timesteps and batches may not dividing evenly
|
319
|
-
# Only write data if it exists
|
320
|
-
if data.shape[0] == frames:
|
321
|
-
with h5py.File(output_name, 'a') as f:
|
322
|
-
if 'predict' in f:
|
323
|
-
del f['predict']
|
324
|
-
f.create_dataset('predict', data=data)
|
325
|
-
|
326
|
-
logger.info(f'Wrote validation predict data to {v_predict_dir}')
|
327
|
-
|
328
|
-
|
329
|
-
if __name__ == '__main__':
|
330
|
-
try:
|
331
|
-
main()
|
332
|
-
except KeyboardInterrupt:
|
333
|
-
logger.info('Canceled due to keyboard interrupt')
|
334
|
-
exit()
|
sonusai/torchl_onnx.py
DELETED
@@ -1,216 +0,0 @@
|
|
1
|
-
"""sonusai torchl_onnx
|
2
|
-
|
3
|
-
usage: torchl_onnx [-hv] [-b BATCH] [-t TSTEPS] [-o OUTPUT] MODEL CKPT
|
4
|
-
|
5
|
-
options:
|
6
|
-
-h, --help
|
7
|
-
-v, --verbose Be verbose
|
8
|
-
-b BATCH, --batch BATCH Batch size [default: 1]
|
9
|
-
-t TSTEPS, --tsteps TSTEPS Timesteps [default: 1]
|
10
|
-
-o OUTPUT, --output OUTPUT Output directory.
|
11
|
-
|
12
|
-
Convert a trained Pytorch Lightning model to ONNX. The model is specified as an
|
13
|
-
sctl_*.py model file (sctl: sonusai custom torch lightning) and a checkpoint file
|
14
|
-
for loading weights.
|
15
|
-
|
16
|
-
Inputs:
|
17
|
-
MODEL SonusAI Python custom model file.
|
18
|
-
CKPT A Pytorch Lightning checkpoint file
|
19
|
-
BATCH Batch size used in onnx conversion, overrides value in model ckpt. Defaults to 1.
|
20
|
-
TSTEPS Timestep dimension size using in onnx conversion, overrides value in model ckpt if
|
21
|
-
the model has a timestep dimension. Else it is ignored.
|
22
|
-
|
23
|
-
Outputs:
|
24
|
-
OUTPUT/ A directory containing:
|
25
|
-
<CKPT>.onnx Model file with batch_size and timesteps equal to provided parameters
|
26
|
-
<CKPT>-b1.onnx Model file with batch_size=1 and if the timesteps dimension exists it
|
27
|
-
is set to 1 (useful for real-time inference applications)
|
28
|
-
torchl_onnx.log
|
29
|
-
|
30
|
-
Results are written into subdirectory <MODEL>-<TIMESTAMP> unless OUTPUT is specified.
|
31
|
-
|
32
|
-
"""
|
33
|
-
from sonusai import logger
|
34
|
-
|
35
|
-
|
36
|
-
def main() -> None:
|
37
|
-
from docopt import docopt
|
38
|
-
|
39
|
-
import sonusai
|
40
|
-
from sonusai.utils import trim_docstring
|
41
|
-
|
42
|
-
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
43
|
-
|
44
|
-
verbose = args['--verbose']
|
45
|
-
batch_size = args['--batch']
|
46
|
-
timesteps = args['--tsteps']
|
47
|
-
model_path = args['MODEL']
|
48
|
-
ckpt_path = args['CKPT']
|
49
|
-
output_dir = args['--output']
|
50
|
-
|
51
|
-
from os import makedirs
|
52
|
-
from os.path import basename, splitext
|
53
|
-
from sonusai.utils import import_keras_model
|
54
|
-
|
55
|
-
# Import model definition file first to check
|
56
|
-
model_base = basename(model_path)
|
57
|
-
model_root = splitext(model_base)[0]
|
58
|
-
logger.info(f'Importing model from {model_base}')
|
59
|
-
try:
|
60
|
-
litemodule = import_keras_model(model_path) # note works for pytorch lightning as well as keras
|
61
|
-
except Exception as e:
|
62
|
-
logger.exception(f'Error: could not import model from {model_path}: {e}')
|
63
|
-
raise SystemExit(1)
|
64
|
-
|
65
|
-
# Load checkpoint first to get hparams if available
|
66
|
-
from torch import load as load
|
67
|
-
ckpt_base = basename(ckpt_path)
|
68
|
-
ckpt_root = splitext(ckpt_base)[0]
|
69
|
-
logger.info(f'Loading checkpoint from {ckpt_base}')
|
70
|
-
try:
|
71
|
-
checkpoint = load(ckpt_path, map_location=lambda storage, loc: storage)
|
72
|
-
except Exception as e:
|
73
|
-
logger.exception(f'Error: could not load checkpoint from {ckpt_path}: {e}')
|
74
|
-
raise SystemExit(1)
|
75
|
-
|
76
|
-
from os.path import join, isdir, dirname, exists
|
77
|
-
from sonusai import create_file_handler
|
78
|
-
from sonusai import initial_log_messages
|
79
|
-
from sonusai import update_console_handler
|
80
|
-
from torch import randn
|
81
|
-
from sonusai.utils import create_ts_name
|
82
|
-
|
83
|
-
from sonusai.utils import create_ts_name
|
84
|
-
from torchinfo import summary
|
85
|
-
|
86
|
-
if batch_size is not None:
|
87
|
-
batch_size = int(batch_size)
|
88
|
-
if batch_size != 1:
|
89
|
-
batch_size = 1
|
90
|
-
logger.info(f'For now prediction only supports batch_size = 1, forcing it to 1 now')
|
91
|
-
|
92
|
-
if timesteps is not None:
|
93
|
-
timesteps = int(timesteps)
|
94
|
-
|
95
|
-
if output_dir is None:
|
96
|
-
output_dir = dirname(ckpt_path)
|
97
|
-
else:
|
98
|
-
if not isdir(output_dir):
|
99
|
-
makedirs(output_dir, exist_ok=True)
|
100
|
-
|
101
|
-
ofname = join(output_dir, ckpt_root + '.onnx')
|
102
|
-
# First try, then add date
|
103
|
-
if exists(ofname):
|
104
|
-
# add hour-min-sec if necessary
|
105
|
-
from datetime import datetime
|
106
|
-
ts = datetime.now()
|
107
|
-
ofname = join(output_dir, ckpt_root + '-' + ts.strftime('%Y%m%d') + '.onnx')
|
108
|
-
ofname_root = splitext(ofname)[0]
|
109
|
-
|
110
|
-
# Setup logging file
|
111
|
-
create_file_handler(ofname_root + '-onnx.log')
|
112
|
-
update_console_handler(verbose)
|
113
|
-
initial_log_messages('torchl_onnx')
|
114
|
-
logger.info(f'Imported model from {model_base}')
|
115
|
-
logger.info(f'Loaded checkpoint from {ckpt_base}')
|
116
|
-
|
117
|
-
if 'hyper_parameters' in checkpoint:
|
118
|
-
hparams = checkpoint['hyper_parameters']
|
119
|
-
logger.info(f'Found hyper-params on checkpoint named {checkpoint["hparams_name"]} '
|
120
|
-
f'with {len(hparams)} total hparams.')
|
121
|
-
if batch_size is not None and hparams['batch_size'] != batch_size:
|
122
|
-
if batch_size != 1:
|
123
|
-
batch_size = 1
|
124
|
-
logger.info(f'For now prediction only supports batch_size = 1, forcing it to 1 now')
|
125
|
-
logger.info(f'Overriding batch_size: default = {hparams["batch_size"]}; specified = {batch_size}.')
|
126
|
-
hparams["batch_size"] = batch_size
|
127
|
-
|
128
|
-
if timesteps is not None:
|
129
|
-
if hparams['timesteps'] == 0 and timesteps != 0:
|
130
|
-
logger.warning(f'Model does not contain timesteps; ignoring override.')
|
131
|
-
timesteps = 0
|
132
|
-
|
133
|
-
if hparams['timesteps'] != 0 and timesteps == 0:
|
134
|
-
logger.warning(f'Model contains timesteps; ignoring override of 0, using model default.')
|
135
|
-
timesteps = hparams['timesteps']
|
136
|
-
|
137
|
-
if hparams['timesteps'] != timesteps:
|
138
|
-
logger.info(f'Overriding timesteps: default = {hparams["timesteps"]}; specified = {timesteps}.')
|
139
|
-
hparams['timesteps'] = timesteps
|
140
|
-
|
141
|
-
logger.info(f'Building model with hparams and batch_size={batch_size}, timesteps={timesteps}')
|
142
|
-
try:
|
143
|
-
model = litemodule.MyHyperModel(**hparams) # use hparams
|
144
|
-
# litemodule.MyHyperModel.load_from_checkpoint(ckpt_name, **hparams)
|
145
|
-
except Exception as e:
|
146
|
-
logger.exception(f'Error: model build (MyHyperModel) in {model_base} failed: {e}')
|
147
|
-
raise SystemExit(1)
|
148
|
-
else:
|
149
|
-
logger.info(f'Warning: found checkpoint with no hyper-parameters, building model with defaults')
|
150
|
-
try:
|
151
|
-
tmp = litemodule.MyHyperModel() # use default hparams
|
152
|
-
except Exception as e:
|
153
|
-
logger.exception(f'Error: model build (MyHyperModel) in {model_base} failed: {e}')
|
154
|
-
raise SystemExit(1)
|
155
|
-
|
156
|
-
if batch_size is not None:
|
157
|
-
if tmp.batch_size != batch_size:
|
158
|
-
logger.info(f'Overriding batch_size: default = {tmp.batch_size}; specified = {batch_size}.')
|
159
|
-
else:
|
160
|
-
batch_size = tmp.batch_size # inherit
|
161
|
-
|
162
|
-
if timesteps is not None:
|
163
|
-
if tmp.timesteps == 0 and timesteps != 0:
|
164
|
-
logger.warning(f'Model does not contain timesteps; ignoring override.')
|
165
|
-
timesteps = 0
|
166
|
-
|
167
|
-
if tmp.timesteps != 0 and timesteps == 0:
|
168
|
-
logger.warning(f'Model contains timesteps; ignoring override.')
|
169
|
-
timesteps = tmp.timesteps
|
170
|
-
|
171
|
-
if tmp.timesteps != timesteps:
|
172
|
-
logger.info(f'Overriding timesteps: default = {tmp.timesteps}; specified = {timesteps}.')
|
173
|
-
else:
|
174
|
-
timesteps = tmp.timesteps
|
175
|
-
|
176
|
-
logger.info(f'Building model with default hparams and batch_size= {batch_size}, timesteps={timesteps}')
|
177
|
-
model = litemodule.MyHyperModel(timesteps=timesteps, batch_size=batch_size)
|
178
|
-
|
179
|
-
logger.info('')
|
180
|
-
# logger.info(summary(model))
|
181
|
-
# from lightning.pytorch import Trainer
|
182
|
-
# from lightning.pytorch.callbacks import ModelSummary
|
183
|
-
# trainer = Trainer(callbacks=[ModelSummary(max_depth=2)])
|
184
|
-
# logger.info(trainer.summarize())
|
185
|
-
logger.info('')
|
186
|
-
logger.info(f'feature {model.hparams.feature}')
|
187
|
-
logger.info(f'num_classes {model.num_classes}')
|
188
|
-
logger.info(f'batch_size {model.hparams.batch_size}')
|
189
|
-
logger.info(f'timesteps {model.hparams.timesteps}')
|
190
|
-
logger.info(f'flatten {model.flatten}')
|
191
|
-
logger.info(f'add1ch {model.add1ch}')
|
192
|
-
logger.info(f'truth_mutex {model.truth_mutex}')
|
193
|
-
logger.info(f'input_shape {model.input_shape}')
|
194
|
-
logger.info('')
|
195
|
-
logger.info(f'Loading weights from {ckpt_base}')
|
196
|
-
# model = model.load_from_checkpoint(ckpt_path) # weights only, has problems - needs investigation
|
197
|
-
model.load_state_dict(checkpoint["state_dict"])
|
198
|
-
model.eval()
|
199
|
-
insample_shape = model.input_shape
|
200
|
-
insample_shape.insert(0, batch_size)
|
201
|
-
input_sample = randn(insample_shape)
|
202
|
-
logger.info(f'Creating onnx model ...')
|
203
|
-
for m in model.modules():
|
204
|
-
if 'instancenorm' in m.__class__.__name__.lower():
|
205
|
-
logger.info(f'Forcing train=false for instancenorm instance {m}, {m.__class__.__name__.lower()}')
|
206
|
-
m.train(False)
|
207
|
-
# m.track_running_stats=True # has problems
|
208
|
-
model.to_onnx(file_path=ofname, input_sample=input_sample, export_params=True)
|
209
|
-
|
210
|
-
|
211
|
-
if __name__ == '__main__':
|
212
|
-
try:
|
213
|
-
main()
|
214
|
-
except KeyboardInterrupt:
|
215
|
-
logger.info('Canceled due to keyboard interrupt')
|
216
|
-
exit()
|