sonusai 0.18.9__py3-none-any.whl → 0.19.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +81 -91
- sonusai/genmetrics.py +51 -61
- sonusai/genmix.py +105 -115
- sonusai/genmixdb.py +201 -174
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +16 -18
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +20 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +40 -38
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +669 -477
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +58 -101
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +41 -30
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/METADATA +20 -21
- sonusai-0.19.6.dist-info/RECORD +125 -0
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.9.dist-info/RECORD +0 -125
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/entry_points.txt +0 -0
@@ -57,19 +57,21 @@ def signal_handler(_sig, _frame):
|
|
57
57
|
|
58
58
|
from sonusai import logger
|
59
59
|
|
60
|
-
logger.info(
|
60
|
+
logger.info("Canceled due to keyboard interrupt")
|
61
61
|
sys.exit(1)
|
62
62
|
|
63
63
|
|
64
64
|
signal.signal(signal.SIGINT, signal_handler)
|
65
65
|
|
66
66
|
|
67
|
-
def spec_plot(
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
67
|
+
def spec_plot(
|
68
|
+
mixture: AudioT,
|
69
|
+
feature: Feature,
|
70
|
+
predict: Predict | None = None,
|
71
|
+
target: AudioT | None = None,
|
72
|
+
labels: list[str] | None = None,
|
73
|
+
title: str = "",
|
74
|
+
) -> plt.Figure:
|
73
75
|
from sonusai.mixture import SAMPLE_RATE
|
74
76
|
|
75
77
|
num_plots = 4 if predict is not None else 2
|
@@ -77,46 +79,53 @@ def spec_plot(mixture: AudioT,
|
|
77
79
|
|
78
80
|
# Plot the waveform
|
79
81
|
x_axis = np.arange(len(mixture), dtype=np.float32) / SAMPLE_RATE
|
80
|
-
ax[0].plot(x_axis, mixture, label=
|
82
|
+
ax[0].plot(x_axis, mixture, label="Mixture")
|
81
83
|
ax[0].set_xlim(x_axis[0], x_axis[-1])
|
82
84
|
ax[0].set_ylim([-1.025, 1.025])
|
83
85
|
if target is not None:
|
84
86
|
# Plot target time-domain waveform on top of mixture
|
85
|
-
color =
|
86
|
-
ax[0].plot(x_axis, target, color=color, label=
|
87
|
-
ax[0].set_ylabel(
|
88
|
-
ax[0].set_title(
|
87
|
+
color = "tab:blue"
|
88
|
+
ax[0].plot(x_axis, target, color=color, label="Target")
|
89
|
+
ax[0].set_ylabel("magnitude", color=color)
|
90
|
+
ax[0].set_title("Waveform")
|
89
91
|
|
90
92
|
# Plot the spectrogram
|
91
|
-
ax[1].imshow(np.transpose(feature), aspect=
|
92
|
-
ax[1].set_title(
|
93
|
+
ax[1].imshow(np.transpose(feature), aspect="auto", interpolation="nearest", origin="lower")
|
94
|
+
ax[1].set_title("Feature")
|
93
95
|
|
94
96
|
if predict is not None:
|
97
|
+
if labels is None:
|
98
|
+
raise ValueError("Provided predict without labels")
|
99
|
+
|
95
100
|
# Plot and label the model output scores for the top-scoring classes.
|
96
101
|
mean_predict = np.mean(predict, axis=0)
|
97
102
|
num_classes = predict.shape[-1]
|
98
103
|
top_n = min(10, num_classes)
|
99
104
|
top_class_indices = np.argsort(mean_predict)[::-1][:top_n]
|
100
|
-
ax[2].imshow(
|
105
|
+
ax[2].imshow(
|
106
|
+
np.transpose(predict[:, top_class_indices]),
|
107
|
+
aspect="auto",
|
108
|
+
interpolation="nearest",
|
109
|
+
cmap="gray_r",
|
110
|
+
)
|
101
111
|
y_ticks = range(0, top_n)
|
102
112
|
ax[2].set_yticks(y_ticks, [labels[top_class_indices[x]] for x in y_ticks])
|
103
113
|
ax[2].set_ylim(-0.5 + np.array([top_n, 0]))
|
104
|
-
ax[2].set_title(
|
114
|
+
ax[2].set_title("Class Scores")
|
105
115
|
|
106
116
|
# Plot the probabilities
|
107
117
|
ax[3].plot(predict[:, top_class_indices])
|
108
|
-
ax[3].legend(np.array(labels)[top_class_indices], loc=
|
109
|
-
ax[3].set_title(
|
118
|
+
ax[3].legend(np.array(labels)[top_class_indices], loc="best")
|
119
|
+
ax[3].set_title("Class Probabilities")
|
110
120
|
|
111
121
|
fig.suptitle(title)
|
112
122
|
|
113
123
|
return fig
|
114
124
|
|
115
125
|
|
116
|
-
def spec_energy_plot(
|
117
|
-
|
118
|
-
|
119
|
-
predict: Predict = None) -> plt.Figure:
|
126
|
+
def spec_energy_plot(
|
127
|
+
mixture: AudioT, feature: Feature, truth_f: Truth | None = None, predict: Predict | None = None
|
128
|
+
) -> plt.Figure:
|
120
129
|
from sonusai.mixture import SAMPLE_RATE
|
121
130
|
|
122
131
|
num_plots = 2
|
@@ -130,35 +139,47 @@ def spec_energy_plot(mixture: AudioT,
|
|
130
139
|
# Plot the waveform
|
131
140
|
p = 0
|
132
141
|
x_axis = np.arange(len(mixture), dtype=np.float32) / SAMPLE_RATE
|
133
|
-
ax[p].plot(x_axis, mixture, label=
|
142
|
+
ax[p].plot(x_axis, mixture, label="Mixture")
|
134
143
|
ax[p].set_xlim(x_axis[0], x_axis[-1])
|
135
144
|
ax[p].set_ylim([-1.025, 1.025])
|
136
|
-
ax[p].set_title(
|
145
|
+
ax[p].set_title("Waveform")
|
137
146
|
|
138
147
|
# Plot the spectrogram
|
139
148
|
p += 1
|
140
|
-
ax[p].imshow(np.transpose(feature), aspect=
|
141
|
-
ax[p].set_title(
|
149
|
+
ax[p].imshow(np.transpose(feature), aspect="auto", interpolation="nearest", origin="lower")
|
150
|
+
ax[p].set_title("Feature")
|
142
151
|
|
143
152
|
if truth_f is not None:
|
144
153
|
p += 1
|
145
|
-
ax[p].imshow(
|
146
|
-
|
154
|
+
ax[p].imshow(
|
155
|
+
np.transpose(truth_f),
|
156
|
+
aspect="auto",
|
157
|
+
interpolation="nearest",
|
158
|
+
origin="lower",
|
159
|
+
)
|
160
|
+
ax[p].set_title("Truth")
|
147
161
|
|
148
162
|
if predict is not None:
|
149
163
|
p += 1
|
150
|
-
ax[p].imshow(
|
151
|
-
|
164
|
+
ax[p].imshow(
|
165
|
+
np.transpose(predict),
|
166
|
+
aspect="auto",
|
167
|
+
interpolation="nearest",
|
168
|
+
origin="lower",
|
169
|
+
)
|
170
|
+
ax[p].set_title("Predict")
|
152
171
|
|
153
172
|
return fig
|
154
173
|
|
155
174
|
|
156
|
-
def class_plot(
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
175
|
+
def class_plot(
|
176
|
+
mixture: AudioT,
|
177
|
+
target: AudioT | None = None,
|
178
|
+
truth_f: Truth | None = None,
|
179
|
+
predict: Predict | None = None,
|
180
|
+
label: str = "",
|
181
|
+
) -> plt.Figure:
|
182
|
+
"""Plot mixture waveform with optional prediction and/or truth together in a single plot
|
162
183
|
|
163
184
|
The target waveform can optionally be provided, and prediction and truth can have multiple classes.
|
164
185
|
|
@@ -174,30 +195,30 @@ def class_plot(mixture: AudioT,
|
|
174
195
|
from sonusai.mixture import SAMPLE_RATE
|
175
196
|
|
176
197
|
if mixture.ndim != 1:
|
177
|
-
raise SonusAIError(
|
198
|
+
raise SonusAIError("Too many dimensions in mixture")
|
178
199
|
|
179
200
|
if target is not None and target.ndim != 1:
|
180
|
-
raise SonusAIError(
|
201
|
+
raise SonusAIError("Too many dimensions in target")
|
181
202
|
|
182
203
|
# Set default to 1 frame when there is no truth or predict data
|
183
204
|
frames = 1
|
184
205
|
if truth_f is not None and predict is not None:
|
185
206
|
if truth_f.ndim != 1:
|
186
|
-
raise SonusAIError(
|
207
|
+
raise SonusAIError("Too many dimensions in truth_f")
|
187
208
|
t_frames = len(truth_f)
|
188
209
|
|
189
210
|
if predict.ndim != 1:
|
190
|
-
raise SonusAIError(
|
211
|
+
raise SonusAIError("Too many dimensions in predict")
|
191
212
|
p_frames = len(predict)
|
192
213
|
|
193
214
|
frames = min(t_frames, p_frames)
|
194
215
|
elif truth_f is not None:
|
195
216
|
if truth_f.ndim != 1:
|
196
|
-
raise SonusAIError(
|
217
|
+
raise SonusAIError("Too many dimensions in truth_f")
|
197
218
|
frames = len(truth_f)
|
198
219
|
elif predict is not None:
|
199
220
|
if predict.ndim != 1:
|
200
|
-
raise SonusAIError(
|
221
|
+
raise SonusAIError("Too many dimensions in predict")
|
201
222
|
frames = len(predict)
|
202
223
|
|
203
224
|
samples = (len(mixture) // frames) * frames
|
@@ -208,41 +229,51 @@ def class_plot(mixture: AudioT,
|
|
208
229
|
fig, ax = plt.subplots(1, 1, constrained_layout=True, figsize=(11, 8.5))
|
209
230
|
|
210
231
|
# Plot the time-domain waveforms then truth/prediction on second axis
|
211
|
-
ax.plot(x_axis, mixture[0:samples], color=
|
212
|
-
color =
|
232
|
+
ax.plot(x_axis, mixture[0:samples], color="mistyrose", label="Mixture")
|
233
|
+
color = "red"
|
213
234
|
ax.set_xlim(x_axis[0], x_axis[-1])
|
214
|
-
ax.set_ylim(
|
215
|
-
ax.set_ylabel(
|
216
|
-
ax.tick_params(axis=
|
235
|
+
ax.set_ylim((-1.025, 1.025))
|
236
|
+
ax.set_ylabel("Amplitude", color=color)
|
237
|
+
ax.tick_params(axis="y", labelcolor=color)
|
217
238
|
|
218
239
|
# Plot target time-domain waveform
|
219
240
|
if target is not None:
|
220
|
-
ax.plot(x_axis, target[0:samples], color=
|
241
|
+
ax.plot(x_axis, target[0:samples], color="blue", label="Target")
|
221
242
|
|
222
243
|
# instantiate 2nd y-axis that shares the same x-axis
|
223
244
|
if truth_f is not None or predict is not None:
|
224
|
-
y_label =
|
245
|
+
y_label = "Truth/Predict"
|
225
246
|
if truth_f is None:
|
226
|
-
y_label =
|
247
|
+
y_label = "Predict"
|
227
248
|
if predict is None:
|
228
|
-
y_label =
|
249
|
+
y_label = "Truth"
|
229
250
|
|
230
251
|
ax2 = ax.twinx()
|
231
252
|
|
232
|
-
color =
|
253
|
+
color = "black"
|
233
254
|
ax2.set_xlim(x_axis[0], x_axis[-1])
|
234
|
-
ax2.set_ylim(
|
255
|
+
ax2.set_ylim((-0.025, 1.025))
|
235
256
|
ax2.set_ylabel(y_label, color=color)
|
236
|
-
ax2.tick_params(axis=
|
257
|
+
ax2.tick_params(axis="y", labelcolor=color)
|
237
258
|
|
238
259
|
if truth_f is not None:
|
239
|
-
ax2.plot(
|
260
|
+
ax2.plot(
|
261
|
+
x_axis,
|
262
|
+
expand_frames_to_samples(truth_f, samples),
|
263
|
+
color="green",
|
264
|
+
label="Truth",
|
265
|
+
)
|
240
266
|
|
241
267
|
if predict is not None:
|
242
|
-
ax2.plot(
|
268
|
+
ax2.plot(
|
269
|
+
x_axis,
|
270
|
+
expand_frames_to_samples(predict, samples),
|
271
|
+
color="brown",
|
272
|
+
label="Predict",
|
273
|
+
)
|
243
274
|
|
244
275
|
# set only on last/bottom plot
|
245
|
-
ax.set_xlabel(
|
276
|
+
ax.set_xlabel("time (s)")
|
246
277
|
|
247
278
|
fig.suptitle(label)
|
248
279
|
|
@@ -263,7 +294,6 @@ def main() -> None:
|
|
263
294
|
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
264
295
|
|
265
296
|
from dataclasses import asdict
|
266
|
-
|
267
297
|
from os.path import basename
|
268
298
|
from os.path import exists
|
269
299
|
from os.path import isdir
|
@@ -279,117 +309,121 @@ def main() -> None:
|
|
279
309
|
from sonusai import initial_log_messages
|
280
310
|
from sonusai import logger
|
281
311
|
from sonusai import update_console_handler
|
282
|
-
from sonusai.mixture import MixtureDatabase
|
283
312
|
from sonusai.mixture import FeatureGeneratorConfig
|
313
|
+
from sonusai.mixture import MixtureDatabase
|
284
314
|
from sonusai.mixture import get_feature_from_audio
|
285
315
|
from sonusai.mixture import get_truth_indices_for_mixid
|
286
316
|
from sonusai.mixture import read_audio
|
287
317
|
from sonusai.utils import get_label_names
|
288
318
|
from sonusai.utils import print_mixture_details
|
289
319
|
|
290
|
-
verbose = args[
|
291
|
-
model_name = args[
|
292
|
-
output_name = args[
|
293
|
-
labels_name = args[
|
294
|
-
mixid = args[
|
295
|
-
energy = args[
|
296
|
-
input_name = args[
|
320
|
+
verbose = args["--verbose"]
|
321
|
+
model_name = args["--model"]
|
322
|
+
output_name = args["--output"]
|
323
|
+
labels_name = args["--labels"]
|
324
|
+
mixid = args["--mixid"]
|
325
|
+
energy = args["--energy"]
|
326
|
+
input_name = args["INPUT"]
|
297
327
|
|
298
328
|
if mixid is not None:
|
299
329
|
mixid = int(mixid)
|
300
330
|
|
301
|
-
create_file_handler(
|
331
|
+
create_file_handler("plot.log")
|
302
332
|
update_console_handler(verbose)
|
303
|
-
initial_log_messages(
|
333
|
+
initial_log_messages("plot")
|
304
334
|
|
305
335
|
if not exists(input_name):
|
306
|
-
raise SonusAIError(f
|
336
|
+
raise SonusAIError(f"{input_name} does not exist")
|
307
337
|
|
308
|
-
logger.info(
|
309
|
-
logger.info(f
|
338
|
+
logger.info("")
|
339
|
+
logger.info(f"Input: {input_name}")
|
310
340
|
if model_name is not None:
|
311
|
-
logger.info(f
|
341
|
+
logger.info(f"Model: {model_name}")
|
312
342
|
if output_name is not None:
|
313
|
-
logger.info(f
|
314
|
-
logger.info(
|
343
|
+
logger.info(f"Output: {output_name}")
|
344
|
+
logger.info("")
|
315
345
|
|
316
346
|
ext = splitext(input_name)[1]
|
317
347
|
|
318
348
|
model = None
|
319
349
|
target_audio = None
|
320
350
|
truth_f = None
|
321
|
-
t_indices =
|
351
|
+
t_indices = []
|
322
352
|
|
323
353
|
if model_name is not None:
|
324
354
|
model = Predict(model_name)
|
325
355
|
|
326
|
-
if ext ==
|
327
|
-
if
|
328
|
-
raise SonusAIError(
|
356
|
+
if ext == ".wav":
|
357
|
+
if model is None:
|
358
|
+
raise SonusAIError("Must specify MODEL when input is WAV")
|
329
359
|
|
330
360
|
mixture_audio = read_audio(input_name)
|
331
361
|
feature = get_feature_from_audio(audio=mixture_audio, feature_mode=model.feature)
|
332
|
-
fg_config = FeatureGeneratorConfig(
|
333
|
-
|
334
|
-
|
362
|
+
fg_config = FeatureGeneratorConfig(
|
363
|
+
feature_mode=model.feature,
|
364
|
+
num_classes=model.output_shape[-1],
|
365
|
+
truth_mutex=False,
|
366
|
+
)
|
335
367
|
fg = FeatureGenerator(**asdict(fg_config))
|
336
368
|
fg_step = fg.step
|
337
369
|
mixdb = None
|
338
|
-
logger.debug(f
|
339
|
-
logger.debug(f
|
370
|
+
logger.debug(f"Audio samples {len(mixture_audio)}")
|
371
|
+
logger.debug(f"Feature shape {feature.shape}")
|
340
372
|
|
341
373
|
elif isdir(input_name):
|
342
374
|
if mixid is None:
|
343
|
-
raise SonusAIError(
|
375
|
+
raise SonusAIError("Must specify mixid when input is mixture database")
|
344
376
|
|
345
377
|
mixdb = MixtureDatabase(input_name)
|
346
378
|
fg_step = mixdb.fg_step
|
347
379
|
|
348
380
|
print_mixture_details(mixdb=mixdb, mixid=mixid, desc_len=24, print_fn=logger.info)
|
349
381
|
|
350
|
-
logger.info(f
|
382
|
+
logger.info(f"Generating data for mixture {mixid}")
|
351
383
|
mixture_audio = mixdb.mixture_mixture(mixid)
|
352
384
|
target_audio = mixdb.mixture_target(mixid)
|
353
385
|
feature, truth_f = mixdb.mixture_ft(mixid)
|
354
386
|
t_indices = [x - 1 for x in get_truth_indices_for_mixid(mixdb=mixdb, mixid=mixid)]
|
355
387
|
|
356
388
|
target_files = [mixdb.target_file(target.file_id) for target in mixdb.mixtures[mixid].targets]
|
357
|
-
truth_functions = list(
|
358
|
-
energy =
|
389
|
+
truth_functions = list({sub2.function for sub1 in target_files for sub2 in sub1.truth_configs})
|
390
|
+
energy = "energy_f" in truth_functions or "snr_f" in truth_functions
|
359
391
|
|
360
|
-
logger.debug(f
|
361
|
-
logger.debug(
|
392
|
+
logger.debug(f"Audio samples {len(mixture_audio)}")
|
393
|
+
logger.debug("Targets:")
|
362
394
|
mixture = mixdb.mixture(mixid)
|
363
395
|
for target in mixture.targets:
|
364
396
|
target_file = mixdb.target_file(target.file_id)
|
365
397
|
name = target_file.name
|
366
398
|
duration = target_file.duration
|
367
399
|
augmentation = target.augmentation
|
368
|
-
logger.debug(f
|
369
|
-
logger.debug(f
|
370
|
-
logger.debug(f
|
400
|
+
logger.debug(f" Name {name}")
|
401
|
+
logger.debug(f" Duration {duration}")
|
402
|
+
logger.debug(f" Augmentation {augmentation}")
|
371
403
|
|
372
|
-
logger.debug(f
|
373
|
-
logger.debug(f
|
404
|
+
logger.debug(f"Feature shape {feature.shape}")
|
405
|
+
logger.debug(f"Truth shape {truth_f.shape}")
|
374
406
|
|
375
407
|
else:
|
376
|
-
raise SonusAIError(f
|
408
|
+
raise SonusAIError(f"Unknown file type for {input_name}")
|
377
409
|
|
378
410
|
predict = None
|
379
411
|
labels = None
|
380
412
|
indices = []
|
381
413
|
if model is not None:
|
382
|
-
logger.debug(
|
383
|
-
logger.info(f
|
384
|
-
logger.debug(f
|
385
|
-
logger.debug(f
|
386
|
-
logger.debug(f
|
414
|
+
logger.debug("")
|
415
|
+
logger.info(f"Running prediction on mixture {mixid}")
|
416
|
+
logger.debug(f"Model feature name {model.feature}")
|
417
|
+
logger.debug(f"Model input shape {model.input_shape}")
|
418
|
+
logger.debug(f"Model output shape {model.output_shape}")
|
387
419
|
|
388
420
|
if feature.shape[0] < model.input_shape[0]:
|
389
|
-
raise SonusAIError(
|
390
|
-
|
391
|
-
|
392
|
-
|
421
|
+
raise SonusAIError(
|
422
|
+
f"Mixture {mixid} contains {feature.shape[0]} "
|
423
|
+
f"frames of data which is not enough to run prediction; "
|
424
|
+
f"at least {model.input_shape[0]} frames are needed for this model.\n"
|
425
|
+
f"Consider using a model with a smaller batch size or a mixture with more data."
|
426
|
+
)
|
393
427
|
|
394
428
|
predict = model.execute(feature)
|
395
429
|
|
@@ -400,10 +434,10 @@ def main() -> None:
|
|
400
434
|
p_indices = np.argsort(p_max)[::-1][:5]
|
401
435
|
p_max_len = max([len(labels[i]) for i in p_indices])
|
402
436
|
|
403
|
-
logger.info(
|
437
|
+
logger.info("Top 5 active prediction classes by max:")
|
404
438
|
for p_index in p_indices:
|
405
|
-
logger.info(f
|
406
|
-
logger.info(
|
439
|
+
logger.info(f" {labels[p_index]:{p_max_len}s} {p_max[p_index]:.3f}")
|
440
|
+
logger.info("")
|
407
441
|
|
408
442
|
indices = list(p_indices)
|
409
443
|
|
@@ -414,26 +448,30 @@ def main() -> None:
|
|
414
448
|
|
415
449
|
base_name = basename(splitext(input_name)[0])
|
416
450
|
if mixdb is not None:
|
417
|
-
title = f
|
418
|
-
pdf_name = f
|
451
|
+
title = f"{input_name} Mixture {mixid}"
|
452
|
+
pdf_name = f"{base_name}-mix{mixid}-plot.pdf"
|
419
453
|
else:
|
420
|
-
title = f
|
421
|
-
pdf_name = f
|
454
|
+
title = f"{input_name}"
|
455
|
+
pdf_name = f"{base_name}-plot.pdf"
|
422
456
|
|
423
457
|
# Original size [frames, stride, feature_parameters]
|
424
458
|
# Decimate in the stride dimension
|
425
459
|
# Reshape to get frames*decimated_stride, feature_parameters
|
426
460
|
if feature.ndim != 3:
|
427
|
-
raise SonusAIError(
|
461
|
+
raise SonusAIError("feature does not have 3 dimensions: frames, stride, feature_parameters")
|
428
462
|
spectrogram = feature[:, -fg_step:, :]
|
429
463
|
spectrogram = np.reshape(spectrogram, (spectrogram.shape[0] * spectrogram.shape[1], spectrogram.shape[2]))
|
430
464
|
|
431
465
|
with PdfPages(pdf_name) as pdf:
|
432
|
-
pdf.savefig(
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
466
|
+
pdf.savefig(
|
467
|
+
spec_plot(
|
468
|
+
mixture=mixture_audio,
|
469
|
+
feature=spectrogram,
|
470
|
+
predict=predict,
|
471
|
+
labels=labels,
|
472
|
+
title=title,
|
473
|
+
)
|
474
|
+
)
|
437
475
|
for index in indices:
|
438
476
|
if energy:
|
439
477
|
t_tmp = None
|
@@ -444,10 +482,14 @@ def main() -> None:
|
|
444
482
|
if predict is not None:
|
445
483
|
p_tmp = 10 * np.log10(predict + np.finfo(np.float32).eps)
|
446
484
|
|
447
|
-
pdf.savefig(
|
448
|
-
|
449
|
-
|
450
|
-
|
485
|
+
pdf.savefig(
|
486
|
+
spec_energy_plot(
|
487
|
+
mixture=mixture_audio,
|
488
|
+
feature=spectrogram,
|
489
|
+
truth_f=t_tmp,
|
490
|
+
predict=p_tmp,
|
491
|
+
)
|
492
|
+
)
|
451
493
|
else:
|
452
494
|
p_tmp = None
|
453
495
|
if predict is not None:
|
@@ -457,18 +499,22 @@ def main() -> None:
|
|
457
499
|
if labels is not None:
|
458
500
|
l_tmp = labels[index]
|
459
501
|
|
460
|
-
pdf.savefig(
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
502
|
+
pdf.savefig(
|
503
|
+
class_plot(
|
504
|
+
mixture=mixture_audio,
|
505
|
+
target=target_audio[index],
|
506
|
+
truth_f=truth_f[:, index],
|
507
|
+
predict=p_tmp,
|
508
|
+
label=l_tmp,
|
509
|
+
)
|
510
|
+
)
|
511
|
+
logger.info(f"Wrote {pdf_name}")
|
466
512
|
|
467
513
|
if output_name:
|
468
|
-
with h5py.File(output_name,
|
469
|
-
f.create_dataset(name=
|
470
|
-
logger.info(f
|
514
|
+
with h5py.File(output_name, "w") as f:
|
515
|
+
f.create_dataset(name="predict", data=predict)
|
516
|
+
logger.info(f"Wrote {output_name}")
|
471
517
|
|
472
518
|
|
473
|
-
if __name__ ==
|
519
|
+
if __name__ == "__main__":
|
474
520
|
main()
|