sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,519 @@
1
+ """sonusai plot
2
+
3
+ usage: plot [-hve] [-i MIXID] [-m MODEL] [-l CSV] [-o OUTPUT] INPUT
4
+
5
+ options:
6
+ -h, --help
7
+ -v, --verbose Be verbose.
8
+ -i MIXID, --mixid MIXID Mixture to plot if input is a mixture database.
9
+ -m MODEL, --model MODEL Trained model ONNX file.
10
+ -l CSV, --labels CSV Optional CSV file of class labels (from SonusAI gentcst).
11
+ -o OUTPUT, --output OUTPUT Optional output HDF5 file for prediction.
12
+ -e, --energy Use energy plots.
13
+
14
+ Plot SonusAI audio, feature, truth, and prediction data. INPUT must be one of the following:
15
+
16
+ * WAV
17
+ Using the given model, generate feature data and run prediction. A model file must be
18
+ provided. The MIXID is ignored. If --energy is specified, plot predict data as energy.
19
+
20
+ * directory
21
+ Using the given SonusAI mixture database directory, generate feature and truth data if not found.
22
+ Run prediction if a model is given. The MIXID is required. (--energy is ignored.)
23
+
24
+ Prediction data will be written to OUTPUT if a model file is given and OUTPUT is specified.
25
+
26
+ There will be one plot per active truth index. In addition, the top 5 prediction classes are determined and
27
+ plotted if needed (i.e., if they were not already included in the truth plots). For plots generated using a
28
+ mixture database, then the target will also be displayed. If mixup is active, then each target involved will
29
+ be added to the corresponding truth plot.
30
+
31
+ Inputs:
32
+ MODEL A SonusAI trained ONNX model file. If a model file is given, prediction data will be
33
+ generated.
34
+ INPUT A WAV file, or
35
+ a directory containing a SonusAI mixture database
36
+
37
+ Outputs:
38
+ {INPUT}-plot.pdf or {INPUT}-mix{MIXID}-plot.pdf
39
+ plot.log
40
+ OUTPUT (if MODEL and OUTPUT are both specified)
41
+
42
+ """
43
+
44
+ import signal
45
+
46
+ import numpy as np
47
+ from matplotlib import pyplot as plt
48
+ from sonusai.datatypes import AudioT
49
+ from sonusai.datatypes import Feature
50
+ from sonusai.datatypes import Predict
51
+ from sonusai.datatypes import Truth
52
+
53
+
54
+ def signal_handler(_sig, _frame):
55
+ import sys
56
+
57
+ from sonusai import logger
58
+
59
+ logger.info("Canceled due to keyboard interrupt")
60
+ sys.exit(1)
61
+
62
+
63
+ signal.signal(signal.SIGINT, signal_handler)
64
+
65
+
66
+ def spec_plot(
67
+ mixture: AudioT,
68
+ feature: Feature,
69
+ predict: Predict | None = None,
70
+ target: AudioT | None = None,
71
+ labels: list[str] | None = None,
72
+ title: str = "",
73
+ ) -> plt.Figure:
74
+ from sonusai.constants import SAMPLE_RATE
75
+
76
+ num_plots = 4 if predict is not None else 2
77
+ fig, ax = plt.subplots(num_plots, 1, constrained_layout=True, figsize=(11, 8.5))
78
+
79
+ # Plot the waveform
80
+ x_axis = np.arange(len(mixture), dtype=np.float32) / SAMPLE_RATE
81
+ ax[0].plot(x_axis, mixture, label="Mixture")
82
+ ax[0].set_xlim(x_axis[0], x_axis[-1])
83
+ ax[0].set_ylim([-1.025, 1.025])
84
+ if target is not None:
85
+ # Plot target time-domain waveform on top of mixture
86
+ color = "tab:blue"
87
+ ax[0].plot(x_axis, target, color=color, label="Target")
88
+ ax[0].set_ylabel("magnitude", color=color)
89
+ ax[0].set_title("Waveform")
90
+
91
+ # Plot the spectrogram
92
+ ax[1].imshow(np.transpose(feature), aspect="auto", interpolation="nearest", origin="lower")
93
+ ax[1].set_title("Feature")
94
+
95
+ if predict is not None:
96
+ if labels is None:
97
+ raise ValueError("Provided predict without labels")
98
+
99
+ # Plot and label the model output scores for the top-scoring classes.
100
+ mean_predict = np.mean(predict, axis=0)
101
+ num_classes = predict.shape[-1]
102
+ top_n = min(10, num_classes)
103
+ top_class_indices = np.argsort(mean_predict)[::-1][:top_n]
104
+ ax[2].imshow(
105
+ np.transpose(predict[:, top_class_indices]),
106
+ aspect="auto",
107
+ interpolation="nearest",
108
+ cmap="gray_r",
109
+ )
110
+ y_ticks = range(0, top_n)
111
+ ax[2].set_yticks(y_ticks, [labels[top_class_indices[x]] for x in y_ticks])
112
+ ax[2].set_ylim(-0.5 + np.array([top_n, 0]))
113
+ ax[2].set_title("Class Scores")
114
+
115
+ # Plot the probabilities
116
+ ax[3].plot(predict[:, top_class_indices])
117
+ ax[3].legend(np.array(labels)[top_class_indices], loc="best")
118
+ ax[3].set_title("Class Probabilities")
119
+
120
+ fig.suptitle(title)
121
+
122
+ return fig
123
+
124
+
125
+ def spec_energy_plot(
126
+ mixture: AudioT, feature: Feature, truth_f: Truth | None = None, predict: Predict | None = None
127
+ ) -> plt.Figure:
128
+ from sonusai.constants import SAMPLE_RATE
129
+
130
+ num_plots = 2
131
+ if truth_f is not None:
132
+ num_plots += 1
133
+ if predict is not None:
134
+ num_plots += 1
135
+
136
+ fig, ax = plt.subplots(num_plots, 1, constrained_layout=True, figsize=(11, 8.5))
137
+
138
+ # Plot the waveform
139
+ p = 0
140
+ x_axis = np.arange(len(mixture), dtype=np.float32) / SAMPLE_RATE
141
+ ax[p].plot(x_axis, mixture, label="Mixture")
142
+ ax[p].set_xlim(x_axis[0], x_axis[-1])
143
+ ax[p].set_ylim([-1.025, 1.025])
144
+ ax[p].set_title("Waveform")
145
+
146
+ # Plot the spectrogram
147
+ p += 1
148
+ ax[p].imshow(np.transpose(feature), aspect="auto", interpolation="nearest", origin="lower")
149
+ ax[p].set_title("Feature")
150
+
151
+ if truth_f is not None:
152
+ p += 1
153
+ ax[p].imshow(
154
+ np.transpose(truth_f),
155
+ aspect="auto",
156
+ interpolation="nearest",
157
+ origin="lower",
158
+ )
159
+ ax[p].set_title("Truth")
160
+
161
+ if predict is not None:
162
+ p += 1
163
+ ax[p].imshow(
164
+ np.transpose(predict),
165
+ aspect="auto",
166
+ interpolation="nearest",
167
+ origin="lower",
168
+ )
169
+ ax[p].set_title("Predict")
170
+
171
+ return fig
172
+
173
+
174
+ def class_plot(
175
+ mixture: AudioT,
176
+ target: AudioT | None = None,
177
+ truth_f: Truth | None = None,
178
+ predict: Predict | None = None,
179
+ label: str = "",
180
+ ) -> plt.Figure:
181
+ """Plot mixture waveform with optional prediction and/or truth together in a single plot
182
+
183
+ The target waveform can optionally be provided, and prediction and truth can have multiple classes.
184
+
185
+ Inputs:
186
+ mixture required, numpy array [samples, 1]
187
+ target optional, list of numpy arrays [samples, 1]
188
+ truth_f optional, numpy array [frames, 1]
189
+ predict optional, numpy array [frames, 1]
190
+ label optional, label name to use when plotting
191
+
192
+ """
193
+ from sonusai import SonusAIError
194
+ from sonusai.constants import SAMPLE_RATE
195
+
196
+ if mixture.ndim != 1:
197
+ raise SonusAIError("Too many dimensions in mixture")
198
+
199
+ if target is not None and target.ndim != 1:
200
+ raise SonusAIError("Too many dimensions in target")
201
+
202
+ # Set default to 1 frame when there is no truth or predict data
203
+ frames = 1
204
+ if truth_f is not None and predict is not None:
205
+ if truth_f.ndim != 1:
206
+ raise SonusAIError("Too many dimensions in truth_f")
207
+ t_frames = len(truth_f)
208
+
209
+ if predict.ndim != 1:
210
+ raise SonusAIError("Too many dimensions in predict")
211
+ p_frames = len(predict)
212
+
213
+ frames = min(t_frames, p_frames)
214
+ elif truth_f is not None:
215
+ if truth_f.ndim != 1:
216
+ raise SonusAIError("Too many dimensions in truth_f")
217
+ frames = len(truth_f)
218
+ elif predict is not None:
219
+ if predict.ndim != 1:
220
+ raise SonusAIError("Too many dimensions in predict")
221
+ frames = len(predict)
222
+
223
+ samples = (len(mixture) // frames) * frames
224
+
225
+ # x-axis in sec
226
+ x_axis = np.arange(samples, dtype=np.float32) / SAMPLE_RATE
227
+
228
+ fig, ax = plt.subplots(1, 1, constrained_layout=True, figsize=(11, 8.5))
229
+
230
+ # Plot the time-domain waveforms then truth/prediction on second axis
231
+ ax.plot(x_axis, mixture[0:samples], color="mistyrose", label="Mixture")
232
+ color = "red"
233
+ ax.set_xlim(x_axis[0], x_axis[-1])
234
+ ax.set_ylim((-1.025, 1.025))
235
+ ax.set_ylabel("Amplitude", color=color)
236
+ ax.tick_params(axis="y", labelcolor=color)
237
+
238
+ # Plot target time-domain waveform
239
+ if target is not None:
240
+ ax.plot(x_axis, target[0:samples], color="blue", label="Target")
241
+
242
+ # instantiate 2nd y-axis that shares the same x-axis
243
+ if truth_f is not None or predict is not None:
244
+ y_label = "Truth/Predict"
245
+ if truth_f is None:
246
+ y_label = "Predict"
247
+ if predict is None:
248
+ y_label = "Truth"
249
+
250
+ ax2 = ax.twinx()
251
+
252
+ color = "black"
253
+ ax2.set_xlim(x_axis[0], x_axis[-1])
254
+ ax2.set_ylim((-0.025, 1.025))
255
+ ax2.set_ylabel(y_label, color=color)
256
+ ax2.tick_params(axis="y", labelcolor=color)
257
+
258
+ if truth_f is not None:
259
+ ax2.plot(
260
+ x_axis,
261
+ expand_frames_to_samples(truth_f, samples),
262
+ color="green",
263
+ label="Truth",
264
+ )
265
+
266
+ if predict is not None:
267
+ ax2.plot(
268
+ x_axis,
269
+ expand_frames_to_samples(predict, samples),
270
+ color="brown",
271
+ label="Predict",
272
+ )
273
+
274
+ # set only on last/bottom plot
275
+ ax.set_xlabel("time (s)")
276
+
277
+ fig.suptitle(label)
278
+
279
+ return fig
280
+
281
+
282
+ def expand_frames_to_samples(x: np.ndarray, samples: int) -> np.ndarray:
283
+ samples_per_frame = samples // len(x)
284
+ return np.reshape(np.tile(np.expand_dims(x, 1), [1, samples_per_frame]), samples)
285
+
286
+
287
+ def main() -> None:
288
+ from docopt import docopt
289
+
290
+ import sonusai
291
+ from sonusai.utils.docstring import trim_docstring
292
+
293
+ args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
294
+
295
+ from dataclasses import asdict
296
+ from os.path import basename
297
+ from os.path import exists
298
+ from os.path import isdir
299
+ from os.path import splitext
300
+
301
+ import h5py
302
+ from matplotlib.backends.backend_pdf import PdfPages
303
+ from pyaaware import FeatureGenerator
304
+ from pyaaware import Predict
305
+
306
+ from sonusai import SonusAIError
307
+ from sonusai import create_file_handler
308
+ from sonusai import initial_log_messages
309
+ from sonusai import logger
310
+ from sonusai import update_console_handler
311
+ from sonusai.mixture import FeatureGeneratorConfig
312
+ from sonusai.mixture import MixtureDatabase
313
+ from sonusai.mixture import get_feature_from_audio
314
+ from sonusai.mixture import get_truth_indices_for_mixid
315
+ from sonusai.mixture.audio import read_audio
316
+ from sonusai.utils.get_label_names import get_label_names
317
+ from sonusai.utils.print_mixture_details import print_mixture_details
318
+
319
+ verbose = args["--verbose"]
320
+ model_name = args["--model"]
321
+ output_name = args["--output"]
322
+ labels_name = args["--labels"]
323
+ mixid = args["--mixid"]
324
+ energy = args["--energy"]
325
+ input_name = args["INPUT"]
326
+
327
+ if mixid is not None:
328
+ mixid = int(mixid)
329
+
330
+ create_file_handler("plot.log")
331
+ update_console_handler(verbose)
332
+ initial_log_messages("plot")
333
+
334
+ if not exists(input_name):
335
+ raise SonusAIError(f"{input_name} does not exist")
336
+
337
+ logger.info("")
338
+ logger.info(f"Input: {input_name}")
339
+ if model_name is not None:
340
+ logger.info(f"Model: {model_name}")
341
+ if output_name is not None:
342
+ logger.info(f"Output: {output_name}")
343
+ logger.info("")
344
+
345
+ ext = splitext(input_name)[1]
346
+
347
+ model = None
348
+ target_audio = None
349
+ truth_f = None
350
+ t_indices = []
351
+
352
+ if model_name is not None:
353
+ model = Predict(model_name)
354
+
355
+ if ext == ".wav":
356
+ if model is None:
357
+ raise SonusAIError("Must specify MODEL when input is WAV")
358
+
359
+ mixture_audio = read_audio(input_name)
360
+ feature = get_feature_from_audio(audio=mixture_audio, feature_mode=model.feature)
361
+ fg_config = FeatureGeneratorConfig(
362
+ feature_mode=model.feature,
363
+ num_classes=model.output_shape[-1],
364
+ truth_mutex=False,
365
+ )
366
+ fg = FeatureGenerator(**asdict(fg_config))
367
+ fg_step = fg.step
368
+ mixdb = None
369
+ logger.debug(f"Audio samples {len(mixture_audio)}")
370
+ logger.debug(f"Feature shape {feature.shape}")
371
+
372
+ elif isdir(input_name):
373
+ if mixid is None:
374
+ raise SonusAIError("Must specify mixid when input is mixture database")
375
+
376
+ mixdb = MixtureDatabase(input_name)
377
+ fg_step = mixdb.fg_step
378
+
379
+ print_mixture_details(mixdb=mixdb, mixid=mixid, desc_len=24, print_fn=logger.info)
380
+
381
+ logger.info(f"Generating data for mixture {mixid}")
382
+ mixture_audio = mixdb.mixture_mixture(mixid)
383
+ target_audio = mixdb.mixture_target(mixid)
384
+ feature, truth_f = mixdb.mixture_ft(mixid)
385
+ t_indices = [x - 1 for x in get_truth_indices_for_mixid(mixdb=mixdb, mixid=mixid)]
386
+
387
+ target_files = [mixdb.target_file(target.file_id) for target in mixdb.mixtures[mixid].targets]
388
+ truth_functions = list({sub2.function for sub1 in target_files for sub2 in sub1.truth_configs})
389
+ energy = "energy_f" in truth_functions or "snr_f" in truth_functions
390
+
391
+ logger.debug(f"Audio samples {len(mixture_audio)}")
392
+ logger.debug("Targets:")
393
+ mixture = mixdb.mixture(mixid)
394
+ for target in mixture.targets:
395
+ target_file = mixdb.target_file(target.file_id)
396
+ name = target_file.name
397
+ duration = target_file.duration
398
+ augmentation = target.augmentation
399
+ logger.debug(f" Name {name}")
400
+ logger.debug(f" Duration {duration}")
401
+ logger.debug(f" Augmentation {augmentation}")
402
+
403
+ logger.debug(f"Feature shape {feature.shape}")
404
+ logger.debug(f"Truth shape {truth_f.shape}")
405
+
406
+ else:
407
+ raise SonusAIError(f"Unknown file type for {input_name}")
408
+
409
+ predict = None
410
+ labels = None
411
+ indices = []
412
+ if model is not None:
413
+ logger.debug("")
414
+ logger.info(f"Running prediction on mixture {mixid}")
415
+ logger.debug(f"Model feature name {model.feature}")
416
+ logger.debug(f"Model input shape {model.input_shape}")
417
+ logger.debug(f"Model output shape {model.output_shape}")
418
+
419
+ if feature.shape[0] < model.input_shape[0]:
420
+ raise SonusAIError(
421
+ f"Mixture {mixid} contains {feature.shape[0]} "
422
+ f"frames of data which is not enough to run prediction; "
423
+ f"at least {model.input_shape[0]} frames are needed for this model.\n"
424
+ f"Consider using a model with a smaller batch size or a mixture with more data."
425
+ )
426
+
427
+ predict = model.execute(feature)
428
+
429
+ labels = get_label_names(num_labels=predict.shape[1], file=labels_name)
430
+
431
+ # Report the highest-scoring classes and their scores.
432
+ p_max = np.max(predict, axis=0)
433
+ p_indices = np.argsort(p_max)[::-1][:5]
434
+ p_max_len = max([len(labels[i]) for i in p_indices])
435
+
436
+ logger.info("Top 5 active prediction classes by max:")
437
+ for p_index in p_indices:
438
+ logger.info(f" {labels[p_index]:{p_max_len}s} {p_max[p_index]:.3f}")
439
+ logger.info("")
440
+
441
+ indices = list(p_indices)
442
+
443
+ # Add truth indices for target (if needed)
444
+ for t_index in t_indices:
445
+ if t_index not in indices:
446
+ indices.append(t_index)
447
+
448
+ base_name = basename(splitext(input_name)[0])
449
+ if mixdb is not None:
450
+ title = f"{input_name} Mixture {mixid}"
451
+ pdf_name = f"{base_name}-mix{mixid}-plot.pdf"
452
+ else:
453
+ title = f"{input_name}"
454
+ pdf_name = f"{base_name}-plot.pdf"
455
+
456
+ # Original size [frames, stride, feature_parameters]
457
+ # Decimate in the stride dimension
458
+ # Reshape to get frames*decimated_stride, feature_parameters
459
+ if feature.ndim != 3:
460
+ raise SonusAIError("feature does not have 3 dimensions: frames, stride, feature_parameters")
461
+ spectrogram = feature[:, -fg_step:, :]
462
+ spectrogram = np.reshape(spectrogram, (spectrogram.shape[0] * spectrogram.shape[1], spectrogram.shape[2]))
463
+
464
+ with PdfPages(pdf_name) as pdf:
465
+ pdf.savefig(
466
+ spec_plot(
467
+ mixture=mixture_audio,
468
+ feature=spectrogram,
469
+ predict=predict,
470
+ labels=labels,
471
+ title=title,
472
+ )
473
+ )
474
+ for index in indices:
475
+ if energy:
476
+ t_tmp = None
477
+ if truth_f is not None:
478
+ t_tmp = 10 * np.log10(truth_f + np.finfo(np.float32).eps)
479
+
480
+ p_tmp = None
481
+ if predict is not None:
482
+ p_tmp = 10 * np.log10(predict + np.finfo(np.float32).eps)
483
+
484
+ pdf.savefig(
485
+ spec_energy_plot(
486
+ mixture=mixture_audio,
487
+ feature=spectrogram,
488
+ truth_f=t_tmp,
489
+ predict=p_tmp,
490
+ )
491
+ )
492
+ else:
493
+ p_tmp = None
494
+ if predict is not None:
495
+ p_tmp = predict[:, index]
496
+
497
+ l_tmp = None
498
+ if labels is not None:
499
+ l_tmp = labels[index]
500
+
501
+ pdf.savefig(
502
+ class_plot(
503
+ mixture=mixture_audio,
504
+ target=target_audio[index],
505
+ truth_f=truth_f[:, index],
506
+ predict=p_tmp,
507
+ label=l_tmp,
508
+ )
509
+ )
510
+ logger.info(f"Wrote {pdf_name}")
511
+
512
+ if output_name:
513
+ with h5py.File(output_name, "w") as f:
514
+ f.create_dataset(name="predict", data=predict)
515
+ logger.info(f"Wrote {output_name}")
516
+
517
+
518
+ if __name__ == "__main__":
519
+ main()