sonusai 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -60,6 +60,7 @@ Metric and extraction data are written into prediction location PLOC as separate
60
60
  Inputs:
61
61
 
62
62
  """
63
+ import signal
63
64
  from dataclasses import dataclass
64
65
  from typing import Optional
65
66
 
@@ -67,14 +68,24 @@ import matplotlib
67
68
  import matplotlib.pyplot as plt
68
69
  import numpy as np
69
70
  import pandas as pd
70
-
71
- from sonusai import logger
72
71
  from sonusai.mixture import AudioF
73
72
  from sonusai.mixture import AudioT
74
73
  from sonusai.mixture import Feature
75
74
  from sonusai.mixture import MixtureDatabase
76
75
  from sonusai.mixture import Predict
77
76
 
77
+
78
+ def signal_handler(_sig, _frame):
79
+ import sys
80
+
81
+ from sonusai import logger
82
+
83
+ logger.info('Canceled due to keyboard interrupt')
84
+ sys.exit(1)
85
+
86
+
87
+ signal.signal(signal.SIGINT, signal_handler)
88
+
78
89
  matplotlib.use('SVG')
79
90
 
80
91
 
@@ -1145,7 +1156,7 @@ def main():
1145
1156
  fnb = 'metric_spenh_whspaaw_' + whisper_model + '_'
1146
1157
  logger.info(f'WER enabled with method {wer_method} and whisper model {whisper_model}')
1147
1158
  enable_asr_warmup = True
1148
- elif wer_method == 'fastwhisper':
1159
+ elif wer_method == 'faster_whisper':
1149
1160
  fnb = 'metric_spenh_fwhsp_' + whisper_model + '_'
1150
1161
  logger.info(f'WER enabled with method {wer_method} and whisper model {whisper_model}')
1151
1162
  enable_asr_warmup = True
@@ -1326,8 +1337,4 @@ def main():
1326
1337
 
1327
1338
 
1328
1339
  if __name__ == '__main__':
1329
- try:
1330
- main()
1331
- except KeyboardInterrupt:
1332
- logger.info('Canceled due to keyboard interrupt')
1333
- exit()
1340
+ main()
sonusai/genft.py CHANGED
@@ -23,14 +23,26 @@ Outputs the following to the mixture database directory:
23
23
  genft.log
24
24
 
25
25
  """
26
+ import signal
26
27
  from dataclasses import dataclass
27
28
 
28
- from sonusai import logger
29
29
  from sonusai.mixture import GenFTData
30
30
  from sonusai.mixture import GeneralizedIDs
31
31
  from sonusai.mixture import MixtureDatabase
32
32
 
33
33
 
34
+ def signal_handler(_sig, _frame):
35
+ import sys
36
+
37
+ from sonusai import logger
38
+
39
+ logger.info('Canceled due to keyboard interrupt')
40
+ sys.exit(1)
41
+
42
+
43
+ signal.signal(signal.SIGINT, signal_handler)
44
+
45
+
34
46
  @dataclass
35
47
  class MPGlobal:
36
48
  mixdb: MixtureDatabase = None
@@ -123,6 +135,7 @@ def main() -> None:
123
135
 
124
136
  from sonusai import create_file_handler
125
137
  from sonusai import initial_log_messages
138
+ from sonusai import logger
126
139
  from sonusai import update_console_handler
127
140
  from sonusai.mixture import check_audio_files_exist
128
141
  from sonusai.utils import human_readable_size
@@ -177,8 +190,4 @@ def main() -> None:
177
190
 
178
191
 
179
192
  if __name__ == '__main__':
180
- try:
181
- main()
182
- except KeyboardInterrupt:
183
- logger.info('Canceled due to keyboard interrupt')
184
- raise SystemExit(0)
193
+ main()
sonusai/genmix.py CHANGED
@@ -27,14 +27,26 @@ Outputs the following to the mixture database directory:
27
27
  <id>.txt
28
28
  genmix.log
