sonusai 0.18.8__py3-none-any.whl → 0.19.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +93 -77
  13. sonusai/genmetrics.py +59 -46
  14. sonusai/genmix.py +116 -104
  15. sonusai/genmixdb.py +194 -153
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +15 -17
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +19 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +50 -46
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +677 -473
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +52 -85
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +40 -27
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
  113. sonusai-0.19.5.dist-info/RECORD +125 -0
  114. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.8.dist-info/RECORD +0 -125
  118. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
@@ -57,19 +57,21 @@ def signal_handler(_sig, _frame):
57
57
 
58
58
  from sonusai import logger
59
59
 
60
- logger.info('Canceled due to keyboard interrupt')
60
+ logger.info("Canceled due to keyboard interrupt")
61
61
  sys.exit(1)
62
62
 
63
63
 
64
64
  signal.signal(signal.SIGINT, signal_handler)
65
65
 
66
66
 
67
- def spec_plot(mixture: AudioT,
68
- feature: Feature,
69
- predict: Predict = None,
70
- target: AudioT = None,
71
- labels: list[str] = None,
72
- title: str = '') -> plt.Figure:
67
+ def spec_plot(
68
+ mixture: AudioT,
69
+ feature: Feature,
70
+ predict: Predict | None = None,
71
+ target: AudioT | None = None,
72
+ labels: list[str] | None = None,
73
+ title: str = "",
74
+ ) -> plt.Figure:
73
75
  from sonusai.mixture import SAMPLE_RATE
74
76
 
75
77
  num_plots = 4 if predict is not None else 2
