sonusai 0.15.9__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +36 -4
- sonusai/audiofe.py +111 -106
- sonusai/calc_metric_spenh.py +38 -22
- sonusai/genft.py +15 -6
- sonusai/genmix.py +14 -6
- sonusai/genmixdb.py +15 -7
- sonusai/gentcst.py +13 -6
- sonusai/lsdb.py +15 -5
- sonusai/main.py +58 -61
- sonusai/mixture/__init__.py +1 -0
- sonusai/mixture/config.py +1 -2
- sonusai/mkmanifest.py +43 -8
- sonusai/mkwav.py +15 -6
- sonusai/onnx_predict.py +16 -6
- sonusai/plot.py +16 -6
- sonusai/post_spenh_targetf.py +13 -6
- sonusai/summarize_metric_spenh.py +71 -0
- sonusai/tplot.py +14 -6
- sonusai/utils/__init__.py +4 -7
- sonusai/utils/asl_p56.py +3 -3
- sonusai/utils/asr.py +35 -8
- sonusai/utils/asr_functions/__init__.py +0 -5
- sonusai/utils/asr_functions/aaware_whisper.py +2 -2
- sonusai/utils/asr_manifest_functions/__init__.py +1 -0
- sonusai/utils/asr_manifest_functions/mcgill_speech.py +29 -0
- sonusai/utils/{trim_docstring.py → docstring.py} +20 -0
- sonusai/utils/model_utils.py +30 -0
- sonusai/utils/onnx_utils.py +19 -45
- {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/METADATA +7 -25
- {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/RECORD +32 -46
- sonusai/data_generator/__init__.py +0 -5
- sonusai/data_generator/dataset_from_mixdb.py +0 -143
- sonusai/data_generator/keras_from_mixdb.py +0 -169
- sonusai/data_generator/torch_from_mixdb.py +0 -122
- sonusai/keras_onnx.py +0 -86
- sonusai/keras_predict.py +0 -231
- sonusai/keras_train.py +0 -334
- sonusai/torchl_onnx.py +0 -216
- sonusai/torchl_predict.py +0 -542
- sonusai/torchl_train.py +0 -223
- sonusai/utils/asr_functions/aixplain_whisper.py +0 -59
- sonusai/utils/asr_functions/data.py +0 -16
- sonusai/utils/asr_functions/deepgram.py +0 -97
- sonusai/utils/asr_functions/fastwhisper.py +0 -90
- sonusai/utils/asr_functions/google.py +0 -95
- sonusai/utils/asr_functions/whisper.py +0 -49
- sonusai/utils/keras_utils.py +0 -226
- {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/WHEEL +0 -0
- {sonusai-0.15.9.dist-info → sonusai-0.16.1.dist-info}/entry_points.txt +0 -0
@@ -1,169 +0,0 @@
|
|
1
|
-
import warnings
|
2
|
-
from dataclasses import dataclass
|
3
|
-
from multiprocessing import get_context
|
4
|
-
from os import cpu_count
|
5
|
-
from typing import Optional
|
6
|
-
|
7
|
-
import numpy as np
|
8
|
-
|
9
|
-
from sonusai.mixture import Feature
|
10
|
-
from sonusai.mixture import GeneralizedIDs
|
11
|
-
from sonusai.mixture import MixtureDatabase
|
12
|
-
from sonusai.mixture import Truth
|
13
|
-
|
14
|
-
with warnings.catch_warnings():
|
15
|
-
warnings.simplefilter('ignore')
|
16
|
-
from keras.utils import Sequence
|
17
|
-
|
18
|
-
|
19
|
-
@dataclass
|
20
|
-
class MPGlobal:
|
21
|
-
mixdb: MixtureDatabase = None
|
22
|
-
|
23
|
-
|
24
|
-
MP_GLOBAL = MPGlobal()
|
25
|
-
|
26
|
-
|
27
|
-
def _pool_initializer(location: str) -> None:
|
28
|
-
MP_GLOBAL.mixdb = MixtureDatabase(location)
|
29
|
-
|
30
|
-
|
31
|
-
def _pool_kernel(mixid: int) -> tuple[Feature, Truth]:
|
32
|
-
return MP_GLOBAL.mixdb.mixture_ft(mixid)
|
33
|
-
|
34
|
-
|
35
|
-
class KerasFromMixtureDatabase(Sequence):
|
36
|
-
"""Generates data for Keras from a SonusAI mixture database
|
37
|
-
"""
|
38
|
-
|
39
|
-
@dataclass(frozen=True)
|
40
|
-
class BatchParams:
|
41
|
-
mixids: list[int]
|
42
|
-
offset: int
|
43
|
-
extra: int
|
44
|
-
padding: int
|
45
|
-
|
46
|
-
def __init__(self,
|
47
|
-
mixdb: MixtureDatabase,
|
48
|
-
mixids: GeneralizedIDs,
|
49
|
-
batch_size: int,
|
50
|
-
timesteps: int,
|
51
|
-
flatten: bool,
|
52
|
-
add1ch: bool,
|
53
|
-
shuffle: bool = False):
|
54
|
-
"""Initialization
|
55
|
-
"""
|
56
|
-
self.mixdb = mixdb
|
57
|
-
self.mixids = self.mixdb.mixids_to_list(mixids)
|
58
|
-
self.batch_size = batch_size
|
59
|
-
self.timesteps = timesteps
|
60
|
-
self.flatten = flatten
|
61
|
-
self.add1ch = add1ch
|
62
|
-
self.shuffle = shuffle
|
63
|
-
self.stride = self.mixdb.fg_stride
|
64
|
-
self.feature_parameters = self.mixdb.feature_parameters
|
65
|
-
self.num_classes = self.mixdb.num_classes
|
66
|
-
self.mixture_frame_segments: Optional[int] = None
|
67
|
-
self.batch_frame_segments: Optional[int] = None
|
68
|
-
self.total_batches: Optional[int] = None
|
69
|
-
|
70
|
-
self._initialize_mixtures()
|
71
|
-
|
72
|
-
self.pool = get_context('fork').Pool(processes=cpu_count(),
|
73
|
-
initializer=_pool_initializer,
|
74
|
-
initargs=(self.mixdb.location,))
|
75
|
-
|
76
|
-
def __len__(self) -> int:
|
77
|
-
"""Denotes the number of batches per epoch
|
78
|
-
"""
|
79
|
-
return self.total_batches
|
80
|
-
|
81
|
-
def __getitem__(self, batch_index: int) -> tuple[np.ndarray, np.ndarray]:
|
82
|
-
"""Get one batch of data
|
83
|
-
"""
|
84
|
-
from sonusai.utils import reshape_inputs
|
85
|
-
|
86
|
-
batch_params = self.batch_params[batch_index]
|
87
|
-
|
88
|
-
result = self.pool.map(_pool_kernel, batch_params.mixids)
|
89
|
-
feature = np.vstack([result[i][0] for i in range(len(result))])
|
90
|
-
truth = np.vstack([result[i][1] for i in range(len(result))])
|
91
|
-
|
92
|
-
pad_shape = list(feature.shape)
|
93
|
-
pad_shape[0] = batch_params.padding
|
94
|
-
feature = np.vstack([feature, np.zeros(pad_shape)])
|
95
|
-
|
96
|
-
pad_shape = list(truth.shape)
|
97
|
-
pad_shape[0] = batch_params.padding
|
98
|
-
truth = np.vstack([truth, np.zeros(pad_shape)])
|
99
|
-
|
100
|
-
if batch_params.extra > 0:
|
101
|
-
feature = feature[batch_params.offset:-batch_params.extra]
|
102
|
-
truth = truth[batch_params.offset:-batch_params.extra]
|
103
|
-
else:
|
104
|
-
feature = feature[batch_params.offset:]
|
105
|
-
truth = truth[batch_params.offset:]
|
106
|
-
|
107
|
-
feature, truth = reshape_inputs(feature=feature,
|
108
|
-
truth=truth,
|
109
|
-
batch_size=self.batch_size,
|
110
|
-
timesteps=self.timesteps,
|
111
|
-
flatten=self.flatten,
|
112
|
-
add1ch=self.add1ch)
|
113
|
-
|
114
|
-
return feature, truth
|
115
|
-
|
116
|
-
def on_epoch_end(self) -> None:
|
117
|
-
"""Modification of dataset between epochs
|
118
|
-
"""
|
119
|
-
import random
|
120
|
-
|
121
|
-
if self.shuffle:
|
122
|
-
random.shuffle(self.mixids)
|
123
|
-
self._initialize_mixtures()
|
124
|
-
|
125
|
-
def _initialize_mixtures(self) -> None:
|
126
|
-
from sonusai.utils import get_frames_per_batch
|
127
|
-
|
128
|
-
frames_per_batch = get_frames_per_batch(self.batch_size, self.timesteps)
|
129
|
-
# Always extend the number of batches to use all available data
|
130
|
-
# The last batch may need padding
|
131
|
-
self.total_batches = int(np.ceil(self.mixdb.total_feature_frames(self.mixids) / frames_per_batch))
|
132
|
-
|
133
|
-
# Compute mixid, offset, and extra for dataset
|
134
|
-
# offsets and extras are needed because mixtures are not guaranteed to fall on batch boundaries.
|
135
|
-
# When fetching a new index that starts in the middle of a sequence of mixtures, the
|
136
|
-
# previous feature frame offset must be maintained in order to preserve the correct
|
137
|
-
# data sequence. And the extra must be maintained in order to preserve the correct data length.
|
138
|
-
cumulative_frames = 0
|
139
|
-
start_mixture_index = 0
|
140
|
-
offset = 0
|
141
|
-
self.batch_params = []
|
142
|
-
self.file_indices = []
|
143
|
-
total_frames = 0
|
144
|
-
for idx, mixid in enumerate(self.mixids):
|
145
|
-
current_frames = self.mixdb.mixture(mixid).samples // self.mixdb.feature_step_samples
|
146
|
-
self.file_indices.append(slice(total_frames, total_frames + current_frames))
|
147
|
-
total_frames += current_frames
|
148
|
-
cumulative_frames += current_frames
|
149
|
-
while cumulative_frames >= frames_per_batch:
|
150
|
-
extra = cumulative_frames - frames_per_batch
|
151
|
-
mixids = self.mixids[start_mixture_index:idx + 1]
|
152
|
-
self.batch_params.append(self.BatchParams(mixids=mixids, offset=offset, extra=extra, padding=0))
|
153
|
-
if extra == 0:
|
154
|
-
start_mixture_index = idx + 1
|
155
|
-
offset = 0
|
156
|
-
else:
|
157
|
-
start_mixture_index = idx
|
158
|
-
offset = current_frames - extra
|
159
|
-
cumulative_frames = extra
|
160
|
-
|
161
|
-
# If needed, add final batch with padding
|
162
|
-
needed_frames = self.total_batches * frames_per_batch
|
163
|
-
padding = needed_frames - total_frames
|
164
|
-
if padding != 0:
|
165
|
-
mixids = self.mixids[start_mixture_index:]
|
166
|
-
self.batch_params.append(self.BatchParams(mixids=mixids, offset=offset, extra=0, padding=padding))
|
167
|
-
|
168
|
-
|
169
|
-
KerasFromH5 = KerasFromMixtureDatabase
|
@@ -1,122 +0,0 @@
|
|
1
|
-
from typing import Optional
|
2
|
-
|
3
|
-
import numpy as np
|
4
|
-
from torch.utils.data import DataLoader
|
5
|
-
from torch.utils.data import Dataset
|
6
|
-
from torch.utils.data import Sampler
|
7
|
-
|
8
|
-
from sonusai.mixture import GeneralizedIDs
|
9
|
-
from sonusai.mixture import MixtureDatabase
|
10
|
-
|
11
|
-
|
12
|
-
class MixtureDatabaseDataset(Dataset):
|
13
|
-
"""Generates a PyTorch dataset from a SonusAI mixture database
|
14
|
-
"""
|
15
|
-
|
16
|
-
def __init__(self,
|
17
|
-
mixdb: MixtureDatabase,
|
18
|
-
mixids: GeneralizedIDs,
|
19
|
-
cut_len: int,
|
20
|
-
flatten: bool,
|
21
|
-
add1ch: bool,
|
22
|
-
random_cut: bool = True):
|
23
|
-
"""Initialization
|
24
|
-
"""
|
25
|
-
self.mixdb = mixdb
|
26
|
-
self.mixids = self.mixdb.mixids_to_list(mixids)
|
27
|
-
self.cut_len = cut_len
|
28
|
-
self.flatten = flatten
|
29
|
-
self.add1ch = add1ch
|
30
|
-
self.random_cut = random_cut
|
31
|
-
|
32
|
-
def __len__(self):
|
33
|
-
return len(self.mixids)
|
34
|
-
|
35
|
-
def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray, int]:
|
36
|
-
"""Get data from one mixture
|
37
|
-
"""
|
38
|
-
import random
|
39
|
-
|
40
|
-
from sonusai.utils import reshape_inputs
|
41
|
-
|
42
|
-
feature, truth = self.mixdb.mixture_ft(self.mixids[idx])
|
43
|
-
feature, truth = reshape_inputs(feature=feature,
|
44
|
-
truth=truth,
|
45
|
-
batch_size=1,
|
46
|
-
timesteps=0,
|
47
|
-
flatten=self.flatten,
|
48
|
-
add1ch=self.add1ch)
|
49
|
-
|
50
|
-
length = feature.shape[0]
|
51
|
-
|
52
|
-
if self.cut_len > 0:
|
53
|
-
if length < self.cut_len:
|
54
|
-
feature_final = []
|
55
|
-
truth_final = []
|
56
|
-
for _ in range(self.cut_len // length):
|
57
|
-
feature_final.append(feature)
|
58
|
-
truth_final.append(truth)
|
59
|
-
feature_final.append(feature[: self.cut_len % length])
|
60
|
-
truth_final.append(truth[: self.cut_len % length])
|
61
|
-
feature = np.vstack([feature_final[i] for i in range(len(feature_final))])
|
62
|
-
truth = np.vstack([truth_final[i] for i in range(len(truth_final))])
|
63
|
-
else:
|
64
|
-
if self.random_cut:
|
65
|
-
start = random.randint(0, length - self.cut_len)
|
66
|
-
else:
|
67
|
-
start = 0
|
68
|
-
feature = feature[start:start + self.cut_len]
|
69
|
-
truth = truth[start:start + self.cut_len]
|
70
|
-
|
71
|
-
return feature, truth, idx
|
72
|
-
|
73
|
-
|
74
|
-
class AawareDataLoader(DataLoader):
|
75
|
-
_cut_len: Optional[int] = None
|
76
|
-
|
77
|
-
@property
|
78
|
-
def cut_len(self) -> int:
|
79
|
-
return self._cut_len
|
80
|
-
|
81
|
-
@cut_len.setter
|
82
|
-
def cut_len(self, value: int) -> None:
|
83
|
-
self._cut_len = value
|
84
|
-
|
85
|
-
|
86
|
-
def TorchFromMixtureDatabase(mixdb: MixtureDatabase,
|
87
|
-
mixids: GeneralizedIDs,
|
88
|
-
batch_size: int,
|
89
|
-
flatten: bool,
|
90
|
-
add1ch: bool,
|
91
|
-
num_workers: int = 0,
|
92
|
-
cut_len: int = 0,
|
93
|
-
drop_last: bool = False,
|
94
|
-
shuffle: bool = False,
|
95
|
-
random_cut: bool = True,
|
96
|
-
sampler: Optional[type[Sampler]] = None,
|
97
|
-
pin_memory: bool = False) -> AawareDataLoader:
|
98
|
-
"""Generates a PyTorch dataloader from a SonusAI mixture database
|
99
|
-
"""
|
100
|
-
dataset = MixtureDatabaseDataset(mixdb=mixdb,
|
101
|
-
mixids=mixids,
|
102
|
-
cut_len=cut_len,
|
103
|
-
flatten=flatten,
|
104
|
-
add1ch=add1ch,
|
105
|
-
random_cut=random_cut)
|
106
|
-
|
107
|
-
if sampler is not None:
|
108
|
-
my_sampler = sampler(dataset)
|
109
|
-
else:
|
110
|
-
my_sampler = None
|
111
|
-
|
112
|
-
result = AawareDataLoader(dataset=dataset,
|
113
|
-
batch_size=batch_size,
|
114
|
-
pin_memory=pin_memory,
|
115
|
-
shuffle=shuffle,
|
116
|
-
sampler=my_sampler,
|
117
|
-
drop_last=drop_last,
|
118
|
-
num_workers=num_workers)
|
119
|
-
|
120
|
-
result.cut_len = cut_len
|
121
|
-
|
122
|
-
return result
|
sonusai/keras_onnx.py
DELETED
@@ -1,86 +0,0 @@
|
|
1
|
-
"""sonusai keras_onnx
|
2
|
-
|
3
|
-
usage: keras_onnx [-hvr] (-m MODEL) (-w WEIGHTS) [-b BATCH] [-t TSTEPS] [-o OUTPUT]
|
4
|
-
|
5
|
-
options:
|
6
|
-
-h, --help
|
7
|
-
-v, --verbose Be verbose.
|
8
|
-
-m MODEL, --model MODEL Python model file.
|
9
|
-
-w WEIGHTS, --weights WEIGHTS Keras model weights file.
|
10
|
-
-b BATCH, --batch BATCH Batch size.
|
11
|
-
-t TSTEPS, --tsteps TSTEPS Timesteps.
|
12
|
-
-o OUTPUT, --output OUTPUT Output directory.
|
13
|
-
|
14
|
-
Convert a trained Keras model to ONNX.
|
15
|
-
|
16
|
-
Inputs:
|
17
|
-
MODEL A SonusAI Python model file with build and/or hypermodel functions.
|
18
|
-
WEIGHTS A Keras model weights file (or model file with weights).
|
19
|
-
|
20
|
-
Outputs:
|
21
|
-
OUTPUT/ A directory containing:
|
22
|
-
<MODEL>.onnx Model file with batch_size and timesteps equal to provided parameters
|
23
|
-
<MODEL>-b1.onnx Model file with batch_size=1 and if the timesteps dimension exists it
|
24
|
-
is set to 1 (useful for real-time inference applications)
|
25
|
-
keras_onnx.log
|
26
|
-
|
27
|
-
Results are written into subdirectory <MODEL>-<TIMESTAMP> unless OUTPUT is specified.
|
28
|
-
|
29
|
-
"""
|
30
|
-
from sonusai import logger
|
31
|
-
|
32
|
-
|
33
|
-
def main() -> None:
|
34
|
-
from docopt import docopt
|
35
|
-
|
36
|
-
import sonusai
|
37
|
-
from sonusai.utils import trim_docstring
|
38
|
-
|
39
|
-
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
40
|
-
|
41
|
-
verbose = args['--verbose']
|
42
|
-
model_name = args['--model']
|
43
|
-
weight_name = args['--weights']
|
44
|
-
batch_size = args['--batch']
|
45
|
-
timesteps = args['--tsteps']
|
46
|
-
output_dir = args['--output']
|
47
|
-
|
48
|
-
from os import makedirs
|
49
|
-
from os.path import basename
|
50
|
-
from os.path import join
|
51
|
-
from os.path import splitext
|
52
|
-
|
53
|
-
from sonusai import create_file_handler
|
54
|
-
from sonusai import initial_log_messages
|
55
|
-
from sonusai import update_console_handler
|
56
|
-
from sonusai.utils import create_ts_name
|
57
|
-
from sonusai.utils import keras_onnx
|
58
|
-
|
59
|
-
model_tail = basename(model_name)
|
60
|
-
model_root = splitext(model_tail)[0]
|
61
|
-
|
62
|
-
if batch_size is not None:
|
63
|
-
batch_size = int(batch_size)
|
64
|
-
|
65
|
-
if timesteps is not None:
|
66
|
-
timesteps = int(timesteps)
|
67
|
-
|
68
|
-
if output_dir is None:
|
69
|
-
output_dir = create_ts_name(model_root)
|
70
|
-
|
71
|
-
makedirs(output_dir, exist_ok=True)
|
72
|
-
|
73
|
-
# Setup logging file
|
74
|
-
create_file_handler(join(output_dir, 'keras_onnx.log'))
|
75
|
-
update_console_handler(verbose)
|
76
|
-
initial_log_messages('keras_onnx')
|
77
|
-
|
78
|
-
keras_onnx(model_name, weight_name, timesteps, batch_size, output_dir)
|
79
|
-
|
80
|
-
|
81
|
-
if __name__ == '__main__':
|
82
|
-
try:
|
83
|
-
main()
|
84
|
-
except KeyboardInterrupt:
|
85
|
-
logger.info('Canceled due to keyboard interrupt')
|
86
|
-
exit()
|
sonusai/keras_predict.py
DELETED
@@ -1,231 +0,0 @@
|
|
1
|
-
"""sonusai keras_predict
|
2
|
-
|
3
|
-
usage: keras_predict [-hvr] [-i MIXID] (-m MODEL) (-w KMODEL) [-b BATCH] [-t TSTEPS] INPUT ...
|
4
|
-
|
5
|
-
options:
|
6
|
-
-h, --help
|
7
|
-
-v, --verbose Be verbose.
|
8
|
-
-i MIXID, --mixid MIXID Mixture ID(s) to generate if input is a mixture database. [default: *].
|
9
|
-
-m MODEL, --model MODEL Python model file.
|
10
|
-
-w KMODEL, --weights KMODEL Keras model weights file.
|
11
|
-
-b BATCH, --batch BATCH Batch size.
|
12
|
-
-t TSTEPS, --tsteps TSTEPS Timesteps.
|
13
|
-
-r, --reset Reset model between each file.
|
14
|
-
|
15
|
-
Run prediction on a trained Keras model defined by a SonusAI Keras Python model file using SonusAI genft or WAV data.
|
16
|
-
|
17
|
-
Inputs:
|
18
|
-
MODEL A SonusAI Python model file with build and/or hypermodel functions.
|
19
|
-
KMODEL A Keras model weights file (or model file with weights).
|
20
|
-
INPUT The input data must be one of the following:
|
21
|
-
* Single WAV file or glob of WAV files
|
22
|
-
Using the given model, generate feature data and run prediction. A model file must be
|
23
|
-
provided. The MIXID is ignored.
|
24
|
-
|
25
|
-
* directory
|
26
|
-
Using the given SonusAI mixture database directory, generate feature and truth data if not found.
|
27
|
-
Run prediction. The MIXID is required.
|
28
|
-
|
29
|
-
Outputs the following to kpredict-<TIMESTAMP> directory:
|
30
|
-
<id>.h5
|
31
|
-
dataset: predict
|
32
|
-
keras_predict.log
|
33
|
-
|
34
|
-
"""
|
35
|
-
from typing import Any
|
36
|
-
|
37
|
-
from sonusai import logger
|
38
|
-
from sonusai.mixture import Feature
|
39
|
-
from sonusai.mixture import Predict
|
40
|
-
|
41
|
-
|
42
|
-
def main() -> None:
|
43
|
-
from docopt import docopt
|
44
|
-
|
45
|
-
import sonusai
|
46
|
-
from sonusai.utils import trim_docstring
|
47
|
-
|
48
|
-
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
49
|
-
|
50
|
-
verbose = args['--verbose']
|
51
|
-
mixids = args['--mixid']
|
52
|
-
model_name = args['--model']
|
53
|
-
weights_name = args['--weights']
|
54
|
-
batch_size = args['--batch']
|
55
|
-
timesteps = args['--tsteps']
|
56
|
-
reset = args['--reset']
|
57
|
-
input_name = args['INPUT']
|
58
|
-
|
59
|
-
from os import makedirs
|
60
|
-
from os.path import basename
|
61
|
-
from os.path import isdir
|
62
|
-
from os.path import isfile
|
63
|
-
from os.path import join
|
64
|
-
from os.path import splitext
|
65
|
-
|
66
|
-
import h5py
|
67
|
-
import keras_tuner as kt
|
68
|
-
import tensorflow as tf
|
69
|
-
from keras import backend as kb
|
70
|
-
|
71
|
-
from sonusai import create_file_handler
|
72
|
-
from sonusai import initial_log_messages
|
73
|
-
from sonusai import update_console_handler
|
74
|
-
from sonusai.data_generator import KerasFromH5
|
75
|
-
from sonusai.mixture import MixtureDatabase
|
76
|
-
from sonusai.mixture import get_feature_from_audio
|
77
|
-
from sonusai.mixture import read_audio
|
78
|
-
from sonusai.utils import create_ts_name
|
79
|
-
from sonusai.utils import get_frames_per_batch
|
80
|
-
from sonusai.utils import import_and_check_keras_model
|
81
|
-
from sonusai.utils import reshape_outputs
|
82
|
-
|
83
|
-
if batch_size is not None:
|
84
|
-
batch_size = int(batch_size)
|
85
|
-
|
86
|
-
if timesteps is not None:
|
87
|
-
timesteps = int(timesteps)
|
88
|
-
|
89
|
-
output_dir = create_ts_name('kpredict')
|
90
|
-
makedirs(output_dir, exist_ok=True)
|
91
|
-
|
92
|
-
# Setup logging file
|
93
|
-
create_file_handler(join(output_dir, 'keras_predict.log'))
|
94
|
-
update_console_handler(verbose)
|
95
|
-
initial_log_messages('keras_predict')
|
96
|
-
|
97
|
-
logger.info(f'tensorflow {tf.__version__}')
|
98
|
-
logger.info(f'keras {tf.keras.__version__}')
|
99
|
-
logger.info('')
|
100
|
-
|
101
|
-
hypermodel = import_and_check_keras_model(model_name=model_name,
|
102
|
-
weights_name=weights_name,
|
103
|
-
timesteps=timesteps,
|
104
|
-
batch_size=batch_size)
|
105
|
-
built_model = hypermodel.build_model(kt.HyperParameters())
|
106
|
-
|
107
|
-
frames_per_batch = get_frames_per_batch(hypermodel.batch_size, hypermodel.timesteps)
|
108
|
-
|
109
|
-
kb.clear_session()
|
110
|
-
logger.info('')
|
111
|
-
built_model.summary(print_fn=logger.info)
|
112
|
-
logger.info('')
|
113
|
-
logger.info(f'feature {hypermodel.feature}')
|
114
|
-
logger.info(f'num_classes {hypermodel.num_classes}')
|
115
|
-
logger.info(f'batch_size {hypermodel.batch_size}')
|
116
|
-
logger.info(f'timesteps {hypermodel.timesteps}')
|
117
|
-
logger.info(f'flatten {hypermodel.flatten}')
|
118
|
-
logger.info(f'add1ch {hypermodel.add1ch}')
|
119
|
-
logger.info(f'truth_mutex {hypermodel.truth_mutex}')
|
120
|
-
logger.info(f'input_shape {hypermodel.input_shape}')
|
121
|
-
logger.info('')
|
122
|
-
|
123
|
-
logger.info(f'Loading weights from {weights_name}')
|
124
|
-
built_model.load_weights(weights_name)
|
125
|
-
|
126
|
-
logger.info('')
|
127
|
-
if len(input_name) == 1 and isdir(input_name[0]):
|
128
|
-
input_name = input_name[0]
|
129
|
-
logger.info(f'Load mixture database from {input_name}')
|
130
|
-
mixdb = MixtureDatabase(input_name)
|
131
|
-
|
132
|
-
if mixdb.feature != hypermodel.feature:
|
133
|
-
logger.exception(f'Feature in mixture database does not match feature in model')
|
134
|
-
raise SystemExit(1)
|
135
|
-
|
136
|
-
mixids = mixdb.mixids_to_list(mixids)
|
137
|
-
if reset:
|
138
|
-
# reset mode cycles through each file one at a time
|
139
|
-
for mixid in mixids:
|
140
|
-
feature, _ = mixdb.mixture_ft(mixid)
|
141
|
-
|
142
|
-
feature, predict = _pad_and_predict(hypermodel=hypermodel,
|
143
|
-
built_model=built_model,
|
144
|
-
feature=feature,
|
145
|
-
frames_per_batch=frames_per_batch)
|
146
|
-
|
147
|
-
output_name = join(output_dir, mixdb.mixtures[mixid].name)
|
148
|
-
with h5py.File(output_name, 'a') as f:
|
149
|
-
if 'predict' in f:
|
150
|
-
del f['predict']
|
151
|
-
f.create_dataset(name='predict', data=predict)
|
152
|
-
else:
|
153
|
-
# Run all data at once using a data generator
|
154
|
-
feature = KerasFromH5(mixdb=mixdb,
|
155
|
-
mixids=mixids,
|
156
|
-
batch_size=hypermodel.batch_size,
|
157
|
-
timesteps=hypermodel.timesteps,
|
158
|
-
flatten=hypermodel.flatten,
|
159
|
-
add1ch=hypermodel.add1ch)
|
160
|
-
|
161
|
-
predict = built_model.predict(feature, batch_size=hypermodel.batch_size, verbose=1)
|
162
|
-
predict, _ = reshape_outputs(predict=predict, timesteps=hypermodel.timesteps)
|
163
|
-
|
164
|
-
# Write data to separate files
|
165
|
-
for idx, mixid in enumerate(mixids):
|
166
|
-
output_name = join(output_dir, mixdb.mixtures[mixid].name)
|
167
|
-
with h5py.File(output_name, 'a') as f:
|
168
|
-
if 'predict' in f:
|
169
|
-
del f['predict']
|
170
|
-
f.create_dataset('predict', data=predict[feature.file_indices[idx]])
|
171
|
-
|
172
|
-
logger.info(f'Saved results to {output_dir}')
|
173
|
-
return
|
174
|
-
|
175
|
-
if not all(isfile(file) and splitext(file)[1] == '.wav' for file in input_name):
|
176
|
-
logger.exception(f'Do not know how to process input from {input_name}')
|
177
|
-
raise SystemExit(1)
|
178
|
-
|
179
|
-
logger.info(f'Run prediction on {len(input_name):,} WAV files')
|
180
|
-
for file in input_name:
|
181
|
-
# Convert WAV to feature data
|
182
|
-
audio = read_audio(file)
|
183
|
-
feature = get_feature_from_audio(audio=audio, feature_mode=hypermodel.feature)
|
184
|
-
|
185
|
-
feature, predict = _pad_and_predict(hypermodel=hypermodel,
|
186
|
-
built_model=built_model,
|
187
|
-
feature=feature,
|
188
|
-
frames_per_batch=frames_per_batch)
|
189
|
-
|
190
|
-
output_name = join(output_dir, splitext(basename(file))[0] + '.h5')
|
191
|
-
with h5py.File(output_name, 'a') as f:
|
192
|
-
if 'feature' in f:
|
193
|
-
del f['feature']
|
194
|
-
f.create_dataset(name='feature', data=feature)
|
195
|
-
|
196
|
-
if 'predict' in f:
|
197
|
-
del f['predict']
|
198
|
-
f.create_dataset(name='predict', data=predict)
|
199
|
-
|
200
|
-
logger.info(f'Saved results to {output_dir}')
|
201
|
-
|
202
|
-
|
203
|
-
def _pad_and_predict(hypermodel: Any,
|
204
|
-
built_model: Any,
|
205
|
-
feature: Feature,
|
206
|
-
frames_per_batch: int) -> tuple[Feature, Predict]:
|
207
|
-
import numpy as np
|
208
|
-
|
209
|
-
from sonusai.utils import reshape_inputs
|
210
|
-
from sonusai.utils import reshape_outputs
|
211
|
-
|
212
|
-
frames = feature.shape[0]
|
213
|
-
padding = frames_per_batch - frames % frames_per_batch
|
214
|
-
feature = np.pad(array=feature, pad_width=((0, padding), (0, 0), (0, 0)))
|
215
|
-
feature, _ = reshape_inputs(feature=feature,
|
216
|
-
batch_size=hypermodel.batch_size,
|
217
|
-
timesteps=hypermodel.timesteps,
|
218
|
-
flatten=hypermodel.flatten,
|
219
|
-
add1ch=hypermodel.add1ch)
|
220
|
-
predict = built_model.predict(feature, batch_size=hypermodel.batch_size, verbose=1)
|
221
|
-
predict, _ = reshape_outputs(predict=predict, timesteps=hypermodel.timesteps)
|
222
|
-
predict = predict[:frames, :]
|
223
|
-
return feature, predict
|
224
|
-
|
225
|
-
|
226
|
-
if __name__ == '__main__':
|
227
|
-
try:
|
228
|
-
main()
|
229
|
-
except KeyboardInterrupt:
|
230
|
-
logger.info('Canceled due to keyboard interrupt')
|
231
|
-
exit()
|