29
29
  """
30
+ import signal
30
31
  from dataclasses import dataclass
31
32
 
32
- from sonusai import logger
33
33
  from sonusai.mixture import GenMixData
34
34
  from sonusai.mixture import GeneralizedIDs
35
35
  from sonusai.mixture import MixtureDatabase
36
36
 
37
37
 
38
+ def signal_handler(_sig, _frame):
39
+ import sys
40
+
41
+ from sonusai import logger
42
+
43
+ logger.info('Canceled due to keyboard interrupt')
44
+ sys.exit(1)
45
+
46
+
47
+ signal.signal(signal.SIGINT, signal_handler)
48
+
49
+
38
50
  @dataclass
39
51
  class MPGlobal:
40
52
  mixdb: MixtureDatabase = None
@@ -210,8 +222,4 @@ def main() -> None:
210
222
 
211
223
 
212
224
  if __name__ == '__main__':
213
- try:
214
- main()
215
- except KeyboardInterrupt:
216
- logger.info('Canceled due to keyboard interrupt')
217
- raise SystemExit(0)
225
+ main()
sonusai/genmixdb.py CHANGED
@@ -112,13 +112,25 @@ targets:
112
112
  will find all .wav files in the specified directories and process them as targets.
113
113
 
114
114
  """
115
+ import signal
115
116
  from dataclasses import dataclass
116
117
 
117
- from sonusai import logger
118
118
  from sonusai.mixture import Mixture
119
119
  from sonusai.mixture import MixtureDatabase
120
120
 
121
121
 
122
+ def signal_handler(_sig, _frame):
123
+ import sys
124
+
125
+ from sonusai import logger
126
+
127
+ logger.info('Canceled due to keyboard interrupt')
128
+ sys.exit(1)
129
+
130
+
131
+ signal.signal(signal.SIGINT, signal_handler)
132
+
133
+
122
134
  @dataclass
123
135
  class MPGlobal:
124
136
  mixdb: MixtureDatabase = None
@@ -509,8 +521,4 @@ def main() -> None:
509
521
 
510
522
 
511
523
  if __name__ == '__main__':
512
- try:
513
- main()
514
- except KeyboardInterrupt:
515
- logger.info('Canceled due to keyboard interrupt')
516
- raise SystemExit(0)
524
+ main()
sonusai/gentcst.py CHANGED
@@ -44,10 +44,21 @@ Outputs:
44
44
  gentcst.log
45
45
 
46
46
  """
47
+ import signal
47
48
  from dataclasses import dataclass
48
49
  from typing import Optional
49
50
 
50
- from sonusai import logger
51
+
52
+ def signal_handler(_sig, _frame):
53
+ import sys
54
+
55
+ from sonusai import logger
56
+
57
+ logger.info('Canceled due to keyboard interrupt')
58
+ sys.exit(1)
59
+
60
+
61
+ signal.signal(signal.SIGINT, signal_handler)
51
62
 
52
63
  CONFIG_FILE = 'config.yml'
53
64
 
@@ -621,8 +632,4 @@ def main() -> None:
621
632
 
622
633
 
623
634
  if __name__ == '__main__':
624
- try:
625
- main()
626
- except KeyboardInterrupt:
627
- logger.info('Canceled due to keyboard interrupt')
628
- raise SystemExit(0)
635
+ main()
sonusai/lsdb.py CHANGED
@@ -15,11 +15,25 @@ Inputs:
15
15
  LOC A SonusAI mixture database directory.
16
16
 
17
17
  """
18
+ import signal
19
+
18
20
  from sonusai import logger
19
21
  from sonusai.mixture import GeneralizedIDs
20
22
  from sonusai.mixture import MixtureDatabase
21
23
 
22
24
 
