sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,519 @@
|
|
1
|
+
"""sonusai plot
|
2
|
+
|
3
|
+
usage: plot [-hve] [-i MIXID] [-m MODEL] [-l CSV] [-o OUTPUT] INPUT
|
4
|
+
|
5
|
+
options:
|
6
|
+
-h, --help
|
7
|
+
-v, --verbose Be verbose.
|
8
|
+
-i MIXID, --mixid MIXID Mixture to plot if input is a mixture database.
|
9
|
+
-m MODEL, --model MODEL Trained model ONNX file.
|
10
|
+
-l CSV, --labels CSV Optional CSV file of class labels (from SonusAI gentcst).
|
11
|
+
-o OUTPUT, --output OUTPUT Optional output HDF5 file for prediction.
|
12
|
+
-e, --energy Use energy plots.
|
13
|
+
|
14
|
+
Plot SonusAI audio, feature, truth, and prediction data. INPUT must be one of the following:
|
15
|
+
|
16
|
+
* WAV
|
17
|
+
Using the given model, generate feature data and run prediction. A model file must be
|
18
|
+
provided. The MIXID is ignored. If --energy is specified, plot predict data as energy.
|
19
|
+
|
20
|
+
* directory
|
21
|
+
Using the given SonusAI mixture database directory, generate feature and truth data if not found.
|
22
|
+
Run prediction if a model is given. The MIXID is required. (--energy is ignored.)
|
23
|
+
|
24
|
+
Prediction data will be written to OUTPUT if a model file is given and OUTPUT is specified.
|
25
|
+
|
26
|
+
There will be one plot per active truth index. In addition, the top 5 prediction classes are determined and
|
27
|
+
plotted if needed (i.e., if they were not already included in the truth plots). For plots generated using a
|
28
|
+
mixture database, then the target will also be displayed. If mixup is active, then each target involved will
|
29
|
+
be added to the corresponding truth plot.
|
30
|
+
|
31
|
+
Inputs:
|
32
|
+
MODEL A SonusAI trained ONNX model file. If a model file is given, prediction data will be
|
33
|
+
generated.
|
34
|
+
INPUT A WAV file, or
|
35
|
+
a directory containing a SonusAI mixture database
|
36
|
+
|
37
|
+
Outputs:
|
38
|
+
{INPUT}-plot.pdf or {INPUT}-mix{MIXID}-plot.pdf
|
39
|
+
plot.log
|
40
|
+
OUTPUT (if MODEL and OUTPUT are both specified)
|
41
|
+
|
42
|
+
"""
|
43
|
+
|
44
|
+
import signal
|
45
|
+
|
46
|
+
import numpy as np
|
47
|
+
from matplotlib import pyplot as plt
|
48
|
+
from sonusai.datatypes import AudioT
|
49
|
+
from sonusai.datatypes import Feature
|
50
|
+
from sonusai.datatypes import Predict
|
51
|
+
from sonusai.datatypes import Truth
|
52
|
+
|
53
|
+
|
54
|
+
def signal_handler(_sig, _frame):
|
55
|
+
import sys
|
56
|
+
|
57
|
+
from sonusai import logger
|
58
|
+
|
59
|
+
logger.info("Canceled due to keyboard interrupt")
|
60
|
+
sys.exit(1)
|
61
|
+
|
62
|
+
|
63
|
+
signal.signal(signal.SIGINT, signal_handler)
|
64
|
+
|
65
|
+
|
66
|
+
def spec_plot(
|
67
|
+
mixture: AudioT,
|
68
|
+
feature: Feature,
|
69
|
+
predict: Predict | None = None,
|
70
|
+
target: AudioT | None = None,
|
71
|
+
labels: list[str] | None = None,
|
72
|
+
title: str = "",
|
73
|
+
) -> plt.Figure:
|
74
|
+
from sonusai.constants import SAMPLE_RATE
|
75
|
+
|
76
|
+
num_plots = 4 if predict is not None else 2
|
77
|
+
fig, ax = plt.subplots(num_plots, 1, constrained_layout=True, figsize=(11, 8.5))
|
78
|
+
|
79
|
+
# Plot the waveform
|
80
|
+
x_axis = np.arange(len(mixture), dtype=np.float32) / SAMPLE_RATE
|
81
|
+
ax[0].plot(x_axis, mixture, label="Mixture")
|
82
|
+
ax[0].set_xlim(x_axis[0], x_axis[-1])
|
83
|
+
ax[0].set_ylim([-1.025, 1.025])
|
84
|
+
if target is not None:
|
85
|
+
# Plot target time-domain waveform on top of mixture
|
86
|
+
color = "tab:blue"
|
87
|
+
ax[0].plot(x_axis, target, color=color, label="Target")
|
88
|
+
ax[0].set_ylabel("magnitude", color=color)
|
89
|
+
ax[0].set_title("Waveform")
|
90
|
+
|
91
|
+
# Plot the spectrogram
|
92
|
+
ax[1].imshow(np.transpose(feature), aspect="auto", interpolation="nearest", origin="lower")
|
93
|
+
ax[1].set_title("Feature")
|
94
|
+
|
95
|
+
if predict is not None:
|
96
|
+
if labels is None:
|
97
|
+
raise ValueError("Provided predict without labels")
|
98
|
+
|
99
|
+
# Plot and label the model output scores for the top-scoring classes.
|
100
|
+
mean_predict = np.mean(predict, axis=0)
|
101
|
+
num_classes = predict.shape[-1]
|
102
|
+
top_n = min(10, num_classes)
|
103
|
+
top_class_indices = np.argsort(mean_predict)[::-1][:top_n]
|
104
|
+
ax[2].imshow(
|
105
|
+
np.transpose(predict[:, top_class_indices]),
|
106
|
+
aspect="auto",
|
107
|
+
interpolation="nearest",
|
108
|
+
cmap="gray_r",
|
109
|
+
)
|
110
|
+
y_ticks = range(0, top_n)
|
111
|
+
ax[2].set_yticks(y_ticks, [labels[top_class_indices[x]] for x in y_ticks])
|
112
|
+
ax[2].set_ylim(-0.5 + np.array([top_n, 0]))
|
113
|
+
ax[2].set_title("Class Scores")
|
114
|
+
|
115
|
+
# Plot the probabilities
|
116
|
+
ax[3].plot(predict[:, top_class_indices])
|
117
|
+
ax[3].legend(np.array(labels)[top_class_indices], loc="best")
|
118
|
+
ax[3].set_title("Class Probabilities")
|
119
|
+
|
120
|
+
fig.suptitle(title)
|
121
|
+
|
122
|
+
return fig
|
123
|
+
|
124
|
+
|
125
|
+
def spec_energy_plot(
|
126
|
+
mixture: AudioT, feature: Feature, truth_f: Truth | None = None, predict: Predict | None = None
|
127
|
+
) -> plt.Figure:
|
128
|
+
from sonusai.constants import SAMPLE_RATE
|
129
|
+
|
130
|
+
num_plots = 2
|
131
|
+
if truth_f is not None:
|
132
|
+
num_plots += 1
|
133
|
+
if predict is not None:
|
134
|
+
num_plots += 1
|
135
|
+
|
136
|
+
fig, ax = plt.subplots(num_plots, 1, constrained_layout=True, figsize=(11, 8.5))
|
137
|
+
|
138
|
+
# Plot the waveform
|
139
|
+
p = 0
|
140
|
+
x_axis = np.arange(len(mixture), dtype=np.float32) / SAMPLE_RATE
|
141
|
+
ax[p].plot(x_axis, mixture, label="Mixture")
|
142
|
+
ax[p].set_xlim(x_axis[0], x_axis[-1])
|
143
|
+
ax[p].set_ylim([-1.025, 1.025])
|
144
|
+
ax[p].set_title("Waveform")
|
145
|
+
|
146
|
+
# Plot the spectrogram
|
147
|
+
p += 1
|
148
|
+
ax[p].imshow(np.transpose(feature), aspect="auto", interpolation="nearest", origin="lower")
|
149
|
+
ax[p].set_title("Feature")
|
150
|
+
|
151
|
+
if truth_f is not None:
|
152
|
+
p += 1
|
153
|
+
ax[p].imshow(
|
154
|
+
np.transpose(truth_f),
|
155
|
+
aspect="auto",
|
156
|
+
interpolation="nearest",
|
157
|
+
origin="lower",
|
158
|
+
)
|
159
|
+
ax[p].set_title("Truth")
|
160
|
+
|
161
|
+
if predict is not None:
|
162
|
+
p += 1
|
163
|
+
ax[p].imshow(
|
164
|
+
np.transpose(predict),
|
165
|
+
aspect="auto",
|
166
|
+
interpolation="nearest",
|
167
|
+
origin="lower",
|
168
|
+
)
|
169
|
+
ax[p].set_title("Predict")
|
170
|
+
|
171
|
+
return fig
|
172
|
+
|
173
|
+
|
174
|
+
def class_plot(
|
175
|
+
mixture: AudioT,
|
176
|
+
target: AudioT | None = None,
|
177
|
+
truth_f: Truth | None = None,
|
178
|
+
predict: Predict | None = None,
|
179
|
+
label: str = "",
|
180
|
+
) -> plt.Figure:
|
181
|
+
"""Plot mixture waveform with optional prediction and/or truth together in a single plot
|
182
|
+
|
183
|
+
The target waveform can optionally be provided, and prediction and truth can have multiple classes.
|
184
|
+
|
185
|
+
Inputs:
|
186
|
+
mixture required, numpy array [samples, 1]
|
187
|
+
target optional, list of numpy arrays [samples, 1]
|
188
|
+
truth_f optional, numpy array [frames, 1]
|
189
|
+
predict optional, numpy array [frames, 1]
|
190
|
+
label optional, label name to use when plotting
|
191
|
+
|
192
|
+
"""
|
193
|
+
from sonusai import SonusAIError
|
194
|
+
from sonusai.constants import SAMPLE_RATE
|
195
|
+
|
196
|
+
if mixture.ndim != 1:
|
197
|
+
raise SonusAIError("Too many dimensions in mixture")
|
198
|
+
|
199
|
+
if target is not None and target.ndim != 1:
|
200
|
+
raise SonusAIError("Too many dimensions in target")
|
201
|
+
|
202
|
+
# Set default to 1 frame when there is no truth or predict data
|
203
|
+
frames = 1
|
204
|
+
if truth_f is not None and predict is not None:
|
205
|
+
if truth_f.ndim != 1:
|
206
|
+
raise SonusAIError("Too many dimensions in truth_f")
|
207
|
+
t_frames = len(truth_f)
|
208
|
+
|
209
|
+
if predict.ndim != 1:
|
210
|
+
raise SonusAIError("Too many dimensions in predict")
|
211
|
+
p_frames = len(predict)
|
212
|
+
|
213
|
+
frames = min(t_frames, p_frames)
|
214
|
+
elif truth_f is not None:
|
215
|
+
if truth_f.ndim != 1:
|
216
|
+
raise SonusAIError("Too many dimensions in truth_f")
|
217
|
+
frames = len(truth_f)
|
218
|
+
elif predict is not None:
|
219
|
+
if predict.ndim != 1:
|
220
|
+
raise SonusAIError("Too many dimensions in predict")
|
221
|
+
frames = len(predict)
|
222
|
+
|
223
|
+
samples = (len(mixture) // frames) * frames
|
224
|
+
|
225
|
+
# x-axis in sec
|
226
|
+
x_axis = np.arange(samples, dtype=np.float32) / SAMPLE_RATE
|
227
|
+
|
228
|
+
fig, ax = plt.subplots(1, 1, constrained_layout=True, figsize=(11, 8.5))
|
229
|
+
|
230
|
+
# Plot the time-domain waveforms then truth/prediction on second axis
|
231
|
+
ax.plot(x_axis, mixture[0:samples], color="mistyrose", label="Mixture")
|
232
|
+
color = "red"
|
233
|
+
ax.set_xlim(x_axis[0], x_axis[-1])
|
234
|
+
ax.set_ylim((-1.025, 1.025))
|
235
|
+
ax.set_ylabel("Amplitude", color=color)
|
236
|
+
ax.tick_params(axis="y", labelcolor=color)
|
237
|
+
|
238
|
+
# Plot target time-domain waveform
|
239
|
+
if target is not None:
|
240
|
+
ax.plot(x_axis, target[0:samples], color="blue", label="Target")
|
241
|
+
|
242
|
+
# instantiate 2nd y-axis that shares the same x-axis
|
243
|
+
if truth_f is not None or predict is not None:
|
244
|
+
y_label = "Truth/Predict"
|
245
|
+
if truth_f is None:
|
246
|
+
y_label = "Predict"
|
247
|
+
if predict is None:
|
248
|
+
y_label = "Truth"
|
249
|
+
|
250
|
+
ax2 = ax.twinx()
|
251
|
+
|
252
|
+
color = "black"
|
253
|
+
ax2.set_xlim(x_axis[0], x_axis[-1])
|
254
|
+
ax2.set_ylim((-0.025, 1.025))
|
255
|
+
ax2.set_ylabel(y_label, color=color)
|
256
|
+
ax2.tick_params(axis="y", labelcolor=color)
|
257
|
+
|
258
|
+
if truth_f is not None:
|
259
|
+
ax2.plot(
|
260
|
+
x_axis,
|
261
|
+
expand_frames_to_samples(truth_f, samples),
|
262
|
+
color="green",
|
263
|
+
label="Truth",
|
264
|
+
)
|
265
|
+
|
266
|
+
if predict is not None:
|
267
|
+
ax2.plot(
|
268
|
+
x_axis,
|
269
|
+
expand_frames_to_samples(predict, samples),
|
270
|
+
color="brown",
|
271
|
+
label="Predict",
|
272
|
+
)
|
273
|
+
|
274
|
+
# set only on last/bottom plot
|
275
|
+
ax.set_xlabel("time (s)")
|
276
|
+
|
277
|
+
fig.suptitle(label)
|
278
|
+
|
279
|
+
return fig
|
280
|
+
|
281
|
+
|
282
|
+
def expand_frames_to_samples(x: np.ndarray, samples: int) -> np.ndarray:
|
283
|
+
samples_per_frame = samples // len(x)
|
284
|
+
return np.reshape(np.tile(np.expand_dims(x, 1), [1, samples_per_frame]), samples)
|
285
|
+
|
286
|
+
|
287
|
+
def main() -> None:
|
288
|
+
from docopt import docopt
|
289
|
+
|
290
|
+
import sonusai
|
291
|
+
from sonusai.utils.docstring import trim_docstring
|
292
|
+
|
293
|
+
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
294
|
+
|
295
|
+
from dataclasses import asdict
|
296
|
+
from os.path import basename
|
297
|
+
from os.path import exists
|
298
|
+
from os.path import isdir
|
299
|
+
from os.path import splitext
|
300
|
+
|
301
|
+
import h5py
|
302
|
+
from matplotlib.backends.backend_pdf import PdfPages
|
303
|
+
from pyaaware import FeatureGenerator
|
304
|
+
from pyaaware import Predict
|
305
|
+
|
306
|
+
from sonusai import SonusAIError
|
307
|
+
from sonusai import create_file_handler
|
308
|
+
from sonusai import initial_log_messages
|
309
|
+
from sonusai import logger
|
310
|
+
from sonusai import update_console_handler
|
311
|
+
from sonusai.mixture import FeatureGeneratorConfig
|
312
|
+
from sonusai.mixture import MixtureDatabase
|
313
|
+
from sonusai.mixture import get_feature_from_audio
|
314
|
+
from sonusai.mixture import get_truth_indices_for_mixid
|
315
|
+
from sonusai.mixture.audio import read_audio
|
316
|
+
from sonusai.utils.get_label_names import get_label_names
|
317
|
+
from sonusai.utils.print_mixture_details import print_mixture_details
|
318
|
+
|
319
|
+
verbose = args["--verbose"]
|
320
|
+
model_name = args["--model"]
|
321
|
+
output_name = args["--output"]
|
322
|
+
labels_name = args["--labels"]
|
323
|
+
mixid = args["--mixid"]
|
324
|
+
energy = args["--energy"]
|
325
|
+
input_name = args["INPUT"]
|
326
|
+
|
327
|
+
if mixid is not None:
|
328
|
+
mixid = int(mixid)
|
329
|
+
|
330
|
+
create_file_handler("plot.log")
|
331
|
+
update_console_handler(verbose)
|
332
|
+
initial_log_messages("plot")
|
333
|
+
|
334
|
+
if not exists(input_name):
|
335
|
+
raise SonusAIError(f"{input_name} does not exist")
|
336
|
+
|
337
|
+
logger.info("")
|
338
|
+
logger.info(f"Input: {input_name}")
|
339
|
+
if model_name is not None:
|
340
|
+
logger.info(f"Model: {model_name}")
|
341
|
+
if output_name is not None:
|
342
|
+
logger.info(f"Output: {output_name}")
|
343
|
+
logger.info("")
|
344
|
+
|
345
|
+
ext = splitext(input_name)[1]
|
346
|
+
|
347
|
+
model = None
|
348
|
+
target_audio = None
|
349
|
+
truth_f = None
|
350
|
+
t_indices = []
|
351
|
+
|
352
|
+
if model_name is not None:
|
353
|
+
model = Predict(model_name)
|
354
|
+
|
355
|
+
if ext == ".wav":
|
356
|
+
if model is None:
|
357
|
+
raise SonusAIError("Must specify MODEL when input is WAV")
|
358
|
+
|
359
|
+
mixture_audio = read_audio(input_name)
|
360
|
+
feature = get_feature_from_audio(audio=mixture_audio, feature_mode=model.feature)
|
361
|
+
fg_config = FeatureGeneratorConfig(
|
362
|
+
feature_mode=model.feature,
|
363
|
+
num_classes=model.output_shape[-1],
|
364
|
+
truth_mutex=False,
|
365
|
+
)
|
366
|
+
fg = FeatureGenerator(**asdict(fg_config))
|
367
|
+
fg_step = fg.step
|
368
|
+
mixdb = None
|
369
|
+
logger.debug(f"Audio samples {len(mixture_audio)}")
|
370
|
+
logger.debug(f"Feature shape {feature.shape}")
|
371
|
+
|
372
|
+
elif isdir(input_name):
|
373
|
+
if mixid is None:
|
374
|
+
raise SonusAIError("Must specify mixid when input is mixture database")
|
375
|
+
|
376
|
+
mixdb = MixtureDatabase(input_name)
|
377
|
+
fg_step = mixdb.fg_step
|
378
|
+
|
379
|
+
print_mixture_details(mixdb=mixdb, mixid=mixid, desc_len=24, print_fn=logger.info)
|
380
|
+
|
381
|
+
logger.info(f"Generating data for mixture {mixid}")
|
382
|
+
mixture_audio = mixdb.mixture_mixture(mixid)
|
383
|
+
target_audio = mixdb.mixture_target(mixid)
|
384
|
+
feature, truth_f = mixdb.mixture_ft(mixid)
|
385
|
+
t_indices = [x - 1 for x in get_truth_indices_for_mixid(mixdb=mixdb, mixid=mixid)]
|
386
|
+
|
387
|
+
target_files = [mixdb.target_file(target.file_id) for target in mixdb.mixtures[mixid].targets]
|
388
|
+
truth_functions = list({sub2.function for sub1 in target_files for sub2 in sub1.truth_configs})
|
389
|
+
energy = "energy_f" in truth_functions or "snr_f" in truth_functions
|
390
|
+
|
391
|
+
logger.debug(f"Audio samples {len(mixture_audio)}")
|
392
|
+
logger.debug("Targets:")
|
393
|
+
mixture = mixdb.mixture(mixid)
|
394
|
+
for target in mixture.targets:
|
395
|
+
target_file = mixdb.target_file(target.file_id)
|
396
|
+
name = target_file.name
|
397
|
+
duration = target_file.duration
|
398
|
+
augmentation = target.augmentation
|
399
|
+
logger.debug(f" Name {name}")
|
400
|
+
logger.debug(f" Duration {duration}")
|
401
|
+
logger.debug(f" Augmentation {augmentation}")
|
402
|
+
|
403
|
+
logger.debug(f"Feature shape {feature.shape}")
|
404
|
+
logger.debug(f"Truth shape {truth_f.shape}")
|
405
|
+
|
406
|
+
else:
|
407
|
+
raise SonusAIError(f"Unknown file type for {input_name}")
|
408
|
+
|
409
|
+
predict = None
|
410
|
+
labels = None
|
411
|
+
indices = []
|
412
|
+
if model is not None:
|
413
|
+
logger.debug("")
|
414
|
+
logger.info(f"Running prediction on mixture {mixid}")
|
415
|
+
logger.debug(f"Model feature name {model.feature}")
|
416
|
+
logger.debug(f"Model input shape {model.input_shape}")
|
417
|
+
logger.debug(f"Model output shape {model.output_shape}")
|
418
|
+
|
419
|
+
if feature.shape[0] < model.input_shape[0]:
|
420
|
+
raise SonusAIError(
|
421
|
+
f"Mixture {mixid} contains {feature.shape[0]} "
|
422
|
+
f"frames of data which is not enough to run prediction; "
|
423
|
+
f"at least {model.input_shape[0]} frames are needed for this model.\n"
|
424
|
+
f"Consider using a model with a smaller batch size or a mixture with more data."
|
425
|
+
)
|
426
|
+
|
427
|
+
predict = model.execute(feature)
|
428
|
+
|
429
|
+
labels = get_label_names(num_labels=predict.shape[1], file=labels_name)
|
430
|
+
|
431
|
+
# Report the highest-scoring classes and their scores.
|
432
|
+
p_max = np.max(predict, axis=0)
|
433
|
+
p_indices = np.argsort(p_max)[::-1][:5]
|
434
|
+
p_max_len = max([len(labels[i]) for i in p_indices])
|
435
|
+
|
436
|
+
logger.info("Top 5 active prediction classes by max:")
|
437
|
+
for p_index in p_indices:
|
438
|
+
logger.info(f" {labels[p_index]:{p_max_len}s} {p_max[p_index]:.3f}")
|
439
|
+
logger.info("")
|
440
|
+
|
441
|
+
indices = list(p_indices)
|
442
|
+
|
443
|
+
# Add truth indices for target (if needed)
|
444
|
+
for t_index in t_indices:
|
445
|
+
if t_index not in indices:
|
446
|
+
indices.append(t_index)
|
447
|
+
|
448
|
+
base_name = basename(splitext(input_name)[0])
|
449
|
+
if mixdb is not None:
|
450
|
+
title = f"{input_name} Mixture {mixid}"
|
451
|
+
pdf_name = f"{base_name}-mix{mixid}-plot.pdf"
|
452
|
+
else:
|
453
|
+
title = f"{input_name}"
|
454
|
+
pdf_name = f"{base_name}-plot.pdf"
|
455
|
+
|
456
|
+
# Original size [frames, stride, feature_parameters]
|
457
|
+
# Decimate in the stride dimension
|
458
|
+
# Reshape to get frames*decimated_stride, feature_parameters
|
459
|
+
if feature.ndim != 3:
|
460
|
+
raise SonusAIError("feature does not have 3 dimensions: frames, stride, feature_parameters")
|
461
|
+
spectrogram = feature[:, -fg_step:, :]
|
462
|
+
spectrogram = np.reshape(spectrogram, (spectrogram.shape[0] * spectrogram.shape[1], spectrogram.shape[2]))
|
463
|
+
|
464
|
+
with PdfPages(pdf_name) as pdf:
|
465
|
+
pdf.savefig(
|
466
|
+
spec_plot(
|
467
|
+
mixture=mixture_audio,
|
468
|
+
feature=spectrogram,
|
469
|
+
predict=predict,
|
470
|
+
labels=labels,
|
471
|
+
title=title,
|
472
|
+
)
|
473
|
+
)
|
474
|
+
for index in indices:
|
475
|
+
if energy:
|
476
|
+
t_tmp = None
|
477
|
+
if truth_f is not None:
|
478
|
+
t_tmp = 10 * np.log10(truth_f + np.finfo(np.float32).eps)
|
479
|
+
|
480
|
+
p_tmp = None
|
481
|
+
if predict is not None:
|
482
|
+
p_tmp = 10 * np.log10(predict + np.finfo(np.float32).eps)
|
483
|
+
|
484
|
+
pdf.savefig(
|
485
|
+
spec_energy_plot(
|
486
|
+
mixture=mixture_audio,
|
487
|
+
feature=spectrogram,
|
488
|
+
truth_f=t_tmp,
|
489
|
+
predict=p_tmp,
|
490
|
+
)
|
491
|
+
)
|
492
|
+
else:
|
493
|
+
p_tmp = None
|
494
|
+
if predict is not None:
|
495
|
+
p_tmp = predict[:, index]
|
496
|
+
|
497
|
+
l_tmp = None
|
498
|
+
if labels is not None:
|
499
|
+
l_tmp = labels[index]
|
500
|
+
|
501
|
+
pdf.savefig(
|
502
|
+
class_plot(
|
503
|
+
mixture=mixture_audio,
|
504
|
+
target=target_audio[index],
|
505
|
+
truth_f=truth_f[:, index],
|
506
|
+
predict=p_tmp,
|
507
|
+
label=l_tmp,
|
508
|
+
)
|
509
|
+
)
|
510
|
+
logger.info(f"Wrote {pdf_name}")
|
511
|
+
|
512
|
+
if output_name:
|
513
|
+
with h5py.File(output_name, "w") as f:
|
514
|
+
f.create_dataset(name="predict", data=predict)
|
515
|
+
logger.info(f"Wrote {output_name}")
|
516
|
+
|
517
|
+
|
518
|
+
if __name__ == "__main__":
|
519
|
+
main()
|