sonusai 0.15.9__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. sonusai/__init__.py +36 -4
  2. sonusai/audiofe.py +111 -106
  3. sonusai/calc_metric_spenh.py +38 -22
  4. sonusai/genft.py +15 -6
  5. sonusai/genmix.py +14 -6
  6. sonusai/genmixdb.py +15 -7
  7. sonusai/gentcst.py +13 -6
  8. sonusai/lsdb.py +15 -5
  9. sonusai/main.py +58 -61
  10. sonusai/mixture/__init__.py +1 -0
  11. sonusai/mixture/config.py +1 -2
  12. sonusai/mkmanifest.py +43 -8
  13. sonusai/mkwav.py +15 -6
  14. sonusai/onnx_predict.py +16 -6
  15. sonusai/plot.py +16 -6
  16. sonusai/post_spenh_targetf.py +13 -6
  17. sonusai/summarize_metric_spenh.py +71 -0
  18. sonusai/tplot.py +14 -6
  19. sonusai/utils/__init__.py +4 -7
  20. sonusai/utils/asl_p56.py +3 -3
  21. sonusai/utils/asr.py +35 -8
  22. sonusai/utils/asr_functions/__init__.py +0 -5
  23. sonusai/utils/asr_functions/aaware_whisper.py +2 -2
  24. sonusai/utils/asr_manifest_functions/__init__.py +1 -0
  25. sonusai/utils/asr_manifest_functions/mcgill_speech.py +29 -0
  26. sonusai/utils/{trim_docstring.py → docstring.py} +20 -0
  27. sonusai/utils/model_utils.py +30 -0
  28. sonusai/utils/onnx_utils.py +19 -45
  29. {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/METADATA +7 -25
  30. {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/RECORD +32 -46
  31. sonusai/data_generator/__init__.py +0 -5
  32. sonusai/data_generator/dataset_from_mixdb.py +0 -143
  33. sonusai/data_generator/keras_from_mixdb.py +0 -169
  34. sonusai/data_generator/torch_from_mixdb.py +0 -122
  35. sonusai/keras_onnx.py +0 -86
  36. sonusai/keras_predict.py +0 -231
  37. sonusai/keras_train.py +0 -334
  38. sonusai/torchl_onnx.py +0 -216
  39. sonusai/torchl_predict.py +0 -542
  40. sonusai/torchl_train.py +0 -223
  41. sonusai/utils/asr_functions/aixplain_whisper.py +0 -59
  42. sonusai/utils/asr_functions/data.py +0 -16
  43. sonusai/utils/asr_functions/deepgram.py +0 -97
  44. sonusai/utils/asr_functions/fastwhisper.py +0 -90
  45. sonusai/utils/asr_functions/google.py +0 -95
  46. sonusai/utils/asr_functions/whisper.py +0 -49
  47. sonusai/utils/keras_utils.py +0 -226
  48. {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/WHEEL +0 -0
  49. {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/entry_points.txt +0 -0
sonusai/keras_train.py DELETED
@@ -1,334 +0,0 @@
1
- """sonusai keras_train
2
-
3
- usage: keras_train [-hgv] (-m MODEL) (-l VLOC) [-w KMODEL] [-e EPOCHS] [-b BATCH] [-t TSTEPS] [-p ESP] TLOC
4
-
5
- options:
6
- -h, --help
7
- -v, --verbose Be verbose.
8
- -m MODEL, --model MODEL Model Python file with build and/or hypermodel functions.
9
- -l VLOC, --vloc VLOC Location of SonusAI mixture database to use for validation.
10
- -w KMODEL, --weights KMODEL Keras model weights file.
11
- -e EPOCHS, --epochs EPOCHS Number of epochs to use in training. [default: 8].
12
- -b BATCH, --batch BATCH Batch size.
13
- -t TSTEPS, --tsteps TSTEPS Timesteps.
14
- -p ESP, --patience ESP Early stopping patience.
15
- -g, --loss-batch-log Enable per-batch loss log. [default: False]
16
-
17
- Use Keras to train a model defined by a Python definition file and SonusAI genft data.
18
-
19
- Inputs:
20
- TLOC A SonusAI mixture database directory to use for training data.
21
- VLOC A SonusAI mixture database directory to use for validation data.
22
-
23
- Results are written into subdirectory <MODEL>-<TIMESTAMP>.
24
- Per-batch loss history, if enabled, is written to <basename>-history-lossb.npy
25
-
26
- """
27
- import tensorflow as tf
28
-
29
- from sonusai import logger
30
-
31
-
32
- class LossBatchHistory(tf.keras.callbacks.Callback):
33
- def __init__(self):
34
- super().__init__()
35
- self.history = None
36
-
37
- def on_train_begin(self, logs=None):
38
- self.history = {'loss': []}
39
-
40
- def on_batch_end(self, batch, logs=None):
41
- if logs is None:
42
- logs = {}
43
- self.history['loss'].append(logs.get('loss'))
44
-
45
-
46
- class SonusAIModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
47
- def __init__(self,
48
- filepath,
49
- monitor: str = "val_loss",
50
- verbose: int = 0,
51
- save_best_only: bool = False,
52
- save_weights_only: bool = False,
53
- mode: str = "auto",
54
- save_freq="epoch",
55
- options=None,
56
- initial_value_threshold=None,
57
- **kwargs):
58
- super().__init__(filepath,
59
- monitor,
60
- verbose,
61
- save_best_only,
62
- save_weights_only,
63
- mode,
64
- save_freq,
65
- options,
66
- initial_value_threshold,
67
- **kwargs)
68
- self.feature = kwargs.get('feature', None)
69
- self.num_classes = kwargs.get('num_classes', None)
70
-
71
- def _save_model(self, epoch, batch, logs):
72
- import h5py
73
-
74
- super()._save_model(epoch, batch, logs)
75
-
76
- with h5py.File(self.filepath, 'a') as f:
77
- if self.feature is not None:
78
- f.attrs['sonusai_feature'] = self.feature
79
- if self.num_classes is not None:
80
- f.attrs['sonusai_num_classes'] = str(self.num_classes)
81
-
82
-
83
- def main() -> None:
84
- from docopt import docopt
85
-
86
- import sonusai
87
- from sonusai.utils import trim_docstring
88
-
89
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
90
-
91
- verbose = args['--verbose']
92
- model_name = args['--model']
93
- weights_name = args['--weights']
94
- v_name = args['--vloc']
95
- epochs = int(args['--epochs'])
96
- batch_size = args['--batch']
97
- timesteps = args['--tsteps']
98
- esp = args['--patience']
99
- loss_batch_log = args['--loss-batch-log']
100
- t_name = args['TLOC']
101
-
102
- import warnings
103
- from os import makedirs
104
- from os import walk
105
- from os.path import basename
106
- from os.path import join
107
- from os.path import splitext
108
-
109
- import h5py
110
- import keras_tuner as kt
111
- import numpy as np
112
-
113
- with warnings.catch_warnings():
114
- warnings.simplefilter('ignore')
115
- from keras import backend as kb
116
- from keras.callbacks import EarlyStopping
117
-
118
- from sonusai import create_file_handler
119
- from sonusai import initial_log_messages
120
- from sonusai import update_console_handler
121
- from sonusai.data_generator import KerasFromH5
122
- from sonusai.mixture import MixtureDatabase
123
- from sonusai.utils import check_keras_overrides
124
- from sonusai.utils import create_ts_name
125
- from sonusai.utils import import_keras_model
126
- from sonusai.utils import stratified_shuffle_split_mixid
127
- from sonusai.utils import reshape_outputs
128
- from sonusai.utils import get_frames_per_batch
129
-
130
- model_base = basename(model_name)
131
- model_root = splitext(model_base)[0]
132
-
133
- if batch_size is not None:
134
- batch_size = int(batch_size)
135
-
136
- if timesteps is not None:
137
- timesteps = int(timesteps)
138
-
139
- output_dir = create_ts_name(model_root)
140
- makedirs(output_dir, exist_ok=True)
141
- base_name = join(output_dir, model_root)
142
-
143
- # Setup logging file
144
- create_file_handler(join(output_dir, 'keras_train.log'))
145
- update_console_handler(verbose)
146
- initial_log_messages('keras_train')
147
-
148
- logger.info(f'tensorflow {tf.__version__}')
149
- logger.info(f'keras {tf.keras.__version__}')
150
- logger.info('')
151
-
152
- t_mixdb = MixtureDatabase(t_name)
153
- logger.info(f'Training: found {len(t_mixdb.mixtures)} mixtures with {t_mixdb.num_classes} classes from {t_name}')
154
-
155
- v_mixdb = MixtureDatabase(v_name)
156
- logger.info(f'Validation: found {len(v_mixdb.mixtures)} mixtures with {v_mixdb.num_classes} classes from {v_name}')
157
-
158
- # Import model definition file
159
- logger.info(f'Importing {model_base}')
160
- model = import_keras_model(model_name)
161
-
162
- # Check overrides
163
- timesteps = check_keras_overrides(model, t_mixdb.feature, t_mixdb.num_classes, timesteps, batch_size)
164
- # Calculate batches per epoch, use ceiling as last batch is zero extended
165
- frames_per_batch = get_frames_per_batch(batch_size, timesteps)
166
- batches_per_epoch = int(np.ceil(t_mixdb.total_feature_frames('*') / frames_per_batch))
167
-
168
- logger.info('Building and compiling model')
169
- try:
170
- hypermodel = model.MyHyperModel(feature=t_mixdb.feature,
171
- num_classes=t_mixdb.num_classes,
172
- timesteps=timesteps,
173
- batch_size=batch_size)
174
- built_model = hypermodel.build_model(kt.HyperParameters())
175
- built_model = hypermodel.compile_default(built_model, batches_per_epoch)
176
- except Exception as e:
177
- logger.exception(f'Error: build_model() in {model_base} failed: {e}')
178
- raise SystemExit(1)
179
-
180
- kb.clear_session()
181
- logger.info('')
182
- built_model.summary(print_fn=logger.info)
183
- logger.info('')
184
- logger.info(f'feature {hypermodel.feature}')
185
- logger.info(f'num_classes {hypermodel.num_classes}')
186
- logger.info(f'batch_size {hypermodel.batch_size}')
187
- logger.info(f'timesteps {hypermodel.timesteps}')
188
- logger.info(f'flatten {hypermodel.flatten}')
189
- logger.info(f'add1ch {hypermodel.add1ch}')
190
- logger.info(f'truth_mutex {hypermodel.truth_mutex}')
191
- logger.info(f'lossf {hypermodel.lossf}')
192
- logger.info(f'input_shape {hypermodel.input_shape}')
193
- logger.info(f'optimizer {built_model.optimizer.get_config()}')
194
- logger.info('')
195
-
196
- t_mixid = t_mixdb.mixids_to_list()
197
- v_mixid = v_mixdb.mixids_to_list()
198
-
199
- stratify = False
200
- if stratify:
201
- logger.info(f'Stratifying training data')
202
- t_mixid, _, _, _ = stratified_shuffle_split_mixid(t_mixdb, vsplit=0)
203
-
204
- # Use SonusAI DataGenerator to create validation feature/truth on the fly
205
- v_datagen = KerasFromH5(mixdb=v_mixdb,
206
- mixids=v_mixid,
207
- batch_size=hypermodel.batch_size,
208
- timesteps=hypermodel.timesteps,
209
- flatten=hypermodel.flatten,
210
- add1ch=hypermodel.add1ch,
211
- shuffle=False)
212
-
213
- # Prepare class weighting
214
- # class_count = np.ceil(np.array(get_class_count_from_mixids(t_mixdb, t_mixid)) / t_mixdb.feature_step_samples)
215
- # if t_mixdb.truth_mutex:
216
- # other_weight = 16.0
217
- # logger.info(f'Detected single-label mode (truth_mutex); setting other weight to {other_weight}')
218
- # class_count[-1] = class_count[-1] / other_weight
219
-
220
- # Use SonusAI DataGenerator to create training feature/truth on the fly
221
- t_datagen = KerasFromH5(mixdb=t_mixdb,
222
- mixids=t_mixid,
223
- batch_size=hypermodel.batch_size,
224
- timesteps=hypermodel.timesteps,
225
- flatten=hypermodel.flatten,
226
- add1ch=hypermodel.add1ch,
227
- shuffle=True)
228
-
229
- # TODO: If hypermodel.es exists, then use it; otherwise use default here
230
- if esp is None:
231
- es = EarlyStopping(monitor='val_loss',
232
- mode='min',
233
- verbose=1,
234
- patience=8)
235
- else:
236
- es = EarlyStopping(monitor='val_loss',
237
- mode='min',
238
- verbose=1,
239
- patience=int(esp))
240
-
241
- ckpt_callback = SonusAIModelCheckpoint(filepath=base_name + '-ckpt-weights.h5',
242
- save_weights_only=True,
243
- monitor='val_loss',
244
- mode='min',
245
- save_best_only=True,
246
- feature=hypermodel.feature,
247
- num_classes=hypermodel.num_classes)
248
-
249
- csv_logger = tf.keras.callbacks.CSVLogger(base_name + '-history.csv')
250
- callbacks = [es, ckpt_callback, csv_logger]
251
- # loss_batch_log = True
252
- loss_batchlogger = None
253
- if loss_batch_log is True:
254
- loss_batchlogger = LossBatchHistory()
255
- callbacks.append(loss_batchlogger)
256
- logger.info(f'Adding per batch loss logging to training')
257
-
258
- if weights_name is not None:
259
- logger.info(f'Loading weights from {weights_name}')
260
- built_model.load_weights(weights_name)
261
-
262
- logger.info('')
263
- logger.info(f'Training with no class weighting and early stopping patience = {es.patience}')
264
- logger.info(f' training mixtures {len(t_mixid)}')
265
- logger.info(f' validation mixtures {len(v_mixid)}')
266
- logger.info('')
267
-
268
- history = built_model.fit(t_datagen,
269
- batch_size=hypermodel.batch_size,
270
- epochs=epochs,
271
- validation_data=v_datagen,
272
- shuffle=False,
273
- callbacks=callbacks)
274
-
275
- # Save history into numpy file
276
- history_name = base_name + '-history'
277
- np.save(history_name, history.history)
278
- # Note: Reload with history=np.load(history_name, allow_pickle='TRUE').item()
279
- logger.info(f'Saved training history to numpy file {history_name}.npy')
280
- if loss_batch_log is True:
281
- his_batch_loss_name = base_name + '-history-lossb.npy'
282
- np.save(his_batch_loss_name, loss_batchlogger.history)
283
- logger.info(f'Saved per-batch loss history to numpy file {his_batch_loss_name}')
284
-
285
- # Find checkpoint file and load weights for prediction and model save
286
- checkpoint_name = None
287
- for path, dirs, files in walk(output_dir):
288
- for file in files:
289
- if "ckpt" in file:
290
- checkpoint_name = file
291
-
292
- if checkpoint_name is not None:
293
- logger.info('Using best checkpoint for prediction and model exports')
294
- built_model.load_weights(join(output_dir, checkpoint_name))
295
- else:
296
- logger.info('Using last epoch for prediction and model exports')
297
-
298
- # save for later model export(s)
299
- weight_name = base_name + '.h5'
300
- built_model.save(weight_name)
301
- with h5py.File(weight_name, 'a') as f:
302
- f.attrs['sonusai_feature'] = hypermodel.feature
303
- f.attrs['sonusai_num_classes'] = str(hypermodel.num_classes)
304
- logger.info(f'Saved trained model to {weight_name}')
305
-
306
- # Compute prediction metrics on validation data using the best checkpoint
307
- v_predict = built_model.predict(v_datagen, batch_size=hypermodel.batch_size, verbose=1)
308
- v_predict, _ = reshape_outputs(predict=v_predict, timesteps=hypermodel.timesteps)
309
-
310
- # Write data to separate files
311
- v_predict_dir = base_name + '-valpredict'
312
- makedirs(v_predict_dir, exist_ok=True)
313
- for idx, mixid in enumerate(v_mixid):
314
- output_name = join(v_predict_dir, v_mixdb.mixtures[mixid].name)
315
- indices = v_datagen.file_indices[idx]
316
- frames = indices.stop - indices.start
317
- data = v_predict[indices]
318
- # The predict operation may produce less data due to timesteps and batches may not dividing evenly
319
- # Only write data if it exists
320
- if data.shape[0] == frames:
321
- with h5py.File(output_name, 'a') as f:
322
- if 'predict' in f:
323
- del f['predict']
324
- f.create_dataset('predict', data=data)
325
-
326
- logger.info(f'Wrote validation predict data to {v_predict_dir}')
327
-
328
-
329
- if __name__ == '__main__':
330
- try:
331
- main()
332
- except KeyboardInterrupt:
333
- logger.info('Canceled due to keyboard interrupt')
334
- exit()
sonusai/torchl_onnx.py DELETED
@@ -1,216 +0,0 @@
1
- """sonusai torchl_onnx
2
-
3
- usage: torchl_onnx [-hv] [-b BATCH] [-t TSTEPS] [-o OUTPUT] MODEL CKPT
4
-
5
- options:
6
- -h, --help
7
- -v, --verbose Be verbose
8
- -b BATCH, --batch BATCH Batch size [default: 1]
9
- -t TSTEPS, --tsteps TSTEPS Timesteps [default: 1]
10
- -o OUTPUT, --output OUTPUT Output directory.
11
-
12
- Convert a trained Pytorch Lightning model to ONNX. The model is specified as an
13
- sctl_*.py model file (sctl: sonusai custom torch lightning) and a checkpoint file
14
- for loading weights.
15
-
16
- Inputs:
17
- MODEL SonusAI Python custom model file.
18
- CKPT A Pytorch Lightning checkpoint file
19
- BATCH Batch size used in onnx conversion, overrides value in model ckpt. Defaults to 1.
20
- TSTEPS Timestep dimension size using in onnx conversion, overrides value in model ckpt if
21
- the model has a timestep dimension. Else it is ignored.
22
-
23
- Outputs:
24
- OUTPUT/ A directory containing:
25
- <CKPT>.onnx Model file with batch_size and timesteps equal to provided parameters
26
- <CKPT>-b1.onnx Model file with batch_size=1 and if the timesteps dimension exists it
27
- is set to 1 (useful for real-time inference applications)
28
- torchl_onnx.log
29
-
30
- Results are written into subdirectory <MODEL>-<TIMESTAMP> unless OUTPUT is specified.
31
-
32
- """
33
- from sonusai import logger
34
-
35
-
36
- def main() -> None:
37
- from docopt import docopt
38
-
39
- import sonusai
40
- from sonusai.utils import trim_docstring
41
-
42
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
43
-
44
- verbose = args['--verbose']
45
- batch_size = args['--batch']
46
- timesteps = args['--tsteps']
47
- model_path = args['MODEL']
48
- ckpt_path = args['CKPT']
49
- output_dir = args['--output']
50
-
51
- from os import makedirs
52
- from os.path import basename, splitext
53
- from sonusai.utils import import_keras_model
54
-
55
- # Import model definition file first to check
56
- model_base = basename(model_path)
57
- model_root = splitext(model_base)[0]
58
- logger.info(f'Importing model from {model_base}')
59
- try:
60
- litemodule = import_keras_model(model_path) # note works for pytorch lightning as well as keras
61
- except Exception as e:
62
- logger.exception(f'Error: could not import model from {model_path}: {e}')
63
- raise SystemExit(1)
64
-
65
- # Load checkpoint first to get hparams if available
66
- from torch import load as load
67
- ckpt_base = basename(ckpt_path)
68
- ckpt_root = splitext(ckpt_base)[0]
69
- logger.info(f'Loading checkpoint from {ckpt_base}')
70
- try:
71
- checkpoint = load(ckpt_path, map_location=lambda storage, loc: storage)
72
- except Exception as e:
73
- logger.exception(f'Error: could not load checkpoint from {ckpt_path}: {e}')
74
- raise SystemExit(1)
75
-
76
- from os.path import join, isdir, dirname, exists
77
- from sonusai import create_file_handler
78
- from sonusai import initial_log_messages
79
- from sonusai import update_console_handler
80
- from torch import randn
81
- from sonusai.utils import create_ts_name
82
-
83
- from sonusai.utils import create_ts_name
84
- from torchinfo import summary
85
-
86
- if batch_size is not None:
87
- batch_size = int(batch_size)
88
- if batch_size != 1:
89
- batch_size = 1
90
- logger.info(f'For now prediction only supports batch_size = 1, forcing it to 1 now')
91
-
92
- if timesteps is not None:
93
- timesteps = int(timesteps)
94
-
95
- if output_dir is None:
96
- output_dir = dirname(ckpt_path)
97
- else:
98
- if not isdir(output_dir):
99
- makedirs(output_dir, exist_ok=True)
100
-
101
- ofname = join(output_dir, ckpt_root + '.onnx')
102
- # First try, then add date
103
- if exists(ofname):
104
- # add hour-min-sec if necessary
105
- from datetime import datetime
106
- ts = datetime.now()
107
- ofname = join(output_dir, ckpt_root + '-' + ts.strftime('%Y%m%d') + '.onnx')
108
- ofname_root = splitext(ofname)[0]
109
-
110
- # Setup logging file
111
- create_file_handler(ofname_root + '-onnx.log')
112
- update_console_handler(verbose)
113
- initial_log_messages('torchl_onnx')
114
- logger.info(f'Imported model from {model_base}')
115
- logger.info(f'Loaded checkpoint from {ckpt_base}')
116
-
117
- if 'hyper_parameters' in checkpoint:
118
- hparams = checkpoint['hyper_parameters']
119
- logger.info(f'Found hyper-params on checkpoint named {checkpoint["hparams_name"]} '
120
- f'with {len(hparams)} total hparams.')
121
- if batch_size is not None and hparams['batch_size'] != batch_size:
122
- if batch_size != 1:
123
- batch_size = 1
124
- logger.info(f'For now prediction only supports batch_size = 1, forcing it to 1 now')
125
- logger.info(f'Overriding batch_size: default = {hparams["batch_size"]}; specified = {batch_size}.')
126
- hparams["batch_size"] = batch_size
127
-
128
- if timesteps is not None:
129
- if hparams['timesteps'] == 0 and timesteps != 0:
130
- logger.warning(f'Model does not contain timesteps; ignoring override.')
131
- timesteps = 0
132
-
133
- if hparams['timesteps'] != 0 and timesteps == 0:
134
- logger.warning(f'Model contains timesteps; ignoring override of 0, using model default.')
135
- timesteps = hparams['timesteps']
136
-
137
- if hparams['timesteps'] != timesteps:
138
- logger.info(f'Overriding timesteps: default = {hparams["timesteps"]}; specified = {timesteps}.')
139
- hparams['timesteps'] = timesteps
140
-
141
- logger.info(f'Building model with hparams and batch_size={batch_size}, timesteps={timesteps}')
142
- try:
143
- model = litemodule.MyHyperModel(**hparams) # use hparams
144
- # litemodule.MyHyperModel.load_from_checkpoint(ckpt_name, **hparams)
145
- except Exception as e:
146
- logger.exception(f'Error: model build (MyHyperModel) in {model_base} failed: {e}')
147
- raise SystemExit(1)
148
- else:
149
- logger.info(f'Warning: found checkpoint with no hyper-parameters, building model with defaults')
150
- try:
151
- tmp = litemodule.MyHyperModel() # use default hparams
152
- except Exception as e:
153
- logger.exception(f'Error: model build (MyHyperModel) in {model_base} failed: {e}')
154
- raise SystemExit(1)
155
-
156
- if batch_size is not None:
157
- if tmp.batch_size != batch_size:
158
- logger.info(f'Overriding batch_size: default = {tmp.batch_size}; specified = {batch_size}.')
159
- else:
160
- batch_size = tmp.batch_size # inherit
161
-
162
- if timesteps is not None:
163
- if tmp.timesteps == 0 and timesteps != 0:
164
- logger.warning(f'Model does not contain timesteps; ignoring override.')
165
- timesteps = 0
166
-
167
- if tmp.timesteps != 0 and timesteps == 0:
168
- logger.warning(f'Model contains timesteps; ignoring override.')
169
- timesteps = tmp.timesteps
170
-
171
- if tmp.timesteps != timesteps:
172
- logger.info(f'Overriding timesteps: default = {tmp.timesteps}; specified = {timesteps}.')
173
- else:
174
- timesteps = tmp.timesteps
175
-
176
- logger.info(f'Building model with default hparams and batch_size= {batch_size}, timesteps={timesteps}')
177
- model = litemodule.MyHyperModel(timesteps=timesteps, batch_size=batch_size)
178
-
179
- logger.info('')
180
- # logger.info(summary(model))
181
- # from lightning.pytorch import Trainer
182
- # from lightning.pytorch.callbacks import ModelSummary
183
- # trainer = Trainer(callbacks=[ModelSummary(max_depth=2)])
184
- # logger.info(trainer.summarize())
185
- logger.info('')
186
- logger.info(f'feature {model.hparams.feature}')
187
- logger.info(f'num_classes {model.num_classes}')
188
- logger.info(f'batch_size {model.hparams.batch_size}')
189
- logger.info(f'timesteps {model.hparams.timesteps}')
190
- logger.info(f'flatten {model.flatten}')
191
- logger.info(f'add1ch {model.add1ch}')
192
- logger.info(f'truth_mutex {model.truth_mutex}')
193
- logger.info(f'input_shape {model.input_shape}')
194
- logger.info('')
195
- logger.info(f'Loading weights from {ckpt_base}')
196
- # model = model.load_from_checkpoint(ckpt_path) # weights only, has problems - needs investigation
197
- model.load_state_dict(checkpoint["state_dict"])
198
- model.eval()
199
- insample_shape = model.input_shape
200
- insample_shape.insert(0, batch_size)
201
- input_sample = randn(insample_shape)
202
- logger.info(f'Creating onnx model ...')
203
- for m in model.modules():
204
- if 'instancenorm' in m.__class__.__name__.lower():
205
- logger.info(f'Forcing train=false for instancenorm instance {m}, {m.__class__.__name__.lower()}')
206
- m.train(False)
207
- # m.track_running_stats=True # has problems
208
- model.to_onnx(file_path=ofname, input_sample=input_sample, export_params=True)
209
-
210
-
211
- if __name__ == '__main__':
212
- try:
213
- main()
214
- except KeyboardInterrupt:
215
- logger.info('Canceled due to keyboard interrupt')
216
- exit()