25
+ def signal_handler(_sig, _frame):
26
+ import sys
27
+
28
+ from sonusai import logger
29
+
30
+ logger.info('Canceled due to keyboard interrupt')
31
+ sys.exit(1)
32
+
33
+
34
+ signal.signal(signal.SIGINT, signal_handler)
35
+
36
+
23
37
  def lsdb(mixdb: MixtureDatabase,
24
38
  mixids: GeneralizedIDs = None,
25
39
  truth_index: int = None,
@@ -142,8 +156,4 @@ def main() -> None:
142
156
 
143
157
 
144
158
  if __name__ == '__main__':
145
- try:
146
- main()
147
- except KeyboardInterrupt:
148
- logger.info('Canceled due to keyboard interrupt')
149
- raise SystemExit(0)
159
+ main()
sonusai/mkmanifest.py CHANGED
@@ -46,7 +46,19 @@ Example usage for LibriSpeech:
46
46
  sonusai mkmanifest -mlibrispeech -eADAT -oasr_manifest.json --include='*.flac' train-clean-100
47
47
  sonusai mkmanifest -m mcgill-speech -e ADAT -o asr_manifest_16k.json 16k-LP7/
48
48
  """
49
- from sonusai import logger
49
+ import signal
50
+
51
+
52
+ def signal_handler(_sig, _frame):
53
+ import sys
54
+
55
+ from sonusai import logger
56
+
57
+ logger.info('Canceled due to keyboard interrupt')
58
+ sys.exit(1)
59
+
60
+
61
+ signal.signal(signal.SIGINT, signal_handler)
50
62
 
51
63
  VALID_METHOD = ['librispeech', 'vctk_noisy_speech', 'mcgill-speech']
52
64
 
@@ -194,8 +206,4 @@ def main() -> None:
194
206
 
195
207
 
196
208
  if __name__ == '__main__':
197
- try:
198
- main()
199
- except KeyboardInterrupt:
200
- logger.info('Canceled due to keyboard interrupt')
201
- raise SystemExit(0)
209
+ main()
sonusai/mkwav.py CHANGED
@@ -23,13 +23,25 @@ Outputs the following to the mixture database directory:
23
23
  mkwav.log
24
24
 
25
25
  """
26
+ import signal
26
27
  from dataclasses import dataclass
27
28
 
28
- from sonusai import logger
29
29
  from sonusai.mixture import AudioT
30
30
  from sonusai.mixture import MixtureDatabase
31
31
 
32
32
 
33
+ def signal_handler(_sig, _frame):
34
+ import sys
35
+
36
+ from sonusai import logger
37
+
38
+ logger.info('Canceled due to keyboard interrupt')
39
+ sys.exit(1)
40
+
41
+
42
+ signal.signal(signal.SIGINT, signal_handler)
43
+
44
+
33
45
  @dataclass
34
46
  class MPGlobal:
35
47
  mixdb: MixtureDatabase = None
@@ -120,6 +132,7 @@ def main() -> None:
120
132
  import sonusai
121
133
  from sonusai import create_file_handler
122
134
  from sonusai import initial_log_messages
135
+ from sonusai import logger
123
136
  from sonusai import update_console_handler
124
137
  from sonusai.mixture import check_audio_files_exist
125
138
  from sonusai.utils import pp_tqdm_imap
@@ -164,8 +177,4 @@ def main() -> None:
164
177
 
165
178
 
166
179
  if __name__ == '__main__':
167
- try:
168
- main()
169
- except KeyboardInterrupt:
170
- logger.info('Canceled due to keyboard interrupt')
171
- raise SystemExit(0)
180
+ main()
@@ -0,0 +1,240 @@
1
+ """sonusai predict
2
+
3
+ usage: predict [-hvr] [-i MIXID] (-m MODEL) INPUT
4
+
5
+ options:
6
+ -h, --help
7
+ -v, --verbose Be verbose.
8
+ -i MIXID, --mixid MIXID Mixture ID(s) to generate if input is a mixture database. [default: *].
9
+ -m MODEL, --model MODEL Trained ONNX model file.
10
+ -r, --reset Reset model between each file.
11
+
12
+ Run prediction on a trained ONNX model using SonusAI genft or WAV data.
13
+
14
+ Inputs:
15
+ MODEL A SonusAI trained ONNX model file.
16
+ INPUT The input data must be one of the following:
17
+ * WAV
18
+ Using the given model, generate feature data and run prediction. A model file must be
19
+ provided. The MIXID is ignored.
20
+
21
+ * directory
22
+ Using the given SonusAI mixture database directory, generate feature and truth data if not found.
23
+ Run prediction. The MIXID is required.
24
+
25
+ Outputs the following to opredict-<TIMESTAMP> directory:
26
+ <id>.h5
27
+ dataset: predict
28
+ onnx_predict.log
29
+
30
+ """
31
+
32
+ from sonusai import logger
33
+ from sonusai.mixture import Feature
34
+ from sonusai.mixture import Predict
35
+ from sonusai.utils import SonusAIMetaData
36
+
37
+
38
+ def main() -> None:
39
+ from docopt import docopt
40
+
41
+ import sonusai
42
+ from sonusai.utils import trim_docstring
43
+
44
+ args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
45
+
46
+ verbose = args['--verbose']
47
+ mixids = args['--mixid']
48
+ model_name = args['--model']
49
+ reset = args['--reset']
50
+ input_name = args['INPUT']
51
+
52
+ from os import makedirs
53
+ from os.path import isdir
54
+ from os.path import join
55
+ from os.path import splitext
56
+
57
+ import h5py
58
+ import onnxruntime as rt
59
+ import numpy as np
60
+
61
+ from sonusai import create_file_handler
62
+ from sonusai import initial_log_messages
63
+ from sonusai import update_console_handler
64
+ from sonusai.mixture import MixtureDatabase
65
+ from sonusai.mixture import get_feature_from_audio
66
+ from sonusai.mixture import read_audio
67
+ from sonusai.utils import create_ts_name
68
+ from sonusai.utils import get_frames_per_batch
69
+ from sonusai.utils import get_sonusai_metadata
70
+
71
+ output_dir = create_ts_name('opredict')
72
+ makedirs(output_dir, exist_ok=True)
73
+
74
+ # Setup logging file
75
+ create_file_handler(join(output_dir, 'onnx_predict.log'))
76
+ update_console_handler(verbose)
77
+ initial_log_messages('onnx_predict')
78
+
79
+ model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
80
+ model_metadata = get_sonusai_metadata(model)
81
+
82
+ batch_size = model_metadata.input_shape[0]
83
+ if model_metadata.timestep:
84
+ timesteps = model_metadata.input_shape[1]
85
+ else:
86
+ timesteps = 0
87
+ num_classes = model_metadata.output_shape[-1]
88
+
89
+ frames_per_batch = get_frames_per_batch(batch_size, timesteps)
90
+
91
+ logger.info('')
92
+ logger.info(f'feature {model_metadata.feature}')
93
+ logger.info(f'num_classes {num_classes}')
94
+ logger.info(f'batch_size {batch_size}')
95
+ logger.info(f'timesteps {timesteps}')
96
+ logger.info(f'flatten {model_metadata.flattened}')
97
+ logger.info(f'add1ch {model_metadata.channel}')
98
+ logger.info(f'truth_mutex {model_metadata.mutex}')
99
+ logger.info(f'input_shape {model_metadata.input_shape}')
100
+ logger.info(f'output_shape {model_metadata.output_shape}')
101
+ logger.info('')
102
+
103
+ if splitext(input_name)[1] == '.wav':
104
+ # Convert WAV to feature data
105
+ logger.info('')
106
+ logger.info(f'Run prediction on {input_name}')
107
+ audio = read_audio(input_name)
108
+ feature = get_feature_from_audio(audio=audio, feature_mode=model_metadata.feature)
109
+
110
+ predict = pad_and_predict(feature=feature,
111
+ model_name=model_name,
112
+ model_metadata=model_metadata,
113
+ frames_per_batch=frames_per_batch,
114
+ batch_size=batch_size,
115
+ timesteps=timesteps,
116
+ reset=reset)
117
+
118
+ output_name = splitext(input_name)[0] + '.h5'
119
+ with h5py.File(output_name, 'a') as f:
120
+ if 'feature' in f:
121
+ del f['feature']
122
+ f.create_dataset(name='feature', data=feature)
123
+
124
+ if 'predict' in f:
125
+ del f['predict']
126
+ f.create_dataset(name='predict', data=predict)
127
+
128
+ logger.info(f'Saved results to {output_name}')
129
+ return
130
+
131
+ if not isdir(input_name):
132
+ logger.exception(f'Do not know how to process input from {input_name}')
133
+ raise SystemExit(1)
134
+
135
+ mixdb = MixtureDatabase(input_name)
136
+
137
+ if mixdb.feature != model_metadata.feature:
138
+ logger.exception(f'Feature in mixture database does not match feature in model')
139
+ raise SystemExit(1)
140
+
141
+ mixids = mixdb.mixids_to_list(mixids)
142
+ if reset:
143
+ # reset mode cycles through each file one at a time
144
+ for mixid in mixids:
145
+ feature, _ = mixdb.mixture_ft(mixid)
146
+
147
+ predict = pad_and_predict(feature=feature,
148
+ model_name=model_name,
149
+ model_metadata=model_metadata,
150
+ frames_per_batch=frames_per_batch,
151
+ batch_size=batch_size,
152
+ timesteps=timesteps,
153
+ reset=reset)
154
+
155
+ output_name = join(output_dir, mixdb.mixtures[mixid].name)
156
+ with h5py.File(output_name, 'a') as f:
157
+ if 'predict' in f:
158
+ del f['predict']
159
+ f.create_dataset(name='predict', data=predict)
160
+ else:
161
+ features: list[Feature] = []
162
+ file_indices: list[slice] = []
163
+ total_frames = 0
164
+ for mixid in mixids:
165
+ current_feature, _ = mixdb.mixture_ft(mixid)
166
+ current_frames = current_feature.shape[0]
167
+ features.append(current_feature)
168
+ file_indices.append(slice(total_frames, total_frames + current_frames))
169
+ total_frames += current_frames
170
+ feature = np.vstack([features[i] for i in range(len(features))])
171
+
172
+ predict = pad_and_predict(feature=feature,
173
+ model_name=model_name,
174
+ model_metadata=model_metadata,
175
+ frames_per_batch=frames_per_batch,
176
+ batch_size=batch_size,
177
+ timesteps=timesteps,
178
+ reset=reset)
179
+
180
+ # Write data to separate files
181
+ for idx, mixid in enumerate(mixids):
182
+ output_name = join(output_dir, mixdb.mixtures[mixid].name)
183
+ with h5py.File(output_name, 'a') as f:
184
+ if 'predict' in f:
185
+ del f['predict']
186
+ f.create_dataset('predict', data=predict[file_indices[idx]])
187
+
188
+ logger.info(f'Saved results to {output_dir}')
189
+
190
+
191
+ def pad_and_predict(feature: Feature,
192
+ model_name: str,
193
+ model_metadata: SonusAIMetaData,
194
+ frames_per_batch: int,
195
+ batch_size: int,
196
+ timesteps: int,
197
+ reset: bool) -> Predict:
198
+ import onnxruntime as rt
199
+ import numpy as np
200
+
201
+ from sonusai.utils import reshape_inputs
202
+ from sonusai.utils import reshape_outputs
203
+
204
+ frames = feature.shape[0]
205
+ padding = frames_per_batch - frames % frames_per_batch
206
+ feature = np.pad(array=feature, pad_width=((0, padding), (0, 0), (0, 0)))
207
+ feature, _ = reshape_inputs(feature=feature,
208
+ batch_size=batch_size,
209
+ timesteps=timesteps,
210
+ flatten=model_metadata.flattened,
211
+ add1ch=model_metadata.channel)
212
+ sequences = feature.shape[0] // model_metadata.input_shape[0]
213
+ feature = np.reshape(feature, [sequences, *model_metadata.input_shape])
214
+
215
+ model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
216
+ output_names = [n.name for n in model.get_outputs()]
217
+ input_names = [n.name for n in model.get_inputs()]
218
+
219
+ predict = []
220
+ for sequence in range(sequences):
221
+ predict.append(model.run(output_names, {input_names[0]: feature[sequence]}))
222
+ if reset:
223
+ model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
224
+
225
+ predict_arr = np.vstack(predict)
226
+ # Combine [sequences, batch_size, ...] into [frames, ...]
227
+ predict_shape = predict_arr.shape
228
+ predict_arr = np.reshape(predict_arr, [predict_shape[0] * predict_shape[1], *predict_shape[2:]])
229
+ predict_arr, _ = reshape_outputs(predict=predict_arr, timesteps=timesteps)
230
+ predict_arr = predict_arr[:frames, :]
231
+
232
+ return predict_arr
233
+
234
+
235
+ if __name__ == '__main__':
236
+ try:
237
+ main()
238
+ except KeyboardInterrupt:
239
+ logger.info('Canceled due to keyboard interrupt')
240
+ raise SystemExit(0)