@@ -77,46 +79,53 @@ def spec_plot(mixture: AudioT,
77
79
 
78
80
  # Plot the waveform
79
81
  x_axis = np.arange(len(mixture), dtype=np.float32) / SAMPLE_RATE
80
- ax[0].plot(x_axis, mixture, label='Mixture')
82
+ ax[0].plot(x_axis, mixture, label="Mixture")
81
83
  ax[0].set_xlim(x_axis[0], x_axis[-1])
82
84
  ax[0].set_ylim([-1.025, 1.025])
83
85
  if target is not None:
84
86
  # Plot target time-domain waveform on top of mixture
85
- color = 'tab:blue'
86
- ax[0].plot(x_axis, target, color=color, label='Target')
87
- ax[0].set_ylabel('magnitude', color=color)
88
- ax[0].set_title('Waveform')
87
+ color = "tab:blue"
88
+ ax[0].plot(x_axis, target, color=color, label="Target")
89
+ ax[0].set_ylabel("magnitude", color=color)
90
+ ax[0].set_title("Waveform")
89
91
 
90
92
  # Plot the spectrogram
91
- ax[1].imshow(np.transpose(feature), aspect='auto', interpolation='nearest', origin='lower')
92
- ax[1].set_title('Feature')
93
+ ax[1].imshow(np.transpose(feature), aspect="auto", interpolation="nearest", origin="lower")
94
+ ax[1].set_title("Feature")
93
95
 
94
96
  if predict is not None:
97
+ if labels is None:
98
+ raise ValueError("Provided predict without labels")
99
+
95
100
  # Plot and label the model output scores for the top-scoring classes.
96
101
  mean_predict = np.mean(predict, axis=0)
97
102
  num_classes = predict.shape[-1]
98
103
  top_n = min(10, num_classes)
99
104
  top_class_indices = np.argsort(mean_predict)[::-1][:top_n]
100
- ax[2].imshow(np.transpose(predict[:, top_class_indices]), aspect='auto', interpolation='nearest', cmap='gray_r')
105
+ ax[2].imshow(
106
+ np.transpose(predict[:, top_class_indices]),
107
+ aspect="auto",
108
+ interpolation="nearest",
109
+ cmap="gray_r",
110
+ )
101
111
  y_ticks = range(0, top_n)
102
112
  ax[2].set_yticks(y_ticks, [labels[top_class_indices[x]] for x in y_ticks])
103
113
  ax[2].set_ylim(-0.5 + np.array([top_n, 0]))
104
- ax[2].set_title('Class Scores')
114
+ ax[2].set_title("Class Scores")
105
115
 
106
116
  # Plot the probabilities
107
117
  ax[3].plot(predict[:, top_class_indices])
108
- ax[3].legend(np.array(labels)[top_class_indices], loc='best')
109
- ax[3].set_title('Class Probabilities')
118
+ ax[3].legend(np.array(labels)[top_class_indices], loc="best")
119
+ ax[3].set_title("Class Probabilities")
110
120
 
111
121
  fig.suptitle(title)
112
122
 
113
123
  return fig
114
124
 
115
125
 
116
- def spec_energy_plot(mixture: AudioT,
117
- feature: Feature,
118
- truth_f: Truth = None,
119
- predict: Predict = None) -> plt.Figure:
126
+ def spec_energy_plot(
127
+ mixture: AudioT, feature: Feature, truth_f: Truth | None = None, predict: Predict | None = None
128
+ ) -> plt.Figure:
120
129
  from sonusai.mixture import SAMPLE_RATE
121
130
 
122
131
  num_plots = 2
@@ -130,35 +139,47 @@ def spec_energy_plot(mixture: AudioT,
130
139
  # Plot the waveform
131
140
  p = 0
132
141
  x_axis = np.arange(len(mixture), dtype=np.float32) / SAMPLE_RATE
133
- ax[p].plot(x_axis, mixture, label='Mixture')
142
+ ax[p].plot(x_axis, mixture, label="Mixture")
134
143
  ax[p].set_xlim(x_axis[0], x_axis[-1])
135
144
  ax[p].set_ylim([-1.025, 1.025])
136
- ax[p].set_title('Waveform')
145
+ ax[p].set_title("Waveform")
137
146
 
138
147
  # Plot the spectrogram
139
148
  p += 1
140
- ax[p].imshow(np.transpose(feature), aspect='auto', interpolation='nearest', origin='lower')
141
- ax[p].set_title('Feature')
149
+ ax[p].imshow(np.transpose(feature), aspect="auto", interpolation="nearest", origin="lower")
150
+ ax[p].set_title("Feature")
142
151
 
143
152
  if truth_f is not None:
144
153
  p += 1
145
- ax[p].imshow(np.transpose(truth_f), aspect='auto', interpolation='nearest', origin='lower')
146
- ax[p].set_title('Truth')
154
+ ax[p].imshow(
155
+ np.transpose(truth_f),
156
+ aspect="auto",
157
+ interpolation="nearest",
158
+ origin="lower",
159
+ )
160
+ ax[p].set_title("Truth")
147
161
 
148
162
  if predict is not None:
149
163
  p += 1
150
- ax[p].imshow(np.transpose(predict), aspect='auto', interpolation='nearest', origin='lower')
151
- ax[p].set_title('Predict')
164
+ ax[p].imshow(
165
+ np.transpose(predict),
166
+ aspect="auto",
167
+ interpolation="nearest",
168
+ origin="lower",
169
+ )
170
+ ax[p].set_title("Predict")
152
171
 
153
172
  return fig
154
173
 
155
174
 
156
- def class_plot(mixture: AudioT,
157
- target: AudioT = None,
158
- truth_f: Truth = None,
159
- predict: Predict = None,
160
- label: str = '') -> plt.Figure:
161
- """ Plot mixture waveform with optional prediction and/or truth together in a single plot
175
+ def class_plot(
176
+ mixture: AudioT,
177
+ target: AudioT | None = None,
178
+ truth_f: Truth | None = None,
179
+ predict: Predict | None = None,
180
+ label: str = "",
181
+ ) -> plt.Figure:
182
+ """Plot mixture waveform with optional prediction and/or truth together in a single plot
162
183
 
163
184
  The target waveform can optionally be provided, and prediction and truth can have multiple classes.
164
185
 
@@ -174,30 +195,30 @@ def class_plot(mixture: AudioT,
174
195
  from sonusai.mixture import SAMPLE_RATE
175
196
 
176
197
  if mixture.ndim != 1:
177
- raise SonusAIError('Too many dimensions in mixture')
198
+ raise SonusAIError("Too many dimensions in mixture")
178
199
 
179
200
  if target is not None and target.ndim != 1:
180
- raise SonusAIError('Too many dimensions in target')
201
+ raise SonusAIError("Too many dimensions in target")
181
202
 
182
203
  # Set default to 1 frame when there is no truth or predict data
183
204
  frames = 1
184
205
  if truth_f is not None and predict is not None:
185
206
  if truth_f.ndim != 1:
186
- raise SonusAIError('Too many dimensions in truth_f')
207
+ raise SonusAIError("Too many dimensions in truth_f")
187
208
  t_frames = len(truth_f)
188
209
 
189
210
  if predict.ndim != 1:
190
- raise SonusAIError('Too many dimensions in predict')
211
+ raise SonusAIError("Too many dimensions in predict")
191
212
  p_frames = len(predict)
192
213
 
193
214
  frames = min(t_frames, p_frames)
194
215
  elif truth_f is not None:
195
216
  if truth_f.ndim != 1:
196
- raise SonusAIError('Too many dimensions in truth_f')
217
+ raise SonusAIError("Too many dimensions in truth_f")
197
218
  frames = len(truth_f)
198
219
  elif predict is not None:
199
220
  if predict.ndim != 1:
200
- raise SonusAIError('Too many dimensions in predict')
221
+ raise SonusAIError("Too many dimensions in predict")
201
222
  frames = len(predict)
202
223
 
203
224
  samples = (len(mixture) // frames) * frames
@@ -208,41 +229,51 @@ def class_plot(mixture: AudioT,
208
229
  fig, ax = plt.subplots(1, 1, constrained_layout=True, figsize=(11, 8.5))
209
230
 
210
231
  # Plot the time-domain waveforms then truth/prediction on second axis
211
- ax.plot(x_axis, mixture[0:samples], color='mistyrose', label='Mixture')
212
- color = 'red'
232
+ ax.plot(x_axis, mixture[0:samples], color="mistyrose", label="Mixture")
233
+ color = "red"
213
234
  ax.set_xlim(x_axis[0], x_axis[-1])
214
- ax.set_ylim([-1.025, 1.025])
215
- ax.set_ylabel(f'Amplitude', color=color)
216
- ax.tick_params(axis='y', labelcolor=color)
235
+ ax.set_ylim((-1.025, 1.025))
236
+ ax.set_ylabel("Amplitude", color=color)
237
+ ax.tick_params(axis="y", labelcolor=color)
217
238
 
218
239
  # Plot target time-domain waveform
219
240
  if target is not None:
220
- ax.plot(x_axis, target[0:samples], color='blue', label='Target')
241
+ ax.plot(x_axis, target[0:samples], color="blue", label="Target")
221
242
 
222
243
  # instantiate 2nd y-axis that shares the same x-axis
223
244
  if truth_f is not None or predict is not None:
224
- y_label = 'Truth/Predict'
245
+ y_label = "Truth/Predict"
225
246
  if truth_f is None:
226
- y_label = 'Predict'
247
+ y_label = "Predict"
227
248
  if predict is None:
228
- y_label = 'Truth'
249
+ y_label = "Truth"
229
250
 
230
251
  ax2 = ax.twinx()
231
252
 
232
- color = 'black'
253
+ color = "black"
233
254
  ax2.set_xlim(x_axis[0], x_axis[-1])
234
- ax2.set_ylim([-0.025, 1.025])
255
+ ax2.set_ylim((-0.025, 1.025))
235
256
  ax2.set_ylabel(y_label, color=color)
236
- ax2.tick_params(axis='y', labelcolor=color)
257
+ ax2.tick_params(axis="y", labelcolor=color)
237
258
 
238
259
  if truth_f is not None:
239
- ax2.plot(x_axis, expand_frames_to_samples(truth_f, samples), color='green', label='Truth')
260
+ ax2.plot(
261
+ x_axis,
262
+ expand_frames_to_samples(truth_f, samples),
263
+ color="green",
264
+ label="Truth",
265
+ )
240
266
 
241
267
  if predict is not None:
242
- ax2.plot(x_axis, expand_frames_to_samples(predict, samples), color='brown', label='Predict')
268
+ ax2.plot(
269
+ x_axis,
270
+ expand_frames_to_samples(predict, samples),
271
+ color="brown",
272
+ label="Predict",
273
+ )
243
274
 
244
275
  # set only on last/bottom plot
245
- ax.set_xlabel('time (s)')
276
+ ax.set_xlabel("time (s)")
246
277
 
247
278
  fig.suptitle(label)
248
279
 
@@ -263,7 +294,6 @@ def main() -> None:
263
294
  args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
264
295
 
265
296
  from dataclasses import asdict
266
-
267
297
  from os.path import basename
268
298
  from os.path import exists
269
299
  from os.path import isdir
@@ -279,117 +309,121 @@ def main() -> None:
279
309
  from sonusai import initial_log_messages
280
310
  from sonusai import logger
281
311
  from sonusai import update_console_handler
282
- from sonusai.mixture import MixtureDatabase
283
312
  from sonusai.mixture import FeatureGeneratorConfig
313
+ from sonusai.mixture import MixtureDatabase
284
314
  from sonusai.mixture import get_feature_from_audio
285
315
  from sonusai.mixture import get_truth_indices_for_mixid
286
316
  from sonusai.mixture import read_audio
287
317
  from sonusai.utils import get_label_names
288
318
  from sonusai.utils import print_mixture_details
289
319
 
290
- verbose = args['--verbose']
291
- model_name = args['--model']
292
- output_name = args['--output']
293
- labels_name = args['--labels']
294
- mixid = args['--mixid']
295
- energy = args['--energy']
296
- input_name = args['INPUT']
320
+ verbose = args["--verbose"]
321
+ model_name = args["--model"]
322
+ output_name = args["--output"]
323
+ labels_name = args["--labels"]
324
+ mixid = args["--mixid"]
325
+ energy = args["--energy"]
326
+ input_name = args["INPUT"]
297
327
 
298
328
  if mixid is not None:
299
329
  mixid = int(mixid)
300
330
 
301
- create_file_handler('plot.log')
331
+ create_file_handler("plot.log")
302
332
  update_console_handler(verbose)
303
- initial_log_messages('plot')
333
+ initial_log_messages("plot")
304
334
 
305
335
  if not exists(input_name):
306
- raise SonusAIError(f'{input_name} does not exist')
336
+ raise SonusAIError(f"{input_name} does not exist")
307
337
 
308
- logger.info('')
309
- logger.info(f'Input: {input_name}')
338
+ logger.info("")
339
+ logger.info(f"Input: {input_name}")
310
340
  if model_name is not None:
311
- logger.info(f'Model: {model_name}')
341
+ logger.info(f"Model: {model_name}")
312
342
  if output_name is not None:
313
- logger.info(f'Output: {output_name}')
314
- logger.info('')
343
+ logger.info(f"Output: {output_name}")
344
+ logger.info("")
315
345
 
316
346
  ext = splitext(input_name)[1]
317
347
 
318
348
  model = None
319
349
  target_audio = None
320
350
  truth_f = None
321
- t_indices = None
351
+ t_indices = []
322
352
 
323
353
  if model_name is not None:
324
354
  model = Predict(model_name)
325
355
 
326
- if ext == '.wav':
327
- if model_name is None:
328
- raise SonusAIError('Must specify MODEL when input is WAV')
356
+ if ext == ".wav":
357
+ if model is None:
358
+ raise SonusAIError("Must specify MODEL when input is WAV")
329
359
 
330
360
  mixture_audio = read_audio(input_name)
331
361
  feature = get_feature_from_audio(audio=mixture_audio, feature_mode=model.feature)
332
- fg_config = FeatureGeneratorConfig(feature_mode=model.feature,
333
- num_classes=model.output_shape[-1],
334
- truth_mutex=False)
362
+ fg_config = FeatureGeneratorConfig(
363
+ feature_mode=model.feature,
364
+ num_classes=model.output_shape[-1],
365
+ truth_mutex=False,
366
+ )
335
367
  fg = FeatureGenerator(**asdict(fg_config))
336
368
  fg_step = fg.step
337
369
  mixdb = None
338
- logger.debug(f'Audio samples {len(mixture_audio)}')
339
- logger.debug(f'Feature shape {feature.shape}')
370
+ logger.debug(f"Audio samples {len(mixture_audio)}")
371
+ logger.debug(f"Feature shape {feature.shape}")
340
372
 
341
373
  elif isdir(input_name):
342
374
  if mixid is None:
343
- raise SonusAIError('Must specify mixid when input is mixture database')
375
+ raise SonusAIError("Must specify mixid when input is mixture database")
344
376
 
345
377
  mixdb = MixtureDatabase(input_name)
346
378
  fg_step = mixdb.fg_step
347
379
 
348
380
  print_mixture_details(mixdb=mixdb, mixid=mixid, desc_len=24, print_fn=logger.info)
349
381
 
350
- logger.info(f'Generating data for mixture {mixid}')
382
+ logger.info(f"Generating data for mixture {mixid}")
351
383
  mixture_audio = mixdb.mixture_mixture(mixid)
352
384
  target_audio = mixdb.mixture_target(mixid)
353
385
  feature, truth_f = mixdb.mixture_ft(mixid)
354
386
  t_indices = [x - 1 for x in get_truth_indices_for_mixid(mixdb=mixdb, mixid=mixid)]
355
387
 
356
388
  target_files = [mixdb.target_file(target.file_id) for target in mixdb.mixtures[mixid].targets]
357
- truth_functions = list(set([sub2.function for sub1 in target_files for sub2 in sub1.truth_settings]))
358
- energy = 'energy_f' in truth_functions or 'snr_f' in truth_functions
389
+ truth_functions = list({sub2.function for sub1 in target_files for sub2 in sub1.truth_configs})
390
+ energy = "energy_f" in truth_functions or "snr_f" in truth_functions
359
391
 
360
- logger.debug(f'Audio samples {len(mixture_audio)}')
361
- logger.debug(f'Targets:')
392
+ logger.debug(f"Audio samples {len(mixture_audio)}")
393
+ logger.debug("Targets:")
362
394
  mixture = mixdb.mixture(mixid)
363
395
  for target in mixture.targets:
364
396
  target_file = mixdb.target_file(target.file_id)
365
397
  name = target_file.name
366
398
  duration = target_file.duration
367
399
  augmentation = target.augmentation
368
- logger.debug(f' Name {name}')
369
- logger.debug(f' Duration {duration}')
370
- logger.debug(f' Augmentation {augmentation}')
400
+ logger.debug(f" Name {name}")
401
+ logger.debug(f" Duration {duration}")
402
+ logger.debug(f" Augmentation {augmentation}")
371
403
 
372
- logger.debug(f'Feature shape {feature.shape}')
373
- logger.debug(f'Truth shape {truth_f.shape}')
404
+ logger.debug(f"Feature shape {feature.shape}")
405
+ logger.debug(f"Truth shape {truth_f.shape}")
374
406
 
375
407
  else:
376
- raise SonusAIError(f'Unknown file type for {input_name}')
408
+ raise SonusAIError(f"Unknown file type for {input_name}")
377
409
 
378
410
  predict = None
379
411
  labels = None
380
412
  indices = []
381
413
  if model is not None:
382
- logger.debug('')
383
- logger.info(f'Running prediction on mixture {mixid}')
384
- logger.debug(f'Model feature name {model.feature}')
385
- logger.debug(f'Model input shape {model.input_shape}')
386
- logger.debug(f'Model output shape {model.output_shape}')
414
+ logger.debug("")
415
+ logger.info(f"Running prediction on mixture {mixid}")
416
+ logger.debug(f"Model feature name {model.feature}")
417
+ logger.debug(f"Model input shape {model.input_shape}")
418
+ logger.debug(f"Model output shape {model.output_shape}")
387
419
 
388
420
  if feature.shape[0] < model.input_shape[0]:
389
- raise SonusAIError(f'Mixture {mixid} contains {feature.shape[0]} '
390
- f'frames of data which is not enough to run prediction; '
391
- f'at least {model.input_shape[0]} frames are needed for this model.\n'
392
- f'Consider using a model with a smaller batch size or a mixture with more data.')
421
+ raise SonusAIError(
422
+ f"Mixture {mixid} contains {feature.shape[0]} "
423
+ f"frames of data which is not enough to run prediction; "
424
+ f"at least {model.input_shape[0]} frames are needed for this model.\n"
425
+ f"Consider using a model with a smaller batch size or a mixture with more data."
426
+ )
393
427
 
394
428
  predict = model.execute(feature)
395
429
 
@@ -400,10 +434,10 @@ def main() -> None:
400
434
  p_indices = np.argsort(p_max)[::-1][:5]
401
435
  p_max_len = max([len(labels[i]) for i in p_indices])
402
436
 
403
- logger.info('Top 5 active prediction classes by max:')
437
+ logger.info("Top 5 active prediction classes by max:")
404
438
  for p_index in p_indices:
405
- logger.info(f' {labels[p_index]:{p_max_len}s} {p_max[p_index]:.3f}')
406
- logger.info('')
439
+ logger.info(f" {labels[p_index]:{p_max_len}s} {p_max[p_index]:.3f}")
440
+ logger.info("")
407
441
 
408
442
  indices = list(p_indices)
409
443
 
@@ -414,26 +448,30 @@ def main() -> None:
414
448
 
415
449
  base_name = basename(splitext(input_name)[0])
416
450
  if mixdb is not None:
417
- title = f'{input_name} Mixture {mixid}'
418
- pdf_name = f'{base_name}-mix{mixid}-plot.pdf'
451
+ title = f"{input_name} Mixture {mixid}"
452
+ pdf_name = f"{base_name}-mix{mixid}-plot.pdf"
419
453
  else:
420
- title = f'{input_name}'
421
- pdf_name = f'{base_name}-plot.pdf'
454
+ title = f"{input_name}"
455
+ pdf_name = f"{base_name}-plot.pdf"
422
456
 
423
457
  # Original size [frames, stride, feature_parameters]
424
458
  # Decimate in the stride dimension
425
459
  # Reshape to get frames*decimated_stride, feature_parameters
426
460
  if feature.ndim != 3:
427
- raise SonusAIError(f'feature does not have 3 dimensions: frames, stride, feature_parameters')
461
+ raise SonusAIError("feature does not have 3 dimensions: frames, stride, feature_parameters")
428
462
  spectrogram = feature[:, -fg_step:, :]
429
463
  spectrogram = np.reshape(spectrogram, (spectrogram.shape[0] * spectrogram.shape[1], spectrogram.shape[2]))
430
464
 
431
465
  with PdfPages(pdf_name) as pdf:
432
- pdf.savefig(spec_plot(mixture=mixture_audio,
433
- feature=spectrogram,
434
- predict=predict,
435
- labels=labels,
436
- title=title))
466
+ pdf.savefig(
467
+ spec_plot(
468
+ mixture=mixture_audio,
469
+ feature=spectrogram,
470
+ predict=predict,
471
+ labels=labels,
472
+ title=title,
473
+ )
474
+ )
437
475
  for index in indices:
438
476
  if energy:
439
477
  t_tmp = None
@@ -444,10 +482,14 @@ def main() -> None:
444
482
  if predict is not None:
445
483
  p_tmp = 10 * np.log10(predict + np.finfo(np.float32).eps)
446
484
 
447
- pdf.savefig(spec_energy_plot(mixture=mixture_audio,
448
- feature=spectrogram,
449
- truth_f=t_tmp,
450
- predict=p_tmp))
485
+ pdf.savefig(
486
+ spec_energy_plot(
487
+ mixture=mixture_audio,
488
+ feature=spectrogram,
489
+ truth_f=t_tmp,
490
+ predict=p_tmp,
491
+ )
492
+ )
451
493
  else:
452
494
  p_tmp = None
453
495
  if predict is not None:
@@ -457,18 +499,22 @@ def main() -> None:
457
499
  if labels is not None:
458
500
  l_tmp = labels[index]
459
501
 
460
- pdf.savefig(class_plot(mixture=mixture_audio,
461
- target=target_audio[index],
462
- truth_f=truth_f[:, index],
463
- predict=p_tmp,
464
- label=l_tmp))
465
- logger.info(f'Wrote {pdf_name}')
502
+ pdf.savefig(
503
+ class_plot(
504
+ mixture=mixture_audio,
505
+ target=target_audio[index],
506
+ truth_f=truth_f[:, index],
507
+ predict=p_tmp,
508
+ label=l_tmp,
509
+ )
510
+ )
511
+ logger.info(f"Wrote {pdf_name}")
466
512
 
467
513
  if output_name:
468
- with h5py.File(output_name, 'w') as f:
469
- f.create_dataset(name='predict', data=predict)
470
- logger.info(f'Wrote {output_name}')
514
+ with h5py.File(output_name, "w") as f:
515
+ f.create_dataset(name="predict", data=predict)
516
+ logger.info(f"Wrote {output_name}")
471
517
 
472
518
 
473
- if __name__ == '__main__':
519
+ if __name__ == "__main__":
474
520
  main()