sonusai 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +1 -0
- sonusai/audiofe.py +157 -61
- sonusai/calc_metric_spenh-save.py +1334 -0
- sonusai/calc_metric_spenh.py +15 -8
- sonusai/genft.py +15 -6
- sonusai/genmix.py +14 -6
- sonusai/genmixdb.py +14 -6
- sonusai/gentcst.py +13 -6
- sonusai/lsdb.py +15 -5
- sonusai/mkmanifest.py +14 -6
- sonusai/mkwav.py +15 -6
- sonusai/onnx_predict-old.py +240 -0
- sonusai/onnx_predict-save.py +487 -0
- sonusai/onnx_predict.py +446 -182
- sonusai/ovino_predict.py +508 -0
- sonusai/ovino_query_devices.py +47 -0
- sonusai/plot.py +16 -6
- sonusai/post_spenh_targetf.py +13 -6
- sonusai/summarize_metric_spenh.py +71 -0
- sonusai/torchl_onnx-old.py +216 -0
- sonusai/tplot.py +14 -6
- sonusai/utils/onnx_utils.py +128 -39
- {sonusai-0.16.0.dist-info → sonusai-0.17.0.dist-info}/METADATA +1 -1
- {sonusai-0.16.0.dist-info → sonusai-0.17.0.dist-info}/RECORD +26 -19
- {sonusai-0.16.0.dist-info → sonusai-0.17.0.dist-info}/WHEEL +1 -1
- {sonusai-0.16.0.dist-info → sonusai-0.17.0.dist-info}/entry_points.txt +0 -0
sonusai/calc_metric_spenh.py
CHANGED
@@ -60,6 +60,7 @@ Metric and extraction data are written into prediction location PLOC as separate
|
|
60
60
|
Inputs:
|
61
61
|
|
62
62
|
"""
|
63
|
+
import signal
|
63
64
|
from dataclasses import dataclass
|
64
65
|
from typing import Optional
|
65
66
|
|
@@ -67,14 +68,24 @@ import matplotlib
|
|
67
68
|
import matplotlib.pyplot as plt
|
68
69
|
import numpy as np
|
69
70
|
import pandas as pd
|
70
|
-
|
71
|
-
from sonusai import logger
|
72
71
|
from sonusai.mixture import AudioF
|
73
72
|
from sonusai.mixture import AudioT
|
74
73
|
from sonusai.mixture import Feature
|
75
74
|
from sonusai.mixture import MixtureDatabase
|
76
75
|
from sonusai.mixture import Predict
|
77
76
|
|
77
|
+
|
78
|
+
def signal_handler(_sig, _frame):
|
79
|
+
import sys
|
80
|
+
|
81
|
+
from sonusai import logger
|
82
|
+
|
83
|
+
logger.info('Canceled due to keyboard interrupt')
|
84
|
+
sys.exit(1)
|
85
|
+
|
86
|
+
|
87
|
+
signal.signal(signal.SIGINT, signal_handler)
|
88
|
+
|
78
89
|
matplotlib.use('SVG')
|
79
90
|
|
80
91
|
|
@@ -1145,7 +1156,7 @@ def main():
|
|
1145
1156
|
fnb = 'metric_spenh_whspaaw_' + whisper_model + '_'
|
1146
1157
|
logger.info(f'WER enabled with method {wer_method} and whisper model {whisper_model}')
|
1147
1158
|
enable_asr_warmup = True
|
1148
|
-
elif wer_method == '
|
1159
|
+
elif wer_method == 'faster_whisper':
|
1149
1160
|
fnb = 'metric_spenh_fwhsp_' + whisper_model + '_'
|
1150
1161
|
logger.info(f'WER enabled with method {wer_method} and whisper model {whisper_model}')
|
1151
1162
|
enable_asr_warmup = True
|
@@ -1326,8 +1337,4 @@ def main():
|
|
1326
1337
|
|
1327
1338
|
|
1328
1339
|
if __name__ == '__main__':
|
1329
|
-
|
1330
|
-
main()
|
1331
|
-
except KeyboardInterrupt:
|
1332
|
-
logger.info('Canceled due to keyboard interrupt')
|
1333
|
-
exit()
|
1340
|
+
main()
|
sonusai/genft.py
CHANGED
@@ -23,14 +23,26 @@ Outputs the following to the mixture database directory:
|
|
23
23
|
genft.log
|
24
24
|
|
25
25
|
"""
|
26
|
+
import signal
|
26
27
|
from dataclasses import dataclass
|
27
28
|
|
28
|
-
from sonusai import logger
|
29
29
|
from sonusai.mixture import GenFTData
|
30
30
|
from sonusai.mixture import GeneralizedIDs
|
31
31
|
from sonusai.mixture import MixtureDatabase
|
32
32
|
|
33
33
|
|
34
|
+
def signal_handler(_sig, _frame):
|
35
|
+
import sys
|
36
|
+
|
37
|
+
from sonusai import logger
|
38
|
+
|
39
|
+
logger.info('Canceled due to keyboard interrupt')
|
40
|
+
sys.exit(1)
|
41
|
+
|
42
|
+
|
43
|
+
signal.signal(signal.SIGINT, signal_handler)
|
44
|
+
|
45
|
+
|
34
46
|
@dataclass
|
35
47
|
class MPGlobal:
|
36
48
|
mixdb: MixtureDatabase = None
|
@@ -123,6 +135,7 @@ def main() -> None:
|
|
123
135
|
|
124
136
|
from sonusai import create_file_handler
|
125
137
|
from sonusai import initial_log_messages
|
138
|
+
from sonusai import logger
|
126
139
|
from sonusai import update_console_handler
|
127
140
|
from sonusai.mixture import check_audio_files_exist
|
128
141
|
from sonusai.utils import human_readable_size
|
@@ -177,8 +190,4 @@ def main() -> None:
|
|
177
190
|
|
178
191
|
|
179
192
|
if __name__ == '__main__':
|
180
|
-
|
181
|
-
main()
|
182
|
-
except KeyboardInterrupt:
|
183
|
-
logger.info('Canceled due to keyboard interrupt')
|
184
|
-
raise SystemExit(0)
|
193
|
+
main()
|
sonusai/genmix.py
CHANGED
@@ -27,14 +27,26 @@ Outputs the following to the mixture database directory:
|
|
27
27
|
<id>.txt
|
28
28
|
genmix.log
|
29
29
|
"""
|
30
|
+
import signal
|
30
31
|
from dataclasses import dataclass
|
31
32
|
|
32
|
-
from sonusai import logger
|
33
33
|
from sonusai.mixture import GenMixData
|
34
34
|
from sonusai.mixture import GeneralizedIDs
|
35
35
|
from sonusai.mixture import MixtureDatabase
|
36
36
|
|
37
37
|
|
38
|
+
def signal_handler(_sig, _frame):
|
39
|
+
import sys
|
40
|
+
|
41
|
+
from sonusai import logger
|
42
|
+
|
43
|
+
logger.info('Canceled due to keyboard interrupt')
|
44
|
+
sys.exit(1)
|
45
|
+
|
46
|
+
|
47
|
+
signal.signal(signal.SIGINT, signal_handler)
|
48
|
+
|
49
|
+
|
38
50
|
@dataclass
|
39
51
|
class MPGlobal:
|
40
52
|
mixdb: MixtureDatabase = None
|
@@ -210,8 +222,4 @@ def main() -> None:
|
|
210
222
|
|
211
223
|
|
212
224
|
if __name__ == '__main__':
|
213
|
-
|
214
|
-
main()
|
215
|
-
except KeyboardInterrupt:
|
216
|
-
logger.info('Canceled due to keyboard interrupt')
|
217
|
-
raise SystemExit(0)
|
225
|
+
main()
|
sonusai/genmixdb.py
CHANGED
@@ -112,13 +112,25 @@ targets:
|
|
112
112
|
will find all .wav files in the specified directories and process them as targets.
|
113
113
|
|
114
114
|
"""
|
115
|
+
import signal
|
115
116
|
from dataclasses import dataclass
|
116
117
|
|
117
|
-
from sonusai import logger
|
118
118
|
from sonusai.mixture import Mixture
|
119
119
|
from sonusai.mixture import MixtureDatabase
|
120
120
|
|
121
121
|
|
122
|
+
def signal_handler(_sig, _frame):
|
123
|
+
import sys
|
124
|
+
|
125
|
+
from sonusai import logger
|
126
|
+
|
127
|
+
logger.info('Canceled due to keyboard interrupt')
|
128
|
+
sys.exit(1)
|
129
|
+
|
130
|
+
|
131
|
+
signal.signal(signal.SIGINT, signal_handler)
|
132
|
+
|
133
|
+
|
122
134
|
@dataclass
|
123
135
|
class MPGlobal:
|
124
136
|
mixdb: MixtureDatabase = None
|
@@ -509,8 +521,4 @@ def main() -> None:
|
|
509
521
|
|
510
522
|
|
511
523
|
if __name__ == '__main__':
|
512
|
-
|
513
|
-
main()
|
514
|
-
except KeyboardInterrupt:
|
515
|
-
logger.info('Canceled due to keyboard interrupt')
|
516
|
-
raise SystemExit(0)
|
524
|
+
main()
|
sonusai/gentcst.py
CHANGED
@@ -44,10 +44,21 @@ Outputs:
|
|
44
44
|
gentcst.log
|
45
45
|
|
46
46
|
"""
|
47
|
+
import signal
|
47
48
|
from dataclasses import dataclass
|
48
49
|
from typing import Optional
|
49
50
|
|
50
|
-
|
51
|
+
|
52
|
+
def signal_handler(_sig, _frame):
|
53
|
+
import sys
|
54
|
+
|
55
|
+
from sonusai import logger
|
56
|
+
|
57
|
+
logger.info('Canceled due to keyboard interrupt')
|
58
|
+
sys.exit(1)
|
59
|
+
|
60
|
+
|
61
|
+
signal.signal(signal.SIGINT, signal_handler)
|
51
62
|
|
52
63
|
CONFIG_FILE = 'config.yml'
|
53
64
|
|
@@ -621,8 +632,4 @@ def main() -> None:
|
|
621
632
|
|
622
633
|
|
623
634
|
if __name__ == '__main__':
|
624
|
-
|
625
|
-
main()
|
626
|
-
except KeyboardInterrupt:
|
627
|
-
logger.info('Canceled due to keyboard interrupt')
|
628
|
-
raise SystemExit(0)
|
635
|
+
main()
|
sonusai/lsdb.py
CHANGED
@@ -15,11 +15,25 @@ Inputs:
|
|
15
15
|
LOC A SonusAI mixture database directory.
|
16
16
|
|
17
17
|
"""
|
18
|
+
import signal
|
19
|
+
|
18
20
|
from sonusai import logger
|
19
21
|
from sonusai.mixture import GeneralizedIDs
|
20
22
|
from sonusai.mixture import MixtureDatabase
|
21
23
|
|
22
24
|
|
25
|
+
def signal_handler(_sig, _frame):
|
26
|
+
import sys
|
27
|
+
|
28
|
+
from sonusai import logger
|
29
|
+
|
30
|
+
logger.info('Canceled due to keyboard interrupt')
|
31
|
+
sys.exit(1)
|
32
|
+
|
33
|
+
|
34
|
+
signal.signal(signal.SIGINT, signal_handler)
|
35
|
+
|
36
|
+
|
23
37
|
def lsdb(mixdb: MixtureDatabase,
|
24
38
|
mixids: GeneralizedIDs = None,
|
25
39
|
truth_index: int = None,
|
@@ -142,8 +156,4 @@ def main() -> None:
|
|
142
156
|
|
143
157
|
|
144
158
|
if __name__ == '__main__':
|
145
|
-
|
146
|
-
main()
|
147
|
-
except KeyboardInterrupt:
|
148
|
-
logger.info('Canceled due to keyboard interrupt')
|
149
|
-
raise SystemExit(0)
|
159
|
+
main()
|
sonusai/mkmanifest.py
CHANGED
@@ -46,7 +46,19 @@ Example usage for LibriSpeech:
|
|
46
46
|
sonusai mkmanifest -mlibrispeech -eADAT -oasr_manifest.json --include='*.flac' train-clean-100
|
47
47
|
sonusai mkmanifest -m mcgill-speech -e ADAT -o asr_manifest_16k.json 16k-LP7/
|
48
48
|
"""
|
49
|
-
|
49
|
+
import signal
|
50
|
+
|
51
|
+
|
52
|
+
def signal_handler(_sig, _frame):
|
53
|
+
import sys
|
54
|
+
|
55
|
+
from sonusai import logger
|
56
|
+
|
57
|
+
logger.info('Canceled due to keyboard interrupt')
|
58
|
+
sys.exit(1)
|
59
|
+
|
60
|
+
|
61
|
+
signal.signal(signal.SIGINT, signal_handler)
|
50
62
|
|
51
63
|
VALID_METHOD = ['librispeech', 'vctk_noisy_speech', 'mcgill-speech']
|
52
64
|
|
@@ -194,8 +206,4 @@ def main() -> None:
|
|
194
206
|
|
195
207
|
|
196
208
|
if __name__ == '__main__':
|
197
|
-
|
198
|
-
main()
|
199
|
-
except KeyboardInterrupt:
|
200
|
-
logger.info('Canceled due to keyboard interrupt')
|
201
|
-
raise SystemExit(0)
|
209
|
+
main()
|
sonusai/mkwav.py
CHANGED
@@ -23,13 +23,25 @@ Outputs the following to the mixture database directory:
|
|
23
23
|
mkwav.log
|
24
24
|
|
25
25
|
"""
|
26
|
+
import signal
|
26
27
|
from dataclasses import dataclass
|
27
28
|
|
28
|
-
from sonusai import logger
|
29
29
|
from sonusai.mixture import AudioT
|
30
30
|
from sonusai.mixture import MixtureDatabase
|
31
31
|
|
32
32
|
|
33
|
+
def signal_handler(_sig, _frame):
|
34
|
+
import sys
|
35
|
+
|
36
|
+
from sonusai import logger
|
37
|
+
|
38
|
+
logger.info('Canceled due to keyboard interrupt')
|
39
|
+
sys.exit(1)
|
40
|
+
|
41
|
+
|
42
|
+
signal.signal(signal.SIGINT, signal_handler)
|
43
|
+
|
44
|
+
|
33
45
|
@dataclass
|
34
46
|
class MPGlobal:
|
35
47
|
mixdb: MixtureDatabase = None
|
@@ -120,6 +132,7 @@ def main() -> None:
|
|
120
132
|
import sonusai
|
121
133
|
from sonusai import create_file_handler
|
122
134
|
from sonusai import initial_log_messages
|
135
|
+
from sonusai import logger
|
123
136
|
from sonusai import update_console_handler
|
124
137
|
from sonusai.mixture import check_audio_files_exist
|
125
138
|
from sonusai.utils import pp_tqdm_imap
|
@@ -164,8 +177,4 @@ def main() -> None:
|
|
164
177
|
|
165
178
|
|
166
179
|
if __name__ == '__main__':
|
167
|
-
|
168
|
-
main()
|
169
|
-
except KeyboardInterrupt:
|
170
|
-
logger.info('Canceled due to keyboard interrupt')
|
171
|
-
raise SystemExit(0)
|
180
|
+
main()
|
@@ -0,0 +1,240 @@
|
|
1
|
+
"""sonusai predict
|
2
|
+
|
3
|
+
usage: predict [-hvr] [-i MIXID] (-m MODEL) INPUT
|
4
|
+
|
5
|
+
options:
|
6
|
+
-h, --help
|
7
|
+
-v, --verbose Be verbose.
|
8
|
+
-i MIXID, --mixid MIXID Mixture ID(s) to generate if input is a mixture database. [default: *].
|
9
|
+
-m MODEL, --model MODEL Trained ONNX model file.
|
10
|
+
-r, --reset Reset model between each file.
|
11
|
+
|
12
|
+
Run prediction on a trained ONNX model using SonusAI genft or WAV data.
|
13
|
+
|
14
|
+
Inputs:
|
15
|
+
MODEL A SonusAI trained ONNX model file.
|
16
|
+
INPUT The input data must be one of the following:
|
17
|
+
* WAV
|
18
|
+
Using the given model, generate feature data and run prediction. A model file must be
|
19
|
+
provided. The MIXID is ignored.
|
20
|
+
|
21
|
+
* directory
|
22
|
+
Using the given SonusAI mixture database directory, generate feature and truth data if not found.
|
23
|
+
Run prediction. The MIXID is required.
|
24
|
+
|
25
|
+
Outputs the following to opredict-<TIMESTAMP> directory:
|
26
|
+
<id>.h5
|
27
|
+
dataset: predict
|
28
|
+
onnx_predict.log
|
29
|
+
|
30
|
+
"""
|
31
|
+
|
32
|
+
from sonusai import logger
|
33
|
+
from sonusai.mixture import Feature
|
34
|
+
from sonusai.mixture import Predict
|
35
|
+
from sonusai.utils import SonusAIMetaData
|
36
|
+
|
37
|
+
|
38
|
+
def main() -> None:
|
39
|
+
from docopt import docopt
|
40
|
+
|
41
|
+
import sonusai
|
42
|
+
from sonusai.utils import trim_docstring
|
43
|
+
|
44
|
+
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
45
|
+
|
46
|
+
verbose = args['--verbose']
|
47
|
+
mixids = args['--mixid']
|
48
|
+
model_name = args['--model']
|
49
|
+
reset = args['--reset']
|
50
|
+
input_name = args['INPUT']
|
51
|
+
|
52
|
+
from os import makedirs
|
53
|
+
from os.path import isdir
|
54
|
+
from os.path import join
|
55
|
+
from os.path import splitext
|
56
|
+
|
57
|
+
import h5py
|
58
|
+
import onnxruntime as rt
|
59
|
+
import numpy as np
|
60
|
+
|
61
|
+
from sonusai import create_file_handler
|
62
|
+
from sonusai import initial_log_messages
|
63
|
+
from sonusai import update_console_handler
|
64
|
+
from sonusai.mixture import MixtureDatabase
|
65
|
+
from sonusai.mixture import get_feature_from_audio
|
66
|
+
from sonusai.mixture import read_audio
|
67
|
+
from sonusai.utils import create_ts_name
|
68
|
+
from sonusai.utils import get_frames_per_batch
|
69
|
+
from sonusai.utils import get_sonusai_metadata
|
70
|
+
|
71
|
+
output_dir = create_ts_name('opredict')
|
72
|
+
makedirs(output_dir, exist_ok=True)
|
73
|
+
|
74
|
+
# Setup logging file
|
75
|
+
create_file_handler(join(output_dir, 'onnx_predict.log'))
|
76
|
+
update_console_handler(verbose)
|
77
|
+
initial_log_messages('onnx_predict')
|
78
|
+
|
79
|
+
model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
|
80
|
+
model_metadata = get_sonusai_metadata(model)
|
81
|
+
|
82
|
+
batch_size = model_metadata.input_shape[0]
|
83
|
+
if model_metadata.timestep:
|
84
|
+
timesteps = model_metadata.input_shape[1]
|
85
|
+
else:
|
86
|
+
timesteps = 0
|
87
|
+
num_classes = model_metadata.output_shape[-1]
|
88
|
+
|
89
|
+
frames_per_batch = get_frames_per_batch(batch_size, timesteps)
|
90
|
+
|
91
|
+
logger.info('')
|
92
|
+
logger.info(f'feature {model_metadata.feature}')
|
93
|
+
logger.info(f'num_classes {num_classes}')
|
94
|
+
logger.info(f'batch_size {batch_size}')
|
95
|
+
logger.info(f'timesteps {timesteps}')
|
96
|
+
logger.info(f'flatten {model_metadata.flattened}')
|
97
|
+
logger.info(f'add1ch {model_metadata.channel}')
|
98
|
+
logger.info(f'truth_mutex {model_metadata.mutex}')
|
99
|
+
logger.info(f'input_shape {model_metadata.input_shape}')
|
100
|
+
logger.info(f'output_shape {model_metadata.output_shape}')
|
101
|
+
logger.info('')
|
102
|
+
|
103
|
+
if splitext(input_name)[1] == '.wav':
|
104
|
+
# Convert WAV to feature data
|
105
|
+
logger.info('')
|
106
|
+
logger.info(f'Run prediction on {input_name}')
|
107
|
+
audio = read_audio(input_name)
|
108
|
+
feature = get_feature_from_audio(audio=audio, feature_mode=model_metadata.feature)
|
109
|
+
|
110
|
+
predict = pad_and_predict(feature=feature,
|
111
|
+
model_name=model_name,
|
112
|
+
model_metadata=model_metadata,
|
113
|
+
frames_per_batch=frames_per_batch,
|
114
|
+
batch_size=batch_size,
|
115
|
+
timesteps=timesteps,
|
116
|
+
reset=reset)
|
117
|
+
|
118
|
+
output_name = splitext(input_name)[0] + '.h5'
|
119
|
+
with h5py.File(output_name, 'a') as f:
|
120
|
+
if 'feature' in f:
|
121
|
+
del f['feature']
|
122
|
+
f.create_dataset(name='feature', data=feature)
|
123
|
+
|
124
|
+
if 'predict' in f:
|
125
|
+
del f['predict']
|
126
|
+
f.create_dataset(name='predict', data=predict)
|
127
|
+
|
128
|
+
logger.info(f'Saved results to {output_name}')
|
129
|
+
return
|
130
|
+
|
131
|
+
if not isdir(input_name):
|
132
|
+
logger.exception(f'Do not know how to process input from {input_name}')
|
133
|
+
raise SystemExit(1)
|
134
|
+
|
135
|
+
mixdb = MixtureDatabase(input_name)
|
136
|
+
|
137
|
+
if mixdb.feature != model_metadata.feature:
|
138
|
+
logger.exception(f'Feature in mixture database does not match feature in model')
|
139
|
+
raise SystemExit(1)
|
140
|
+
|
141
|
+
mixids = mixdb.mixids_to_list(mixids)
|
142
|
+
if reset:
|
143
|
+
# reset mode cycles through each file one at a time
|
144
|
+
for mixid in mixids:
|
145
|
+
feature, _ = mixdb.mixture_ft(mixid)
|
146
|
+
|
147
|
+
predict = pad_and_predict(feature=feature,
|
148
|
+
model_name=model_name,
|
149
|
+
model_metadata=model_metadata,
|
150
|
+
frames_per_batch=frames_per_batch,
|
151
|
+
batch_size=batch_size,
|
152
|
+
timesteps=timesteps,
|
153
|
+
reset=reset)
|
154
|
+
|
155
|
+
output_name = join(output_dir, mixdb.mixtures[mixid].name)
|
156
|
+
with h5py.File(output_name, 'a') as f:
|
157
|
+
if 'predict' in f:
|
158
|
+
del f['predict']
|
159
|
+
f.create_dataset(name='predict', data=predict)
|
160
|
+
else:
|
161
|
+
features: list[Feature] = []
|
162
|
+
file_indices: list[slice] = []
|
163
|
+
total_frames = 0
|
164
|
+
for mixid in mixids:
|
165
|
+
current_feature, _ = mixdb.mixture_ft(mixid)
|
166
|
+
current_frames = current_feature.shape[0]
|
167
|
+
features.append(current_feature)
|
168
|
+
file_indices.append(slice(total_frames, total_frames + current_frames))
|
169
|
+
total_frames += current_frames
|
170
|
+
feature = np.vstack([features[i] for i in range(len(features))])
|
171
|
+
|
172
|
+
predict = pad_and_predict(feature=feature,
|
173
|
+
model_name=model_name,
|
174
|
+
model_metadata=model_metadata,
|
175
|
+
frames_per_batch=frames_per_batch,
|
176
|
+
batch_size=batch_size,
|
177
|
+
timesteps=timesteps,
|
178
|
+
reset=reset)
|
179
|
+
|
180
|
+
# Write data to separate files
|
181
|
+
for idx, mixid in enumerate(mixids):
|
182
|
+
output_name = join(output_dir, mixdb.mixtures[mixid].name)
|
183
|
+
with h5py.File(output_name, 'a') as f:
|
184
|
+
if 'predict' in f:
|
185
|
+
del f['predict']
|
186
|
+
f.create_dataset('predict', data=predict[file_indices[idx]])
|
187
|
+
|
188
|
+
logger.info(f'Saved results to {output_dir}')
|
189
|
+
|
190
|
+
|
191
|
+
def pad_and_predict(feature: Feature,
|
192
|
+
model_name: str,
|
193
|
+
model_metadata: SonusAIMetaData,
|
194
|
+
frames_per_batch: int,
|
195
|
+
batch_size: int,
|
196
|
+
timesteps: int,
|
197
|
+
reset: bool) -> Predict:
|
198
|
+
import onnxruntime as rt
|
199
|
+
import numpy as np
|
200
|
+
|
201
|
+
from sonusai.utils import reshape_inputs
|
202
|
+
from sonusai.utils import reshape_outputs
|
203
|
+
|
204
|
+
frames = feature.shape[0]
|
205
|
+
padding = frames_per_batch - frames % frames_per_batch
|
206
|
+
feature = np.pad(array=feature, pad_width=((0, padding), (0, 0), (0, 0)))
|
207
|
+
feature, _ = reshape_inputs(feature=feature,
|
208
|
+
batch_size=batch_size,
|
209
|
+
timesteps=timesteps,
|
210
|
+
flatten=model_metadata.flattened,
|
211
|
+
add1ch=model_metadata.channel)
|
212
|
+
sequences = feature.shape[0] // model_metadata.input_shape[0]
|
213
|
+
feature = np.reshape(feature, [sequences, *model_metadata.input_shape])
|
214
|
+
|
215
|
+
model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
|
216
|
+
output_names = [n.name for n in model.get_outputs()]
|
217
|
+
input_names = [n.name for n in model.get_inputs()]
|
218
|
+
|
219
|
+
predict = []
|
220
|
+
for sequence in range(sequences):
|
221
|
+
predict.append(model.run(output_names, {input_names[0]: feature[sequence]}))
|
222
|
+
if reset:
|
223
|
+
model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
|
224
|
+
|
225
|
+
predict_arr = np.vstack(predict)
|
226
|
+
# Combine [sequences, batch_size, ...] into [frames, ...]
|
227
|
+
predict_shape = predict_arr.shape
|
228
|
+
predict_arr = np.reshape(predict_arr, [predict_shape[0] * predict_shape[1], *predict_shape[2:]])
|
229
|
+
predict_arr, _ = reshape_outputs(predict=predict_arr, timesteps=timesteps)
|
230
|
+
predict_arr = predict_arr[:frames, :]
|
231
|
+
|
232
|
+
return predict_arr
|
233
|
+
|
234
|
+
|
235
|
+
if __name__ == '__main__':
|
236
|
+
try:
|
237
|
+
main()
|
238
|
+
except KeyboardInterrupt:
|
239
|
+
logger.info('Canceled due to keyboard interrupt')
|
240
|
+
raise SystemExit(0)
|