sonusai 0.15.9__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +36 -4
- sonusai/audiofe.py +111 -106
- sonusai/calc_metric_spenh.py +38 -22
- sonusai/genft.py +15 -6
- sonusai/genmix.py +14 -6
- sonusai/genmixdb.py +15 -7
- sonusai/gentcst.py +13 -6
- sonusai/lsdb.py +15 -5
- sonusai/main.py +58 -61
- sonusai/mixture/__init__.py +1 -0
- sonusai/mixture/config.py +1 -2
- sonusai/mkmanifest.py +43 -8
- sonusai/mkwav.py +15 -6
- sonusai/onnx_predict.py +16 -6
- sonusai/plot.py +16 -6
- sonusai/post_spenh_targetf.py +13 -6
- sonusai/summarize_metric_spenh.py +71 -0
- sonusai/tplot.py +14 -6
- sonusai/utils/__init__.py +4 -7
- sonusai/utils/asl_p56.py +3 -3
- sonusai/utils/asr.py +35 -8
- sonusai/utils/asr_functions/__init__.py +0 -5
- sonusai/utils/asr_functions/aaware_whisper.py +2 -2
- sonusai/utils/asr_manifest_functions/__init__.py +1 -0
- sonusai/utils/asr_manifest_functions/mcgill_speech.py +29 -0
- sonusai/utils/{trim_docstring.py → docstring.py} +20 -0
- sonusai/utils/model_utils.py +30 -0
- sonusai/utils/onnx_utils.py +19 -45
- {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/METADATA +7 -25
- {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/RECORD +32 -46
- sonusai/data_generator/__init__.py +0 -5
- sonusai/data_generator/dataset_from_mixdb.py +0 -143
- sonusai/data_generator/keras_from_mixdb.py +0 -169
- sonusai/data_generator/torch_from_mixdb.py +0 -122
- sonusai/keras_onnx.py +0 -86
- sonusai/keras_predict.py +0 -231
- sonusai/keras_train.py +0 -334
- sonusai/torchl_onnx.py +0 -216
- sonusai/torchl_predict.py +0 -542
- sonusai/torchl_train.py +0 -223
- sonusai/utils/asr_functions/aixplain_whisper.py +0 -59
- sonusai/utils/asr_functions/data.py +0 -16
- sonusai/utils/asr_functions/deepgram.py +0 -97
- sonusai/utils/asr_functions/fastwhisper.py +0 -90
- sonusai/utils/asr_functions/google.py +0 -95
- sonusai/utils/asr_functions/whisper.py +0 -49
- sonusai/utils/keras_utils.py +0 -226
- {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/WHEEL +0 -0
- {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/entry_points.txt +0 -0
sonusai/torchl_train.py
DELETED
@@ -1,223 +0,0 @@
|
|
1
|
-
"""sonusai torchl_train
|
2
|
-
|
3
|
-
usage: torchl_train [-hgv] (-m MODEL) (-l VLOC) [-w WEIGHTS] [-k CKPT]
|
4
|
-
[-e EPOCHS] [-b BATCH] [-t TSTEPS] [-p ESP] TLOC
|
5
|
-
|
6
|
-
options:
|
7
|
-
-h, --help
|
8
|
-
-v, --verbose Be verbose.
|
9
|
-
-m MODEL, --model MODEL Python .py file with MyHyperModel custom PL class definition.
|
10
|
-
-l VLOC, --vloc VLOC Location of SonusAI mixture database to use for validation.
|
11
|
-
-w WEIGHTS, --weights WEIGHTS Optional PL checkpoint file for initializing model weights.
|
12
|
-
-k CKPT, --ckpt CKPT Optional PL checkpoint file for full resume of training.
|
13
|
-
-e EPOCHS, --epochs EPOCHS Number of epochs to use in training. [default: 8].
|
14
|
-
-b BATCH, --batch BATCH Batch size.
|
15
|
-
-t TSTEPS, --tsteps TSTEPS Timesteps.
|
16
|
-
-p ESP, --patience ESP Early stopping patience. [default: 12]
|
17
|
-
-g, --loss-batch-log Enable per-batch loss log. [default: False]
|
18
|
-
|
19
|
-
Train a PL (Pytorch Lightning) model defined in MODEL .py using Sonusai mixture data in TLOC.
|
20
|
-
|
21
|
-
Inputs:
|
22
|
-
TLOC A SonusAI mixture database directory to use for training data.
|
23
|
-
VLOC A SonusAI mixture database directory to use for validation data.
|
24
|
-
|
25
|
-
Results are written into subdirectory <MODEL>-<TIMESTAMP>.
|
26
|
-
Per-batch loss history, if enabled, is written to <basename>-history-lossb.npy
|
27
|
-
|
28
|
-
"""
|
29
|
-
from sonusai import logger
|
30
|
-
|
31
|
-
|
32
|
-
def main() -> None:
|
33
|
-
from docopt import docopt
|
34
|
-
|
35
|
-
import sonusai
|
36
|
-
from sonusai.utils import trim_docstring
|
37
|
-
|
38
|
-
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
39
|
-
|
40
|
-
verbose = args['--verbose']
|
41
|
-
model_name = args['--model']
|
42
|
-
weights_name = args['--weights']
|
43
|
-
ckpt_name = args['--ckpt']
|
44
|
-
v_name = args['--vloc']
|
45
|
-
epochs = int(args['--epochs'])
|
46
|
-
batch_size = args['--batch']
|
47
|
-
timesteps = args['--tsteps']
|
48
|
-
esp = int(args['--patience'])
|
49
|
-
loss_batch_log = args['--loss-batch-log']
|
50
|
-
t_name = args['TLOC']
|
51
|
-
|
52
|
-
import warnings
|
53
|
-
from os import makedirs
|
54
|
-
from os.path import basename
|
55
|
-
from os.path import join
|
56
|
-
from os.path import splitext
|
57
|
-
|
58
|
-
# import keras_tuner as kt
|
59
|
-
|
60
|
-
with warnings.catch_warnings():
|
61
|
-
warnings.simplefilter('ignore')
|
62
|
-
# from keras import backend as kb
|
63
|
-
# from keras.callbacks import EarlyStopping
|
64
|
-
|
65
|
-
from sonusai import create_file_handler
|
66
|
-
from sonusai import initial_log_messages
|
67
|
-
from sonusai import update_console_handler
|
68
|
-
from sonusai.data_generator import TorchFromMixtureDatabase
|
69
|
-
# from sonusai.data_generator import KerasFromH5
|
70
|
-
from sonusai.mixture import MixtureDatabase
|
71
|
-
from sonusai.utils import create_ts_name
|
72
|
-
from sonusai.utils import import_keras_model
|
73
|
-
from lightning.pytorch import Trainer
|
74
|
-
from pytorch_lightning.loggers.csv_logs import CSVLogger
|
75
|
-
from lightning.pytorch.callbacks import ModelSummary
|
76
|
-
from lightning.pytorch.callbacks import ModelCheckpoint
|
77
|
-
from lightning.pytorch.callbacks import EarlyStopping
|
78
|
-
from lightning.pytorch.loggers import TensorBoardLogger
|
79
|
-
|
80
|
-
model_base = basename(model_name)
|
81
|
-
model_root = splitext(model_base)[0]
|
82
|
-
|
83
|
-
if batch_size is not None:
|
84
|
-
batch_size = int(batch_size)
|
85
|
-
|
86
|
-
if timesteps is not None:
|
87
|
-
timesteps = int(timesteps)
|
88
|
-
|
89
|
-
output_dir = create_ts_name(model_root)
|
90
|
-
makedirs(output_dir, exist_ok=True)
|
91
|
-
base_name = join(output_dir, model_root)
|
92
|
-
|
93
|
-
# Setup logging file
|
94
|
-
create_file_handler(join(output_dir, 'torchl_train.log'))
|
95
|
-
update_console_handler(verbose)
|
96
|
-
initial_log_messages('torchl_train')
|
97
|
-
logger.info('')
|
98
|
-
|
99
|
-
t_mixdb = MixtureDatabase(t_name)
|
100
|
-
logger.info(f'Training: found {t_mixdb.num_mixtures} mixtures with {t_mixdb.num_classes} classes from {t_name}')
|
101
|
-
|
102
|
-
v_mixdb = MixtureDatabase(v_name)
|
103
|
-
logger.info(f'Validation: found {v_mixdb.num_mixtures} mixtures with {v_mixdb.num_classes} classes from {v_name}')
|
104
|
-
|
105
|
-
# Import model definition file
|
106
|
-
logger.info(f'Importing {model_base}')
|
107
|
-
litemodel = import_keras_model(model_name) # note works for PL as well as keras
|
108
|
-
|
109
|
-
# Check overrides
|
110
|
-
# timesteps = check_keras_overrides(model, t_mixdb.feature, t_mixdb.num_classes, timesteps, batch_size)
|
111
|
-
# Calculate batches per epoch, use ceiling as last batch is zero extended
|
112
|
-
# frames_per_batch = get_frames_per_batch(batch_size, timesteps)
|
113
|
-
# batches_per_epoch = int(np.ceil(t_mixdb.total_feature_frames('*') / frames_per_batch))
|
114
|
-
|
115
|
-
logger.info('Building and compiling model')
|
116
|
-
try:
|
117
|
-
model = litemodel.MyHyperModel(feature=t_mixdb.feature,
|
118
|
-
# num_classes=t_mixdb.num_classes,
|
119
|
-
timesteps=timesteps,
|
120
|
-
batch_size=batch_size)
|
121
|
-
except Exception as e:
|
122
|
-
logger.exception(f'Error: building {model_base} failed: {e}')
|
123
|
-
raise SystemExit(1)
|
124
|
-
|
125
|
-
logger.info('')
|
126
|
-
# built_model.summary(print_fn=logger.info)
|
127
|
-
# logger.info(model)
|
128
|
-
# logger.info((summary(model)))
|
129
|
-
# logger.info(summary(hypermodel, input_size=tuple(hypermodel.input_shape)))
|
130
|
-
logger.info('')
|
131
|
-
logger.info(f'feature {model.hparams.feature}')
|
132
|
-
logger.info(f'batch_size {model.hparams.batch_size}')
|
133
|
-
logger.info(f'timesteps {model.hparams.timesteps}')
|
134
|
-
logger.info(f'num_classes {model.num_classes}')
|
135
|
-
logger.info(f'flatten {model.flatten}')
|
136
|
-
logger.info(f'add1ch {model.add1ch}')
|
137
|
-
logger.info(f'input_shape {model.input_shape}')
|
138
|
-
logger.info(f'truth_mutex {model.truth_mutex}')
|
139
|
-
# logger.info(f'lossf {hypermodel.lossf}')
|
140
|
-
# logger.info(f'optimizer {hypermodel.configure_optimizers()}')
|
141
|
-
logger.info('')
|
142
|
-
|
143
|
-
t_mixid = t_mixdb.mixids_to_list()
|
144
|
-
v_mixid = v_mixdb.mixids_to_list()
|
145
|
-
|
146
|
-
# Use SonusAI DataGenerator to create validation feature/truth on the fly
|
147
|
-
sampler = None # TBD how to stratify, also see stratified_shuffle_split_mixid(t_mixdb, vsplit=0)
|
148
|
-
t_datagen = TorchFromMixtureDatabase(mixdb=t_mixdb,
|
149
|
-
mixids=t_mixid,
|
150
|
-
batch_size=model.hparams.batch_size,
|
151
|
-
cut_len=model.hparams.timesteps,
|
152
|
-
flatten=model.flatten,
|
153
|
-
add1ch=model.add1ch,
|
154
|
-
random_cut=True,
|
155
|
-
sampler=sampler,
|
156
|
-
drop_last=True,
|
157
|
-
num_workers=4)
|
158
|
-
|
159
|
-
v_datagen = TorchFromMixtureDatabase(mixdb=v_mixdb,
|
160
|
-
mixids=v_mixid,
|
161
|
-
batch_size=1,
|
162
|
-
cut_len=0,
|
163
|
-
flatten=model.flatten,
|
164
|
-
add1ch=model.add1ch,
|
165
|
-
random_cut=False,
|
166
|
-
sampler=sampler,
|
167
|
-
drop_last=True,
|
168
|
-
num_workers=0)
|
169
|
-
|
170
|
-
csvl = CSVLogger(output_dir, name="logs", version="")
|
171
|
-
tbl = TensorBoardLogger(output_dir, "logs", "", log_graph=True, default_hp_metric=False)
|
172
|
-
es_cb = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=esp, verbose=False, mode="min")
|
173
|
-
ckpt_topv = ModelCheckpoint(dirpath=output_dir + '/ckpt/', save_top_k=5, monitor="val_loss",
|
174
|
-
mode="min", filename=model_root + "-{epoch:03d}-{val_loss:.3g}")
|
175
|
-
# lr_monitor = LearningRateMonitor(logging_interval="step")
|
176
|
-
ckpt_last = ModelCheckpoint(dirpath=output_dir + '/ckpt/', save_last=True)
|
177
|
-
# lr_monitor = LearningRateMonitor(logging_interval="step")
|
178
|
-
callbacks = [ModelSummary(max_depth=2), ckpt_topv, es_cb, ckpt_last] # , lr_monitor]
|
179
|
-
|
180
|
-
profiler = None # 'advanced'
|
181
|
-
if profiler == 'advanced':
|
182
|
-
from lightning.pytorch.profilers import AdvancedProfiler
|
183
|
-
profiler = AdvancedProfiler(dirpath=output_dir, filename="perf_logs")
|
184
|
-
else:
|
185
|
-
profiler = None
|
186
|
-
|
187
|
-
if weights_name is not None and ckpt_name is None:
|
188
|
-
logger.info(f'Loading weights from {weights_name}')
|
189
|
-
model = litemodel.MyHyperModel.load_from_checkpoint(weights_name,
|
190
|
-
feature=t_mixdb.feature,
|
191
|
-
# num_classes=t_mixdb.num_classes,
|
192
|
-
timesteps=timesteps,
|
193
|
-
batch_size=batch_size)
|
194
|
-
|
195
|
-
if ckpt_name is not None:
|
196
|
-
logger.info(f'Loading full checkpoint and resuming training from {ckpt_name}')
|
197
|
-
ckpt_path = ckpt_name
|
198
|
-
else:
|
199
|
-
ckpt_path = None
|
200
|
-
|
201
|
-
logger.info(f' training mixtures {len(t_mixid)}')
|
202
|
-
logger.info(f' validation mixtures {len(v_mixid)}')
|
203
|
-
logger.info(f'Starting training with early stopping patience = {esp} ...')
|
204
|
-
logger.info('')
|
205
|
-
|
206
|
-
trainer = Trainer(max_epochs=epochs,
|
207
|
-
default_root_dir=output_dir,
|
208
|
-
logger=[tbl, csvl],
|
209
|
-
log_every_n_steps=10,
|
210
|
-
profiler=profiler,
|
211
|
-
# precision='16-mixed',
|
212
|
-
# accelerator="cpu",
|
213
|
-
# devices=4,
|
214
|
-
callbacks=callbacks)
|
215
|
-
trainer.fit(model, t_datagen, v_datagen, ckpt_path=ckpt_path)
|
216
|
-
|
217
|
-
|
218
|
-
if __name__ == '__main__':
|
219
|
-
try:
|
220
|
-
main()
|
221
|
-
except KeyboardInterrupt:
|
222
|
-
logger.info('Canceled due to keyboard interrupt')
|
223
|
-
exit()
|
@@ -1,59 +0,0 @@
|
|
1
|
-
from sonusai.utils import ASRResult
|
2
|
-
from sonusai.utils.asr_functions.data import Data
|
3
|
-
|
4
|
-
|
5
|
-
def aixplain_whisper(data: Data) -> ASRResult:
|
6
|
-
import tempfile
|
7
|
-
from os import getenv
|
8
|
-
from os.path import join
|
9
|
-
|
10
|
-
from aixplain.factories.model_factory import ModelFactory
|
11
|
-
|
12
|
-
from sonusai import SonusAIError
|
13
|
-
from sonusai.utils import ASRResult
|
14
|
-
from sonusai.utils import float_to_int16
|
15
|
-
from sonusai.utils import write_wav
|
16
|
-
|
17
|
-
whisper_model = data.whisper_model
|
18
|
-
if whisper_model is None:
|
19
|
-
envvar = 'AIXP_WHISPER_' + data.whisper_model_name.upper()
|
20
|
-
modelkey = getenv(envvar)
|
21
|
-
if modelkey is None:
|
22
|
-
raise SonusAIError(f'{envvar} environment variable does not exist')
|
23
|
-
|
24
|
-
whisper_model = ModelFactory.get(modelkey)
|
25
|
-
|
26
|
-
with tempfile.TemporaryDirectory() as tmp:
|
27
|
-
file = join(tmp, 'asr.wav')
|
28
|
-
write_wav(name=file, audio=float_to_int16(data.audio))
|
29
|
-
|
30
|
-
retry = 5
|
31
|
-
count = 0
|
32
|
-
while True:
|
33
|
-
try:
|
34
|
-
results = whisper_model.run(file)
|
35
|
-
return ASRResult(text=results['data'], confidence=results['confidence'])
|
36
|
-
except Exception as e:
|
37
|
-
count += 1
|
38
|
-
print(f'Warning: aiXplain exception: {e}')
|
39
|
-
if count >= retry:
|
40
|
-
raise SonusAIError(f'Whisper exception: {e.args}')
|
41
|
-
|
42
|
-
|
43
|
-
"""
|
44
|
-
aiXplain Whisper results:
|
45
|
-
{
|
46
|
-
'completed': True,
|
47
|
-
'data': 'The birch canoe slid on the smooth planks.',
|
48
|
-
'usedCredits': 3.194770833333333e-05,
|
49
|
-
'runTime': 114.029,
|
50
|
-
'confidence': None,
|
51
|
-
'details': [],
|
52
|
-
'rawData': {
|
53
|
-
'predictions': [
|
54
|
-
' The birch canoe slid on the smooth planks.'
|
55
|
-
]
|
56
|
-
},
|
57
|
-
'status': 'SUCCESS'
|
58
|
-
}
|
59
|
-
"""
|
@@ -1,16 +0,0 @@
|
|
1
|
-
from dataclasses import dataclass
|
2
|
-
from typing import Any
|
3
|
-
from typing import Optional
|
4
|
-
|
5
|
-
from sonusai.mixture.datatypes import AudioT
|
6
|
-
|
7
|
-
|
8
|
-
@dataclass(frozen=True)
|
9
|
-
class Data:
|
10
|
-
audio: AudioT
|
11
|
-
whisper_model: Optional[Any] = None
|
12
|
-
whisper_model_name: Optional[str] = None
|
13
|
-
device: Optional[str] = None
|
14
|
-
cpu_threads: Optional[int] = None
|
15
|
-
compute_type: Optional[str] = None
|
16
|
-
beam_size: Optional[int] = None
|
@@ -1,97 +0,0 @@
|
|
1
|
-
from sonusai.utils import ASRResult
|
2
|
-
from sonusai.utils.asr_functions.data import Data
|
3
|
-
|
4
|
-
|
5
|
-
def deepgram(data: Data) -> ASRResult:
|
6
|
-
import tempfile
|
7
|
-
from os import getenv
|
8
|
-
from os.path import join
|
9
|
-
from random import random
|
10
|
-
from time import sleep
|
11
|
-
|
12
|
-
import magic
|
13
|
-
from deepgram import Deepgram
|
14
|
-
from deepgram._types import BufferSource
|
15
|
-
|
16
|
-
from sonusai import SonusAIError
|
17
|
-
from sonusai.utils import float_to_int16
|
18
|
-
from sonusai.utils import write_wav
|
19
|
-
|
20
|
-
key = getenv('DEEPGRAM_API_KEY')
|
21
|
-
if key is None:
|
22
|
-
raise SonusAIError('DEEPGRAM_API_KEY environment variable does not exist')
|
23
|
-
|
24
|
-
client = Deepgram(key)
|
25
|
-
with tempfile.TemporaryDirectory() as tmp:
|
26
|
-
file = join(tmp, 'asr.wav')
|
27
|
-
write_wav(name=file, audio=float_to_int16(data.audio))
|
28
|
-
|
29
|
-
mimetype = magic.from_file(file, mime=True)
|
30
|
-
with open(file, 'rb') as audio:
|
31
|
-
source = BufferSource(buffer=audio, mimetype=mimetype)
|
32
|
-
retry = 5
|
33
|
-
count = 0
|
34
|
-
while True:
|
35
|
-
try:
|
36
|
-
results = client.transcription.sync_prerecorded(source)['results']
|
37
|
-
return ASRResult(text=results['channels'][0]['alternatives'][0]['transcript'],
|
38
|
-
confidence=results['channels'][0]['alternatives'][0]['confidence'])
|
39
|
-
except Exception as e:
|
40
|
-
count += 1
|
41
|
-
print(f'Warning: Deepgram exception: {e}')
|
42
|
-
if count >= retry:
|
43
|
-
raise SonusAIError(f'Deepgram exception: {e.args}')
|
44
|
-
sleep(count * (1 + random()))
|
45
|
-
|
46
|
-
|
47
|
-
"""
|
48
|
-
Deepgram results:
|
49
|
-
{'metadata': {'channels': 1,
|
50
|
-
'created': '2023-01-30T21:49:44.048Z',
|
51
|
-
'duration': 2.3795626,
|
52
|
-
'model_info': {'c12089d0-0766-4ca0-9511-98fd2e443ebd': {'name': 'general',
|
53
|
-
'tier': 'base',
|
54
|
-
'version': '2022-01-18.1'}},
|
55
|
-
'models': ['c12089d0-0766-4ca0-9511-98fd2e443ebd'],
|
56
|
-
'request_id': 'e1154979-07f7-46a3-89e6-d5d796676d31',
|
57
|
-
'sha256': '3cad2f30a83e351eab3c4dcaa2ec47185e8f2979c90abec0c2332a7eef7c2d40',
|
58
|
-
'transaction_key': 'deprecated'},
|
59
|
-
'results': {'channels': [{'alternatives': [{'confidence': 0.9794922,
|
60
|
-
'transcript': 'the birch can canoe slid on the smooth planks',
|
61
|
-
'words': [{'confidence': 0.9794922,
|
62
|
-
'end': 0.29625,
|
63
|
-
'start': 0.13825,
|
64
|
-
'word': 'the'},
|
65
|
-
{'confidence': 0.9902344,
|
66
|
-
'end': 0.57275,
|
67
|
-
'start': 0.29625,
|
68
|
-
'word': 'birch'},
|
69
|
-
{'confidence': 0.73535156,
|
70
|
-
'end': 0.73074996,
|
71
|
-
'start': 0.57275,
|
72
|
-
'word': 'can'},
|
73
|
-
{'confidence': 0.9550781,
|
74
|
-
'end': 1.08625,
|
75
|
-
'start': 0.73074996,
|
76
|
-
'word': 'canoe'},
|
77
|
-
{'confidence': 0.98876953,
|
78
|
-
'end': 1.2442499,
|
79
|
-
'start': 1.08625,
|
80
|
-
'word': 'slid'},
|
81
|
-
{'confidence': 0.9921875,
|
82
|
-
'end': 1.3627499,
|
83
|
-
'start': 1.2442499,
|
84
|
-
'word': 'on'},
|
85
|
-
{'confidence': 0.9584961,
|
86
|
-
'end': 1.5997499,
|
87
|
-
'start': 1.3627499,
|
88
|
-
'word': 'the'},
|
89
|
-
{'confidence': 0.9970703,
|
90
|
-
'end': 1.9947499,
|
91
|
-
'start': 1.5997499,
|
92
|
-
'word': 'smooth'},
|
93
|
-
{'confidence': 0.98828125,
|
94
|
-
'end': 2.23175,
|
95
|
-
'start': 1.9947499,
|
96
|
-
'word': 'planks'}]}]}]}}
|
97
|
-
"""
|
@@ -1,90 +0,0 @@
|
|
1
|
-
from sonusai.utils import ASRResult
|
2
|
-
from sonusai.utils.asr_functions.data import Data
|
3
|
-
|
4
|
-
|
5
|
-
def fastwhisper(data: Data) -> ASRResult:
|
6
|
-
from os import getpid
|
7
|
-
from timeit import default_timer as timer
|
8
|
-
|
9
|
-
from faster_whisper import WhisperModel
|
10
|
-
|
11
|
-
from sonusai import SonusAIError
|
12
|
-
|
13
|
-
whisper_model = data.whisper_model
|
14
|
-
pid = getpid()
|
15
|
-
# print(f'{pid}: Loading model ...')
|
16
|
-
retry = 2
|
17
|
-
count = 0
|
18
|
-
while True:
|
19
|
-
try:
|
20
|
-
# To pre-download model, first provide whisper_model_name without whisper_model (or =None)
|
21
|
-
if whisper_model is None:
|
22
|
-
model = WhisperModel(data.whisper_model_name,
|
23
|
-
device=data.device,
|
24
|
-
cpu_threads=data.cpu_threads,
|
25
|
-
compute_type=data.compute_type)
|
26
|
-
else:
|
27
|
-
model = WhisperModel(whisper_model,
|
28
|
-
device=data.device,
|
29
|
-
cpu_threads=data.cpu_threads,
|
30
|
-
compute_type=data.compute_type,
|
31
|
-
local_files_only=True)
|
32
|
-
|
33
|
-
# print(f'{pid}: Done Loading, now transcribing ...')
|
34
|
-
s_time = timer()
|
35
|
-
segments, info = model.transcribe(data.audio, beam_size=int(data.beam_size))
|
36
|
-
segments = list(segments) # The transcription will actually run here.
|
37
|
-
e_time = timer()
|
38
|
-
elapsed = e_time - s_time
|
39
|
-
transcription = "".join(segment.text for segment in segments)
|
40
|
-
# print(f'{pid}: Done transcribing.')
|
41
|
-
tmp = ASRResult(text=transcription,
|
42
|
-
lang=info.language,
|
43
|
-
lang_prob=info.language_probability,
|
44
|
-
duration=info.duration,
|
45
|
-
num_segments=len(segments),
|
46
|
-
asr_cpu_time=elapsed
|
47
|
-
)
|
48
|
-
return tmp
|
49
|
-
except Exception as e:
|
50
|
-
count += 1
|
51
|
-
print(f'{pid}: Warning: fastwhisper exception: {e}')
|
52
|
-
if count >= retry:
|
53
|
-
raise SonusAIError(f'{pid}: Fastwhisper exception: {e.args}')
|
54
|
-
|
55
|
-
|
56
|
-
"""
|
57
|
-
Whisper results:
|
58
|
-
{
|
59
|
-
'text': ' The birch canoe slid on the smooth planks.',
|
60
|
-
'segments': [
|
61
|
-
{
|
62
|
-
'id': 0,
|
63
|
-
'seek': 0,
|
64
|
-
'start': 0.0,
|
65
|
-
'end': 2.4,
|
66
|
-
'text': ' The birch canoe slid on the smooth planks.',
|
67
|
-
'tokens': [
|
68
|
-
50363,
|
69
|
-
383,
|
70
|
-
35122,
|
71
|
-
354,
|
72
|
-
47434,
|
73
|
-
27803,
|
74
|
-
319,
|
75
|
-
262,
|
76
|
-
7209,
|
77
|
-
1410,
|
78
|
-
591,
|
79
|
-
13,
|
80
|
-
50483
|
81
|
-
],
|
82
|
-
'temperature': 0.0,
|
83
|
-
'avg_logprob': -0.4188103675842285,
|
84
|
-
'compression_ratio': 0.8571428571428571,
|
85
|
-
'no_speech_prob': 0.003438911633566022
|
86
|
-
}
|
87
|
-
],
|
88
|
-
'language': 'en'
|
89
|
-
}
|
90
|
-
"""
|
@@ -1,95 +0,0 @@
|
|
1
|
-
from sonusai.utils import ASRResult
|
2
|
-
from sonusai.utils.asr_functions.data import Data
|
3
|
-
|
4
|
-
|
5
|
-
def google(data: Data) -> ASRResult:
|
6
|
-
import tempfile
|
7
|
-
from os import getenv
|
8
|
-
from os.path import getsize
|
9
|
-
from os.path import join
|
10
|
-
from random import random
|
11
|
-
from time import sleep
|
12
|
-
|
13
|
-
import speech_recognition as sr
|
14
|
-
|
15
|
-
from sonusai import SonusAIError
|
16
|
-
from sonusai.utils import float_to_int16
|
17
|
-
from sonusai.utils import human_readable_size
|
18
|
-
from sonusai.utils import write_wav
|
19
|
-
|
20
|
-
key = getenv('GOOGLE_SPEECH_API_KEY')
|
21
|
-
if key is None:
|
22
|
-
raise SonusAIError('GOOGLE_SPEECH_API_KEY environment variable does not exist')
|
23
|
-
|
24
|
-
r = sr.Recognizer()
|
25
|
-
with tempfile.TemporaryDirectory() as tmp:
|
26
|
-
file = join(tmp, 'asr.wav')
|
27
|
-
write_wav(name=file, audio=float_to_int16(data.audio))
|
28
|
-
size = getsize(file)
|
29
|
-
if size > 10 * 1024 * 1024:
|
30
|
-
print(f'Warning: file size exceeds Google single request limit: {human_readable_size(size)} > 10 MB')
|
31
|
-
|
32
|
-
with sr.AudioFile(file) as source:
|
33
|
-
audio = r.record(source)
|
34
|
-
|
35
|
-
try:
|
36
|
-
retry = 5
|
37
|
-
count = 0
|
38
|
-
while True:
|
39
|
-
try:
|
40
|
-
sleep(count * (1 + random()))
|
41
|
-
results = r.recognize_google(audio, key=key, show_all=True)
|
42
|
-
if not isinstance(results, dict) or len(results.get('alternative', [])) == 0:
|
43
|
-
raise ValueError
|
44
|
-
break
|
45
|
-
except ValueError:
|
46
|
-
print(f'Warning: speech_recognition ValueError {count}\n{results}')
|
47
|
-
count += 1
|
48
|
-
if count >= retry:
|
49
|
-
raise SonusAIError(f'speech_recognition exception: ValueError retry count exceeded.')
|
50
|
-
|
51
|
-
if 'confidence' in results['alternative']:
|
52
|
-
# return alternative with highest confidence score
|
53
|
-
best_hypothesis = max(results['alternative'], key=lambda alternative: alternative['confidence'])
|
54
|
-
else:
|
55
|
-
# when there is no confidence available, we arbitrarily choose the first hypothesis.
|
56
|
-
best_hypothesis = results['alternative'][0]
|
57
|
-
if "transcript" not in best_hypothesis:
|
58
|
-
raise SonusAIError('speech_recognition: UnknownValueError')
|
59
|
-
confidence = best_hypothesis.get('confidence', 0.5)
|
60
|
-
return ASRResult(text=best_hypothesis['transcript'], confidence=confidence)
|
61
|
-
except sr.UnknownValueError:
|
62
|
-
return ASRResult(text='', confidence=0)
|
63
|
-
except sr.RequestError as e:
|
64
|
-
raise SonusAIError(f'Could not request results from Google Speech Recognition service: {e}')
|
65
|
-
|
66
|
-
|
67
|
-
"""
|
68
|
-
Google results:
|
69
|
-
{
|
70
|
-
"result": [
|
71
|
-
{
|
72
|
-
"alternative": [
|
73
|
-
{
|
74
|
-
"transcript": "the Birch canoe slid on the smooth planks",
|
75
|
-
"confidence": 0.94228178
|
76
|
-
},
|
77
|
-
{
|
78
|
-
"transcript": "the Burj canoe slid on the smooth planks"
|
79
|
-
},
|
80
|
-
{
|
81
|
-
"transcript": "the Birch canoe slid on the smooth plank"
|
82
|
-
},
|
83
|
-
{
|
84
|
-
"transcript": "the Birch canoe slit on the smooth planks"
|
85
|
-
},
|
86
|
-
{
|
87
|
-
"transcript": "the Birch canoes slid on the smooth planks"
|
88
|
-
}
|
89
|
-
],
|
90
|
-
"final": true
|
91
|
-
}
|
92
|
-
],
|
93
|
-
"result_index": 0
|
94
|
-
}
|
95
|
-
"""
|
@@ -1,49 +0,0 @@
|
|
1
|
-
from sonusai.utils import ASRResult
|
2
|
-
from sonusai.utils.asr_functions.data import Data
|
3
|
-
|
4
|
-
|
5
|
-
def whisper(data: Data) -> ASRResult:
|
6
|
-
from whisper import load_model
|
7
|
-
|
8
|
-
whisper_model = data.whisper_model
|
9
|
-
if whisper_model is None:
|
10
|
-
whisper_model = load_model(data.whisper_model_name, device=data.device)
|
11
|
-
|
12
|
-
return ASRResult(text=whisper_model.transcribe(data.audio, fp16=False)['text'])
|
13
|
-
|
14
|
-
|
15
|
-
"""
|
16
|
-
Whisper results:
|
17
|
-
{
|
18
|
-
'text': ' The birch canoe slid on the smooth planks.',
|
19
|
-
'segments': [
|
20
|
-
{
|
21
|
-
'id': 0,
|
22
|
-
'seek': 0,
|
23
|
-
'start': 0.0,
|
24
|
-
'end': 2.4,
|
25
|
-
'text': ' The birch canoe slid on the smooth planks.',
|
26
|
-
'tokens': [
|
27
|
-
50363,
|
28
|
-
383,
|
29
|
-
35122,
|
30
|
-
354,
|
31
|
-
47434,
|
32
|
-
27803,
|
33
|
-
319,
|
34
|
-
262,
|
35
|
-
7209,
|
36
|
-
1410,
|
37
|
-
591,
|
38
|
-
13,
|
39
|
-
50483
|
40
|
-
],
|
41
|
-
'temperature': 0.0,
|
42
|
-
'avg_logprob': -0.4188103675842285,
|
43
|
-
'compression_ratio': 0.8571428571428571,
|
44
|
-
'no_speech_prob': 0.003438911633566022
|
45
|
-
}
|
46
|
-
],
|
47
|
-
'language': 'en'
|
48
|
-
}
|
49
|
-
"""
|