sonusai 0.15.9__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. sonusai/__init__.py +36 -4
  2. sonusai/audiofe.py +111 -106
  3. sonusai/calc_metric_spenh.py +38 -22
  4. sonusai/genft.py +15 -6
  5. sonusai/genmix.py +14 -6
  6. sonusai/genmixdb.py +15 -7
  7. sonusai/gentcst.py +13 -6
  8. sonusai/lsdb.py +15 -5
  9. sonusai/main.py +58 -61
  10. sonusai/mixture/__init__.py +1 -0
  11. sonusai/mixture/config.py +1 -2
  12. sonusai/mkmanifest.py +43 -8
  13. sonusai/mkwav.py +15 -6
  14. sonusai/onnx_predict.py +16 -6
  15. sonusai/plot.py +16 -6
  16. sonusai/post_spenh_targetf.py +13 -6
  17. sonusai/summarize_metric_spenh.py +71 -0
  18. sonusai/tplot.py +14 -6
  19. sonusai/utils/__init__.py +4 -7
  20. sonusai/utils/asl_p56.py +3 -3
  21. sonusai/utils/asr.py +35 -8
  22. sonusai/utils/asr_functions/__init__.py +0 -5
  23. sonusai/utils/asr_functions/aaware_whisper.py +2 -2
  24. sonusai/utils/asr_manifest_functions/__init__.py +1 -0
  25. sonusai/utils/asr_manifest_functions/mcgill_speech.py +29 -0
  26. sonusai/utils/{trim_docstring.py → docstring.py} +20 -0
  27. sonusai/utils/model_utils.py +30 -0
  28. sonusai/utils/onnx_utils.py +19 -45
  29. {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/METADATA +7 -25
  30. {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/RECORD +32 -46
  31. sonusai/data_generator/__init__.py +0 -5
  32. sonusai/data_generator/dataset_from_mixdb.py +0 -143
  33. sonusai/data_generator/keras_from_mixdb.py +0 -169
  34. sonusai/data_generator/torch_from_mixdb.py +0 -122
  35. sonusai/keras_onnx.py +0 -86
  36. sonusai/keras_predict.py +0 -231
  37. sonusai/keras_train.py +0 -334
  38. sonusai/torchl_onnx.py +0 -216
  39. sonusai/torchl_predict.py +0 -542
  40. sonusai/torchl_train.py +0 -223
  41. sonusai/utils/asr_functions/aixplain_whisper.py +0 -59
  42. sonusai/utils/asr_functions/data.py +0 -16
  43. sonusai/utils/asr_functions/deepgram.py +0 -97
  44. sonusai/utils/asr_functions/fastwhisper.py +0 -90
  45. sonusai/utils/asr_functions/google.py +0 -95
  46. sonusai/utils/asr_functions/whisper.py +0 -49
  47. sonusai/utils/keras_utils.py +0 -226
  48. {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/WHEEL +0 -0
  49. {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/entry_points.txt +0 -0
@@ -1,169 +0,0 @@
1
- import warnings
2
- from dataclasses import dataclass
3
- from multiprocessing import get_context
4
- from os import cpu_count
5
- from typing import Optional
6
-
7
- import numpy as np
8
-
9
- from sonusai.mixture import Feature
10
- from sonusai.mixture import GeneralizedIDs
11
- from sonusai.mixture import MixtureDatabase
12
- from sonusai.mixture import Truth
13
-
14
- with warnings.catch_warnings():
15
- warnings.simplefilter('ignore')
16
- from keras.utils import Sequence
17
-
18
-
19
- @dataclass
20
- class MPGlobal:
21
- mixdb: MixtureDatabase = None
22
-
23
-
24
- MP_GLOBAL = MPGlobal()
25
-
26
-
27
- def _pool_initializer(location: str) -> None:
28
- MP_GLOBAL.mixdb = MixtureDatabase(location)
29
-
30
-
31
- def _pool_kernel(mixid: int) -> tuple[Feature, Truth]:
32
- return MP_GLOBAL.mixdb.mixture_ft(mixid)
33
-
34
-
35
- class KerasFromMixtureDatabase(Sequence):
36
- """Generates data for Keras from a SonusAI mixture database
37
- """
38
-
39
- @dataclass(frozen=True)
40
- class BatchParams:
41
- mixids: list[int]
42
- offset: int
43
- extra: int
44
- padding: int
45
-
46
- def __init__(self,
47
- mixdb: MixtureDatabase,
48
- mixids: GeneralizedIDs,
49
- batch_size: int,
50
- timesteps: int,
51
- flatten: bool,
52
- add1ch: bool,
53
- shuffle: bool = False):
54
- """Initialization
55
- """
56
- self.mixdb = mixdb
57
- self.mixids = self.mixdb.mixids_to_list(mixids)
58
- self.batch_size = batch_size
59
- self.timesteps = timesteps
60
- self.flatten = flatten
61
- self.add1ch = add1ch
62
- self.shuffle = shuffle
63
- self.stride = self.mixdb.fg_stride
64
- self.feature_parameters = self.mixdb.feature_parameters
65
- self.num_classes = self.mixdb.num_classes
66
- self.mixture_frame_segments: Optional[int] = None
67
- self.batch_frame_segments: Optional[int] = None
68
- self.total_batches: Optional[int] = None
69
-
70
- self._initialize_mixtures()
71
-
72
- self.pool = get_context('fork').Pool(processes=cpu_count(),
73
- initializer=_pool_initializer,
74
- initargs=(self.mixdb.location,))
75
-
76
- def __len__(self) -> int:
77
- """Denotes the number of batches per epoch
78
- """
79
- return self.total_batches
80
-
81
- def __getitem__(self, batch_index: int) -> tuple[np.ndarray, np.ndarray]:
82
- """Get one batch of data
83
- """
84
- from sonusai.utils import reshape_inputs
85
-
86
- batch_params = self.batch_params[batch_index]
87
-
88
- result = self.pool.map(_pool_kernel, batch_params.mixids)
89
- feature = np.vstack([result[i][0] for i in range(len(result))])
90
- truth = np.vstack([result[i][1] for i in range(len(result))])
91
-
92
- pad_shape = list(feature.shape)
93
- pad_shape[0] = batch_params.padding
94
- feature = np.vstack([feature, np.zeros(pad_shape)])
95
-
96
- pad_shape = list(truth.shape)
97
- pad_shape[0] = batch_params.padding
98
- truth = np.vstack([truth, np.zeros(pad_shape)])
99
-
100
- if batch_params.extra > 0:
101
- feature = feature[batch_params.offset:-batch_params.extra]
102
- truth = truth[batch_params.offset:-batch_params.extra]
103
- else:
104
- feature = feature[batch_params.offset:]
105
- truth = truth[batch_params.offset:]
106
-
107
- feature, truth = reshape_inputs(feature=feature,
108
- truth=truth,
109
- batch_size=self.batch_size,
110
- timesteps=self.timesteps,
111
- flatten=self.flatten,
112
- add1ch=self.add1ch)
113
-
114
- return feature, truth
115
-
116
- def on_epoch_end(self) -> None:
117
- """Modification of dataset between epochs
118
- """
119
- import random
120
-
121
- if self.shuffle:
122
- random.shuffle(self.mixids)
123
- self._initialize_mixtures()
124
-
125
- def _initialize_mixtures(self) -> None:
126
- from sonusai.utils import get_frames_per_batch
127
-
128
- frames_per_batch = get_frames_per_batch(self.batch_size, self.timesteps)
129
- # Always extend the number of batches to use all available data
130
- # The last batch may need padding
131
- self.total_batches = int(np.ceil(self.mixdb.total_feature_frames(self.mixids) / frames_per_batch))
132
-
133
- # Compute mixid, offset, and extra for dataset
134
- # offsets and extras are needed because mixtures are not guaranteed to fall on batch boundaries.
135
- # When fetching a new index that starts in the middle of a sequence of mixtures, the
136
- # previous feature frame offset must be maintained in order to preserve the correct
137
- # data sequence. And the extra must be maintained in order to preserve the correct data length.
138
- cumulative_frames = 0
139
- start_mixture_index = 0
140
- offset = 0
141
- self.batch_params = []
142
- self.file_indices = []
143
- total_frames = 0
144
- for idx, mixid in enumerate(self.mixids):
145
- current_frames = self.mixdb.mixture(mixid).samples // self.mixdb.feature_step_samples
146
- self.file_indices.append(slice(total_frames, total_frames + current_frames))
147
- total_frames += current_frames
148
- cumulative_frames += current_frames
149
- while cumulative_frames >= frames_per_batch:
150
- extra = cumulative_frames - frames_per_batch
151
- mixids = self.mixids[start_mixture_index:idx + 1]
152
- self.batch_params.append(self.BatchParams(mixids=mixids, offset=offset, extra=extra, padding=0))
153
- if extra == 0:
154
- start_mixture_index = idx + 1
155
- offset = 0
156
- else:
157
- start_mixture_index = idx
158
- offset = current_frames - extra
159
- cumulative_frames = extra
160
-
161
- # If needed, add final batch with padding
162
- needed_frames = self.total_batches * frames_per_batch
163
- padding = needed_frames - total_frames
164
- if padding != 0:
165
- mixids = self.mixids[start_mixture_index:]
166
- self.batch_params.append(self.BatchParams(mixids=mixids, offset=offset, extra=0, padding=padding))
167
-
168
-
169
- KerasFromH5 = KerasFromMixtureDatabase
@@ -1,122 +0,0 @@
1
- from typing import Optional
2
-
3
- import numpy as np
4
- from torch.utils.data import DataLoader
5
- from torch.utils.data import Dataset
6
- from torch.utils.data import Sampler
7
-
8
- from sonusai.mixture import GeneralizedIDs
9
- from sonusai.mixture import MixtureDatabase
10
-
11
-
12
- class MixtureDatabaseDataset(Dataset):
13
- """Generates a PyTorch dataset from a SonusAI mixture database
14
- """
15
-
16
- def __init__(self,
17
- mixdb: MixtureDatabase,
18
- mixids: GeneralizedIDs,
19
- cut_len: int,
20
- flatten: bool,
21
- add1ch: bool,
22
- random_cut: bool = True):
23
- """Initialization
24
- """
25
- self.mixdb = mixdb
26
- self.mixids = self.mixdb.mixids_to_list(mixids)
27
- self.cut_len = cut_len
28
- self.flatten = flatten
29
- self.add1ch = add1ch
30
- self.random_cut = random_cut
31
-
32
- def __len__(self):
33
- return len(self.mixids)
34
-
35
- def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray, int]:
36
- """Get data from one mixture
37
- """
38
- import random
39
-
40
- from sonusai.utils import reshape_inputs
41
-
42
- feature, truth = self.mixdb.mixture_ft(self.mixids[idx])
43
- feature, truth = reshape_inputs(feature=feature,
44
- truth=truth,
45
- batch_size=1,
46
- timesteps=0,
47
- flatten=self.flatten,
48
- add1ch=self.add1ch)
49
-
50
- length = feature.shape[0]
51
-
52
- if self.cut_len > 0:
53
- if length < self.cut_len:
54
- feature_final = []
55
- truth_final = []
56
- for _ in range(self.cut_len // length):
57
- feature_final.append(feature)
58
- truth_final.append(truth)
59
- feature_final.append(feature[: self.cut_len % length])
60
- truth_final.append(truth[: self.cut_len % length])
61
- feature = np.vstack([feature_final[i] for i in range(len(feature_final))])
62
- truth = np.vstack([truth_final[i] for i in range(len(truth_final))])
63
- else:
64
- if self.random_cut:
65
- start = random.randint(0, length - self.cut_len)
66
- else:
67
- start = 0
68
- feature = feature[start:start + self.cut_len]
69
- truth = truth[start:start + self.cut_len]
70
-
71
- return feature, truth, idx
72
-
73
-
74
- class AawareDataLoader(DataLoader):
75
- _cut_len: Optional[int] = None
76
-
77
- @property
78
- def cut_len(self) -> int:
79
- return self._cut_len
80
-
81
- @cut_len.setter
82
- def cut_len(self, value: int) -> None:
83
- self._cut_len = value
84
-
85
-
86
- def TorchFromMixtureDatabase(mixdb: MixtureDatabase,
87
- mixids: GeneralizedIDs,
88
- batch_size: int,
89
- flatten: bool,
90
- add1ch: bool,
91
- num_workers: int = 0,
92
- cut_len: int = 0,
93
- drop_last: bool = False,
94
- shuffle: bool = False,
95
- random_cut: bool = True,
96
- sampler: Optional[type[Sampler]] = None,
97
- pin_memory: bool = False) -> AawareDataLoader:
98
- """Generates a PyTorch dataloader from a SonusAI mixture database
99
- """
100
- dataset = MixtureDatabaseDataset(mixdb=mixdb,
101
- mixids=mixids,
102
- cut_len=cut_len,
103
- flatten=flatten,
104
- add1ch=add1ch,
105
- random_cut=random_cut)
106
-
107
- if sampler is not None:
108
- my_sampler = sampler(dataset)
109
- else:
110
- my_sampler = None
111
-
112
- result = AawareDataLoader(dataset=dataset,
113
- batch_size=batch_size,
114
- pin_memory=pin_memory,
115
- shuffle=shuffle,
116
- sampler=my_sampler,
117
- drop_last=drop_last,
118
- num_workers=num_workers)
119
-
120
- result.cut_len = cut_len
121
-
122
- return result
sonusai/keras_onnx.py DELETED
@@ -1,86 +0,0 @@
1
- """sonusai keras_onnx
2
-
3
- usage: keras_onnx [-hvr] (-m MODEL) (-w WEIGHTS) [-b BATCH] [-t TSTEPS] [-o OUTPUT]
4
-
5
- options:
6
- -h, --help
7
- -v, --verbose Be verbose.
8
- -m MODEL, --model MODEL Python model file.
9
- -w WEIGHTS, --weights WEIGHTS Keras model weights file.
10
- -b BATCH, --batch BATCH Batch size.
11
- -t TSTEPS, --tsteps TSTEPS Timesteps.
12
- -o OUTPUT, --output OUTPUT Output directory.
13
-
14
- Convert a trained Keras model to ONNX.
15
-
16
- Inputs:
17
- MODEL A SonusAI Python model file with build and/or hypermodel functions.
18
- WEIGHTS A Keras model weights file (or model file with weights).
19
-
20
- Outputs:
21
- OUTPUT/ A directory containing:
22
- <MODEL>.onnx Model file with batch_size and timesteps equal to provided parameters
23
- <MODEL>-b1.onnx Model file with batch_size=1 and if the timesteps dimension exists it
24
- is set to 1 (useful for real-time inference applications)
25
- keras_onnx.log
26
-
27
- Results are written into subdirectory <MODEL>-<TIMESTAMP> unless OUTPUT is specified.
28
-
29
- """
30
- from sonusai import logger
31
-
32
-
33
- def main() -> None:
34
- from docopt import docopt
35
-
36
- import sonusai
37
- from sonusai.utils import trim_docstring
38
-
39
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
40
-
41
- verbose = args['--verbose']
42
- model_name = args['--model']
43
- weight_name = args['--weights']
44
- batch_size = args['--batch']
45
- timesteps = args['--tsteps']
46
- output_dir = args['--output']
47
-
48
- from os import makedirs
49
- from os.path import basename
50
- from os.path import join
51
- from os.path import splitext
52
-
53
- from sonusai import create_file_handler
54
- from sonusai import initial_log_messages
55
- from sonusai import update_console_handler
56
- from sonusai.utils import create_ts_name
57
- from sonusai.utils import keras_onnx
58
-
59
- model_tail = basename(model_name)
60
- model_root = splitext(model_tail)[0]
61
-
62
- if batch_size is not None:
63
- batch_size = int(batch_size)
64
-
65
- if timesteps is not None:
66
- timesteps = int(timesteps)
67
-
68
- if output_dir is None:
69
- output_dir = create_ts_name(model_root)
70
-
71
- makedirs(output_dir, exist_ok=True)
72
-
73
- # Setup logging file
74
- create_file_handler(join(output_dir, 'keras_onnx.log'))
75
- update_console_handler(verbose)
76
- initial_log_messages('keras_onnx')
77
-
78
- keras_onnx(model_name, weight_name, timesteps, batch_size, output_dir)
79
-
80
-
81
- if __name__ == '__main__':
82
- try:
83
- main()
84
- except KeyboardInterrupt:
85
- logger.info('Canceled due to keyboard interrupt')
86
- exit()
sonusai/keras_predict.py DELETED
@@ -1,231 +0,0 @@
1
- """sonusai keras_predict
2
-
3
- usage: keras_predict [-hvr] [-i MIXID] (-m MODEL) (-w KMODEL) [-b BATCH] [-t TSTEPS] INPUT ...
4
-
5
- options:
6
- -h, --help
7
- -v, --verbose Be verbose.
8
- -i MIXID, --mixid MIXID Mixture ID(s) to generate if input is a mixture database. [default: *].
9
- -m MODEL, --model MODEL Python model file.
10
- -w KMODEL, --weights KMODEL Keras model weights file.
11
- -b BATCH, --batch BATCH Batch size.
12
- -t TSTEPS, --tsteps TSTEPS Timesteps.
13
- -r, --reset Reset model between each file.
14
-
15
- Run prediction on a trained Keras model defined by a SonusAI Keras Python model file using SonusAI genft or WAV data.
16
-
17
- Inputs:
18
- MODEL A SonusAI Python model file with build and/or hypermodel functions.
19
- KMODEL A Keras model weights file (or model file with weights).
20
- INPUT The input data must be one of the following:
21
- * Single WAV file or glob of WAV files
22
- Using the given model, generate feature data and run prediction. A model file must be
23
- provided. The MIXID is ignored.
24
-
25
- * directory
26
- Using the given SonusAI mixture database directory, generate feature and truth data if not found.
27
- Run prediction. The MIXID is required.
28
-
29
- Outputs the following to kpredict-<TIMESTAMP> directory:
30
- <id>.h5
31
- dataset: predict
32
- keras_predict.log
33
-
34
- """
35
- from typing import Any
36
-
37
- from sonusai import logger
38
- from sonusai.mixture import Feature
39
- from sonusai.mixture import Predict
40
-
41
-
42
- def main() -> None:
43
- from docopt import docopt
44
-
45
- import sonusai
46
- from sonusai.utils import trim_docstring
47
-
48
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
49
-
50
- verbose = args['--verbose']
51
- mixids = args['--mixid']
52
- model_name = args['--model']
53
- weights_name = args['--weights']
54
- batch_size = args['--batch']
55
- timesteps = args['--tsteps']
56
- reset = args['--reset']
57
- input_name = args['INPUT']
58
-
59
- from os import makedirs
60
- from os.path import basename
61
- from os.path import isdir
62
- from os.path import isfile
63
- from os.path import join
64
- from os.path import splitext
65
-
66
- import h5py
67
- import keras_tuner as kt
68
- import tensorflow as tf
69
- from keras import backend as kb
70
-
71
- from sonusai import create_file_handler
72
- from sonusai import initial_log_messages
73
- from sonusai import update_console_handler
74
- from sonusai.data_generator import KerasFromH5
75
- from sonusai.mixture import MixtureDatabase
76
- from sonusai.mixture import get_feature_from_audio
77
- from sonusai.mixture import read_audio
78
- from sonusai.utils import create_ts_name
79
- from sonusai.utils import get_frames_per_batch
80
- from sonusai.utils import import_and_check_keras_model
81
- from sonusai.utils import reshape_outputs
82
-
83
- if batch_size is not None:
84
- batch_size = int(batch_size)
85
-
86
- if timesteps is not None:
87
- timesteps = int(timesteps)
88
-
89
- output_dir = create_ts_name('kpredict')
90
- makedirs(output_dir, exist_ok=True)
91
-
92
- # Setup logging file
93
- create_file_handler(join(output_dir, 'keras_predict.log'))
94
- update_console_handler(verbose)
95
- initial_log_messages('keras_predict')
96
-
97
- logger.info(f'tensorflow {tf.__version__}')
98
- logger.info(f'keras {tf.keras.__version__}')
99
- logger.info('')
100
-
101
- hypermodel = import_and_check_keras_model(model_name=model_name,
102
- weights_name=weights_name,
103
- timesteps=timesteps,
104
- batch_size=batch_size)
105
- built_model = hypermodel.build_model(kt.HyperParameters())
106
-
107
- frames_per_batch = get_frames_per_batch(hypermodel.batch_size, hypermodel.timesteps)
108
-
109
- kb.clear_session()
110
- logger.info('')
111
- built_model.summary(print_fn=logger.info)
112
- logger.info('')
113
- logger.info(f'feature {hypermodel.feature}')
114
- logger.info(f'num_classes {hypermodel.num_classes}')
115
- logger.info(f'batch_size {hypermodel.batch_size}')
116
- logger.info(f'timesteps {hypermodel.timesteps}')
117
- logger.info(f'flatten {hypermodel.flatten}')
118
- logger.info(f'add1ch {hypermodel.add1ch}')
119
- logger.info(f'truth_mutex {hypermodel.truth_mutex}')
120
- logger.info(f'input_shape {hypermodel.input_shape}')
121
- logger.info('')
122
-
123
- logger.info(f'Loading weights from {weights_name}')
124
- built_model.load_weights(weights_name)
125
-
126
- logger.info('')
127
- if len(input_name) == 1 and isdir(input_name[0]):
128
- input_name = input_name[0]
129
- logger.info(f'Load mixture database from {input_name}')
130
- mixdb = MixtureDatabase(input_name)
131
-
132
- if mixdb.feature != hypermodel.feature:
133
- logger.exception(f'Feature in mixture database does not match feature in model')
134
- raise SystemExit(1)
135
-
136
- mixids = mixdb.mixids_to_list(mixids)
137
- if reset:
138
- # reset mode cycles through each file one at a time
139
- for mixid in mixids:
140
- feature, _ = mixdb.mixture_ft(mixid)
141
-
142
- feature, predict = _pad_and_predict(hypermodel=hypermodel,
143
- built_model=built_model,
144
- feature=feature,
145
- frames_per_batch=frames_per_batch)
146
-
147
- output_name = join(output_dir, mixdb.mixtures[mixid].name)
148
- with h5py.File(output_name, 'a') as f:
149
- if 'predict' in f:
150
- del f['predict']
151
- f.create_dataset(name='predict', data=predict)
152
- else:
153
- # Run all data at once using a data generator
154
- feature = KerasFromH5(mixdb=mixdb,
155
- mixids=mixids,
156
- batch_size=hypermodel.batch_size,
157
- timesteps=hypermodel.timesteps,
158
- flatten=hypermodel.flatten,
159
- add1ch=hypermodel.add1ch)
160
-
161
- predict = built_model.predict(feature, batch_size=hypermodel.batch_size, verbose=1)
162
- predict, _ = reshape_outputs(predict=predict, timesteps=hypermodel.timesteps)
163
-
164
- # Write data to separate files
165
- for idx, mixid in enumerate(mixids):
166
- output_name = join(output_dir, mixdb.mixtures[mixid].name)
167
- with h5py.File(output_name, 'a') as f:
168
- if 'predict' in f:
169
- del f['predict']
170
- f.create_dataset('predict', data=predict[feature.file_indices[idx]])
171
-
172
- logger.info(f'Saved results to {output_dir}')
173
- return
174
-
175
- if not all(isfile(file) and splitext(file)[1] == '.wav' for file in input_name):
176
- logger.exception(f'Do not know how to process input from {input_name}')
177
- raise SystemExit(1)
178
-
179
- logger.info(f'Run prediction on {len(input_name):,} WAV files')
180
- for file in input_name:
181
- # Convert WAV to feature data
182
- audio = read_audio(file)
183
- feature = get_feature_from_audio(audio=audio, feature_mode=hypermodel.feature)
184
-
185
- feature, predict = _pad_and_predict(hypermodel=hypermodel,
186
- built_model=built_model,
187
- feature=feature,
188
- frames_per_batch=frames_per_batch)
189
-
190
- output_name = join(output_dir, splitext(basename(file))[0] + '.h5')
191
- with h5py.File(output_name, 'a') as f:
192
- if 'feature' in f:
193
- del f['feature']
194
- f.create_dataset(name='feature', data=feature)
195
-
196
- if 'predict' in f:
197
- del f['predict']
198
- f.create_dataset(name='predict', data=predict)
199
-
200
- logger.info(f'Saved results to {output_dir}')
201
-
202
-
203
- def _pad_and_predict(hypermodel: Any,
204
- built_model: Any,
205
- feature: Feature,
206
- frames_per_batch: int) -> tuple[Feature, Predict]:
207
- import numpy as np
208
-
209
- from sonusai.utils import reshape_inputs
210
- from sonusai.utils import reshape_outputs
211
-
212
- frames = feature.shape[0]
213
- padding = frames_per_batch - frames % frames_per_batch
214
- feature = np.pad(array=feature, pad_width=((0, padding), (0, 0), (0, 0)))
215
- feature, _ = reshape_inputs(feature=feature,
216
- batch_size=hypermodel.batch_size,
217
- timesteps=hypermodel.timesteps,
218
- flatten=hypermodel.flatten,
219
- add1ch=hypermodel.add1ch)
220
- predict = built_model.predict(feature, batch_size=hypermodel.batch_size, verbose=1)
221
- predict, _ = reshape_outputs(predict=predict, timesteps=hypermodel.timesteps)
222
- predict = predict[:frames, :]
223
- return feature, predict
224
-
225
-
226
- if __name__ == '__main__':
227
- try:
228
- main()
229
- except KeyboardInterrupt:
230
- logger.info('Canceled due to keyboard interrupt')
231
- exit()