sonusai 0.18.8__py3-none-any.whl → 0.19.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +93 -77
  13. sonusai/genmetrics.py +59 -46
  14. sonusai/genmix.py +116 -104
  15. sonusai/genmixdb.py +194 -153
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +15 -17
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +19 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +50 -46
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +677 -473
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +52 -85
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +40 -27
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
  113. sonusai-0.19.5.dist-info/RECORD +125 -0
  114. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.8.dist-info/RECORD +0 -125
  118. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
sonusai/genmixdb.py CHANGED
@@ -76,8 +76,8 @@ generation functions. By default, these are included with the feature data in a
76
76
  truth generation is turned on with default settings (see truth section) and a single class, i.e., detecting a single
77
77
  type of sound. The truth format is a single float per class representing the probability of activity/presence, and
78
78
  multi-class truth is possible by specifying the number of classes and either a scalar index or a vector of indices in
79
- which to put the truth result. For example, 'num_class: 3' and 'truth_index: 2' adds a 1x3 vector to the feature data
80
- with truth put in index 2 (others would be 0) for data/target.wav being an audio clip from sound type of class 2.
79
+ which to put the truth result. For example, 'num_class: 3' and 'class_indices: [ 2 ]' adds a 1x3 vector to the feature
80
+ data with truth put in index 2 (others would be 0) for data/target.wav being an audio clip from sound type of class 2.
81
81
 
82
82
  The mixture is created with potential data augmentation functions in the following way:
83
83
  1. apply noise augmentation rule
@@ -112,6 +112,7 @@ targets:
112
112
  will find all .wav files in the specified directories and process them as targets.
113
113
 
114
114
  """
115
+
115
116
  import signal
116
117
  from dataclasses import dataclass
117
118
 
@@ -124,7 +125,7 @@ def signal_handler(_sig, _frame):
124
125
 
125
126
  from sonusai import logger
126
127
 
127
- logger.info('Canceled due to keyboard interrupt')
128
+ logger.info("Canceled due to keyboard interrupt")
128
129
  sys.exit(1)
129
130
 
130
131
 
@@ -133,34 +134,34 @@ signal.signal(signal.SIGINT, signal_handler)
133
134
 
134
135
  @dataclass
135
136
  class MPGlobal:
136
- mixdb: MixtureDatabase = None
137
- save_mix: bool = None
138
- save_ft: bool = None
139
- save_segsnr: bool = None
137
+ mixdb: MixtureDatabase
138
+ save_mix: bool
139
+ save_ft: bool
140
+ save_segsnr: bool
140
141
 
141
142
 
142
- MP_GLOBAL = MPGlobal()
143
+ MP_GLOBAL: MPGlobal
143
144
 
144
145
 
145
- def genmixdb(location: str,
146
- save_mix: bool = False,
147
- save_ft: bool = False,
148
- save_segsnr: bool = False,
149
- logging: bool = True,
150
- show_progress: bool = False,
151
- test: bool = False,
152
- save_json: bool = False) -> MixtureDatabase:
146
+ def genmixdb(
147
+ location: str,
148
+ save_mix: bool = False,
149
+ save_ft: bool = False,
150
+ save_segsnr: bool = False,
151
+ logging: bool = True,
152
+ show_progress: bool = False,
153
+ test: bool = False,
154
+ save_json: bool = False,
155
+ ) -> MixtureDatabase:
153
156
  from random import seed
154
157
 
155
158
  import yaml
156
- from tqdm import tqdm
157
159
 
158
- from sonusai import SonusAIError
159
160
  from sonusai import logger
160
- from sonusai.mixture import AugmentationRule
161
- from sonusai.mixture import MixtureDatabase
162
161
  from sonusai.mixture import SAMPLE_BYTES
163
162
  from sonusai.mixture import SAMPLE_RATE
163
+ from sonusai.mixture import AugmentationRule
164
+ from sonusai.mixture import MixtureDatabase
164
165
  from sonusai.mixture import balance_targets
165
166
  from sonusai.mixture import generate_mixtures
166
167
  from sonusai.mixture import get_all_snrs_from_config
@@ -182,11 +183,13 @@ def genmixdb(location: str,
182
183
  from sonusai.mixture import populate_spectral_mask_table
183
184
  from sonusai.mixture import populate_target_file_table
184
185
  from sonusai.mixture import populate_top_table
186
+ from sonusai.mixture import populate_truth_parameters_table
185
187
  from sonusai.mixture import update_mixid_width
186
188
  from sonusai.utils import dataclass_from_dict
187
189
  from sonusai.utils import human_readable_size
188
- from sonusai.utils import pp_tqdm_imap
190
+ from sonusai.utils import par_track
189
191
  from sonusai.utils import seconds_to_hms
192
+ from sonusai.utils import track
190
193
 
191
194
  config = load_config(location)
192
195
  initialize_db(location=location, test=test)
@@ -197,113 +200,116 @@ def genmixdb(location: str,
197
200
  populate_class_label_table(location, config, test)
198
201
  populate_class_weights_threshold_table(location, config, test)
199
202
  populate_spectral_mask_table(location, config, test)
203
+ populate_truth_parameters_table(location, config, test)
200
204
 
201
- seed(config['seed'])
205
+ seed(config["seed"])
202
206
 
203
207
  if logging:
204
- logger.debug(f'Seed: {config["seed"]}')
205
- logger.debug('Configuration:')
208
+ logger.debug(f"Seed: {config['seed']}")
209
+ logger.debug("Configuration:")
206
210
  logger.debug(yaml.dump(config))
207
211
 
208
212
  if logging:
209
- logger.info('Collecting targets')
213
+ logger.info("Collecting targets")
210
214
 
211
215
  target_files = get_target_files(config, show_progress=show_progress)
212
216
 
213
217
  if len(target_files) == 0:
214
- raise SonusAIError('Canceled due to no targets')
218
+ raise RuntimeError("Canceled due to no targets")
215
219
 
216
220
  populate_target_file_table(location, target_files, test)
217
221
 
218
222
  if logging:
219
- logger.debug('List of targets:')
223
+ logger.debug("List of targets:")
220
224
  logger.debug(yaml.dump([target.name for target in mixdb.target_files], default_flow_style=False))
221
- logger.debug('')
225
+ logger.debug("")
222
226
 
223
227
  if logging:
224
- logger.info('Collecting noises')
228
+ logger.info("Collecting noises")
225
229
 
226
230
  noise_files = get_noise_files(config, show_progress=show_progress)
227
231
 
228
232
  populate_noise_file_table(location, noise_files, test)
229
233
 
230
234
  if logging:
231
- logger.debug('List of noises:')
235
+ logger.debug("List of noises:")
232
236
  logger.debug(yaml.dump([noise.name for noise in mixdb.noise_files], default_flow_style=False))
233
- logger.debug('')
237
+ logger.debug("")
234
238
 
235
239
  if logging:
236
- logger.info('Collecting impulse responses')
240
+ logger.info("Collecting impulse responses")
237
241
 
238
242
  impulse_response_files = get_impulse_response_files(config)
239
243
 
240
244
  populate_impulse_response_file_table(location, impulse_response_files, test)
241
245
 
242
246
  if logging:
243
- logger.debug('List of impulse responses:')
247
+ logger.debug("List of impulse responses:")
244
248
  logger.debug(
245
- yaml.dump([impulse_response for impulse_response in mixdb.impulse_response_files],
246
- default_flow_style=False))
247
- logger.debug('')
249
+ yaml.dump(
250
+ [entry.file for entry in mixdb.impulse_response_files],
251
+ default_flow_style=False,
252
+ )
253
+ )
254
+ logger.debug("")
248
255
 
249
256
  if logging:
250
- logger.info('Collecting target augmentations')
257
+ logger.info("Collecting target augmentations")
251
258
 
252
- target_augmentations = get_augmentation_rules(rules=config['target_augmentations'],
253
- num_ir=mixdb.num_impulse_response_files)
259
+ target_augmentations = get_augmentation_rules(
260
+ rules=config["target_augmentations"], num_ir=mixdb.num_impulse_response_files
261
+ )
254
262
  mixups = get_mixups(target_augmentations)
255
263
 
256
264
  if logging:
257
265
  for mixup in mixups:
258
- logger.debug(f'Expanded list of target augmentation rules for mixup of {mixup}:')
266
+ logger.debug(f"Expanded list of target augmentation rules for mixup of {mixup}:")
259
267
  for target_augmentation in get_target_augmentations_for_mixup(target_augmentations, mixup):
260
268
  ta_dict = target_augmentation.to_dict()
261
- del ta_dict['mixup']
262
- logger.debug(f'- {ta_dict}')
263
- logger.debug('')
269
+ del ta_dict["mixup"]
270
+ logger.debug(f"- {ta_dict}")
271
+ logger.debug("")
264
272
 
265
273
  if logging:
266
- logger.info('Collecting noise augmentations')
274
+ logger.info("Collecting noise augmentations")
267
275
 
268
- noise_augmentations = get_augmentation_rules(rules=config['noise_augmentations'],
269
- num_ir=mixdb.num_impulse_response_files)
276
+ noise_augmentations = get_augmentation_rules(
277
+ rules=config["noise_augmentations"], num_ir=mixdb.num_impulse_response_files
278
+ )
270
279
 
271
280
  if logging:
272
- logger.debug('Expanded list of noise augmentations:')
281
+ logger.debug("Expanded list of noise augmentations:")
273
282
  for noise_augmentation in noise_augmentations:
274
283
  na_dict = noise_augmentation.to_dict()
275
- del na_dict['mixup']
276
- logger.debug(f'- {na_dict}')
277
- logger.debug('')
284
+ del na_dict["mixup"]
285
+ logger.debug(f"- {na_dict}")
286
+ logger.debug("")
278
287
 
279
288
  if logging:
280
- logger.debug(f'SNRs: {config["snrs"]}\n')
281
- logger.debug(f'Random SNRs: {config["random_snrs"]}\n')
282
- logger.debug(f'Noise mix mode: {mixdb.noise_mix_mode}\n')
283
- logger.debug(f'Spectral masks:')
289
+ logger.debug(f"SNRs: {config['snrs']}\n")
290
+ logger.debug(f"Random SNRs: {config['random_snrs']}\n")
291
+ logger.debug(f"Noise mix mode: {mixdb.noise_mix_mode}\n")
292
+ logger.debug("Spectral masks:")
284
293
  for spectral_mask in mixdb.spectral_masks:
285
- logger.debug(f'- {spectral_mask}')
286
- logger.debug('')
287
-
288
- if mixdb.truth_mutex and any(mixup > 1 for mixup in mixups):
289
- raise SonusAIError(f'Mutex truth mode is not compatible with mixup')
294
+ logger.debug(f"- {spectral_mask}")
295
+ logger.debug("")
290
296
 
291
297
  if logging:
292
- logger.info('Collecting augmented targets')
298
+ logger.info("Collecting augmented targets")
293
299
 
294
300
  augmented_targets = get_augmented_targets(target_files, target_augmentations, mixups)
295
301
 
296
- if config['class_balancing']:
297
- class_balancing_augmentation = dataclass_from_dict(AugmentationRule, config['class_balancing_augmentation'])
302
+ if config["class_balancing"]:
303
+ class_balancing_augmentation = dataclass_from_dict(AugmentationRule, config["class_balancing_augmentation"])
298
304
  augmented_targets, target_augmentations = balance_targets(
299
305
  augmented_targets=augmented_targets,
300
306
  targets=target_files,
301
307
  target_augmentations=target_augmentations,
302
308
  class_balancing_augmentation=class_balancing_augmentation,
303
309
  num_classes=mixdb.num_classes,
304
- truth_mutex=mixdb.truth_mutex,
305
310
  num_ir=mixdb.num_impulse_response_files,
306
- mixups=mixups)
311
+ mixups=mixups,
312
+ )
307
313
 
308
314
  target_audio_samples = sum([targets.samples for targets in mixdb.target_files])
309
315
  target_audio_duration = target_audio_samples / SAMPLE_RATE
@@ -311,13 +317,17 @@ def genmixdb(location: str,
311
317
  noise_audio_samples = noise_audio_duration * SAMPLE_RATE
312
318
 
313
319
  if logging:
314
- logger.info('')
315
- logger.info(f'Target audio: {mixdb.num_target_files} files, '
316
- f'{human_readable_size(target_audio_samples * SAMPLE_BYTES, 1)}, '
317
- f'{seconds_to_hms(seconds=target_audio_duration)}')
318
- logger.info(f'Noise audio: {mixdb.num_noise_files} files, '
319
- f'{human_readable_size(noise_audio_samples * SAMPLE_BYTES, 1)}, '
320
- f'{seconds_to_hms(seconds=noise_audio_duration)}')
320
+ logger.info("")
321
+ logger.info(
322
+ f"Target audio: {mixdb.num_target_files} files, "
323
+ f"{human_readable_size(target_audio_samples * SAMPLE_BYTES, 1)}, "
324
+ f"{seconds_to_hms(seconds=target_audio_duration)}"
325
+ )
326
+ logger.info(
327
+ f"Noise audio: {mixdb.num_noise_files} files, "
328
+ f"{human_readable_size(noise_audio_samples * SAMPLE_BYTES, 1)}, "
329
+ f"{seconds_to_hms(seconds=noise_audio_duration)}"
330
+ )
321
331
 
322
332
  used_noise_files, used_noise_samples, mixtures = generate_mixtures(
323
333
  noise_mix_mode=mixdb.noise_mix_mode,
@@ -330,41 +340,48 @@ def genmixdb(location: str,
330
340
  all_snrs=get_all_snrs_from_config(config),
331
341
  mixups=mixups,
332
342
  num_classes=mixdb.num_classes,
333
- truth_mutex=mixdb.truth_mutex,
334
343
  feature_step_samples=mixdb.feature_step_samples,
335
- num_ir=mixdb.num_impulse_response_files)
344
+ num_ir=mixdb.num_impulse_response_files,
345
+ )
336
346
 
337
347
  num_mixtures = len(mixtures)
338
348
  update_mixid_width(location, num_mixtures, test)
339
349
 
340
350
  if logging:
341
- logger.info('')
342
- logger.info(f'Found {num_mixtures:,} mixtures to process')
351
+ logger.info("")
352
+ logger.info(f"Found {num_mixtures:,} mixtures to process")
343
353
 
344
354
  total_duration = float(sum([mixture.samples for mixture in mixtures])) / SAMPLE_RATE
345
355
 
346
356
  if logging:
347
- log_duration_and_sizes(total_duration=total_duration,
348
- num_classes=mixdb.num_classes,
349
- feature_step_samples=mixdb.feature_step_samples,
350
- feature_parameters=mixdb.feature_parameters,
351
- stride=mixdb.fg_stride,
352
- desc='Estimated')
353
- logger.info(f'Feature shape: '
354
- f'{mixdb.fg_stride} x {mixdb.feature_parameters} '
355
- f'({mixdb.fg_stride * mixdb.feature_parameters} total params)')
356
- logger.info(f'Feature samples: {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)')
357
- logger.info(f'Feature step samples: {mixdb.feature_step_samples} samples ({mixdb.feature_step_ms} ms)')
358
- logger.info('')
357
+ log_duration_and_sizes(
358
+ total_duration=total_duration,
359
+ num_classes=mixdb.num_classes,
360
+ feature_step_samples=mixdb.feature_step_samples,
361
+ feature_parameters=mixdb.feature_parameters,
362
+ stride=mixdb.fg_stride,
363
+ desc="Estimated",
364
+ )
365
+ logger.info(
366
+ f"Feature shape: "
367
+ f"{mixdb.fg_stride} x {mixdb.feature_parameters} "
368
+ f"({mixdb.fg_stride * mixdb.feature_parameters} total params)"
369
+ )
370
+ logger.info(f"Feature samples: {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)")
371
+ logger.info(f"Feature step samples: {mixdb.feature_step_samples} samples ({mixdb.feature_step_ms} ms)")
372
+ logger.info("")
359
373
 
360
374
  # Fill in the details
361
375
  if logging:
362
- logger.info('Generating mixtures')
363
- progress = tqdm(total=num_mixtures, disable=not show_progress)
364
- mixtures = pp_tqdm_imap(_process_mixture, mixtures,
365
- progress=progress,
366
- initializer=_initializer,
367
- initargs=(location, save_mix, save_ft, save_segsnr, test))
376
+ logger.info("Generating mixtures")
377
+ progress = track(total=num_mixtures, disable=not show_progress)
378
+ mixtures = par_track(
379
+ _process_mixture,
380
+ mixtures,
381
+ progress=progress,
382
+ initializer=_initializer,
383
+ initargs=(location, save_mix, save_ft, save_segsnr, test),
384
+ )
368
385
  progress.close()
369
386
 
370
387
  populate_mixture_table(location, mixtures, test)
@@ -378,20 +395,22 @@ def genmixdb(location: str,
378
395
  noise_samples_percent = (float(used_noise_samples) / float(noise_audio_samples)) * 100
379
396
 
380
397
  if logging:
381
- log_duration_and_sizes(total_duration=total_duration,
382
- num_classes=mixdb.num_classes,
383
- feature_step_samples=mixdb.feature_step_samples,
384
- feature_parameters=mixdb.feature_parameters,
385
- stride=mixdb.fg_stride,
386
- desc='Actual')
387
- logger.info('')
388
- logger.info(f'Used {noise_files_percent:,.0f}% of noise files')
389
- logger.info(f'Used {noise_samples_percent:,.0f}% of noise audio')
390
- logger.info('')
398
+ log_duration_and_sizes(
399
+ total_duration=total_duration,
400
+ num_classes=mixdb.num_classes,
401
+ feature_step_samples=mixdb.feature_step_samples,
402
+ feature_parameters=mixdb.feature_parameters,
403
+ stride=mixdb.fg_stride,
404
+ desc="Actual",
405
+ )
406
+ logger.info("")
407
+ logger.info(f"Used {noise_files_percent:,.0f}% of noise files")
408
+ logger.info(f"Used {noise_samples_percent:,.0f}% of noise audio")
409
+ logger.info("")
391
410
 
392
411
  if not test and save_json:
393
412
  if logging:
394
- logger.info(f'Writing JSON version of database to {location}')
413
+ logger.info(f"Writing JSON version of database to {location}")
395
414
  mixdb = MixtureDatabase(location)
396
415
  mixdb.save()
397
416
 
@@ -399,10 +418,14 @@ def genmixdb(location: str,
399
418
 
400
419
 
401
420
  def _initializer(location: str, save_mix: bool, save_ft: bool, save_segsnr: bool, test: bool) -> None:
402
- MP_GLOBAL.mixdb = MixtureDatabase(location, test)
403
- MP_GLOBAL.save_mix = save_mix
404
- MP_GLOBAL.save_ft = save_ft
405
- MP_GLOBAL.save_segsnr = save_segsnr
421
+ global MP_GLOBAL
422
+
423
+ MP_GLOBAL = MPGlobal(
424
+ mixdb=MixtureDatabase(location, test),
425
+ save_mix=save_mix,
426
+ save_ft=save_ft,
427
+ save_segsnr=save_segsnr,
428
+ )
406
429
 
407
430
 
408
431
  def _process_mixture(mixture: Mixture) -> Mixture:
@@ -410,11 +433,13 @@ def _process_mixture(mixture: Mixture) -> Mixture:
410
433
 
411
434
  from sonusai.mixture import get_ft
412
435
  from sonusai.mixture import get_segsnr
413
- from sonusai.mixture import get_truth_t
436
+ from sonusai.mixture import get_truth
414
437
  from sonusai.mixture import update_mixture
415
- from sonusai.mixture import write_mixture_data
438
+ from sonusai.mixture import write_cached_data
416
439
  from sonusai.mixture import write_mixture_metadata
417
440
 
441
+ global MP_GLOBAL
442
+
418
443
  with_data = MP_GLOBAL.save_mix or MP_GLOBAL.save_ft
419
444
  mixdb = MP_GLOBAL.mixdb
420
445
 
@@ -424,31 +449,41 @@ def _process_mixture(mixture: Mixture) -> Mixture:
424
449
  write_data: list[tuple[str, Any]] = []
425
450
 
426
451
  if MP_GLOBAL.save_mix:
427
- write_data.append(('targets', genmix_data.targets))
428
- write_data.append(('noise', genmix_data.noise))
429
- write_data.append(('mixture', genmix_data.mixture))
452
+ write_data.append(("targets", genmix_data.targets))
453
+ write_data.append(("noise", genmix_data.noise))
454
+ write_data.append(("mixture", genmix_data.mixture))
430
455
 
431
456
  if MP_GLOBAL.save_ft:
432
- truth_t = get_truth_t(mixdb=mixdb,
433
- mixture=mixture,
434
- targets_audio=genmix_data.targets,
435
- noise_audio=genmix_data.noise,
436
- mixture_audio=genmix_data.mixture)
437
- feature, truth_f = get_ft(mixdb=mixdb,
438
- mixture=mixture,
439
- mixture_audio=genmix_data.mixture,
440
- truth_t=truth_t)
441
- write_data.append(('feature', feature))
442
- write_data.append(('truth_f', truth_f))
457
+ if genmix_data.targets is None or genmix_data.noise is None or genmix_data.mixture is None:
458
+ raise RuntimeError("Mixture data was not generated properly")
459
+ truth_t = get_truth(
460
+ mixdb=mixdb,
461
+ mixture=mixture,
462
+ targets_audio=genmix_data.targets,
463
+ noise_audio=genmix_data.noise,
464
+ mixture_audio=genmix_data.mixture,
465
+ )
466
+ feature, truth_f = get_ft(
467
+ mixdb=mixdb,
468
+ mixture=mixture,
469
+ mixture_audio=genmix_data.mixture,
470
+ truth_t=truth_t,
471
+ )
472
+ write_data.append(("feature", feature))
473
+ write_data.append(("truth_f", truth_f))
443
474
 
444
475
  if MP_GLOBAL.save_segsnr:
445
- segsnr = get_segsnr(mixdb=mixdb,
446
- mixture=mixture,
447
- target_audio=genmix_data.target,
448
- noise=genmix_data.noise)
449
- write_data.append(('segsnr', segsnr))
450
-
451
- write_mixture_data(mixdb, mixture, write_data)
476
+ if genmix_data.target is None:
477
+ raise RuntimeError("Target data was not generated properly")
478
+ segsnr = get_segsnr(
479
+ mixdb=mixdb,
480
+ mixture=mixture,
481
+ target_audio=genmix_data.target,
482
+ noise=genmix_data.noise,
483
+ )
484
+ write_data.append(("segsnr", segsnr))
485
+
486
+ write_cached_data(mixdb.location, "mixture", mixture.name, write_data)
452
487
  write_mixture_metadata(mixdb, mixture)
453
488
 
454
489
  return mixture
@@ -478,13 +513,13 @@ def main() -> None:
478
513
  from sonusai.mixture import load_config
479
514
  from sonusai.utils import seconds_to_hms
480
515
 
481
- verbose = args['--verbose']
482
- save_mix = args['--mix']
483
- save_ft = args['--ft']
484
- save_segsnr = args['--segsnr']
485
- dryrun = args['--dryrun']
486
- save_json = args['--json']
487
- location = args['LOC']
516
+ verbose = args["--verbose"]
517
+ save_mix = args["--mix"]
518
+ save_ft = args["--ft"]
519
+ save_segsnr = args["--segsnr"]
520
+ dryrun = args["--dryrun"]
521
+ save_json = args["--json"]
522
+ location = args["LOC"]
488
523
 
489
524
  start_time = time.monotonic()
490
525
 
@@ -493,30 +528,36 @@ def main() -> None:
493
528
 
494
529
  makedirs(location, exist_ok=True)
495
530
 
496
- create_file_handler(join(location, 'genmixdb.log'))
531
+ create_file_handler(join(location, "genmixdb.log"))
497
532
  update_console_handler(verbose)
498
- initial_log_messages('genmixdb')
533
+ initial_log_messages("genmixdb")
499
534
 
500
535
  if dryrun:
501
536
  config = load_config(location)
502
- logger.info('Dryrun configuration:')
537
+ logger.info("Dryrun configuration:")
503
538
  logger.info(yaml.dump(config))
504
539
  return
505
540
 
506
- logger.info(f'Creating mixture database for {location}')
507
- logger.info('')
508
-
509
- genmixdb(location=location,
510
- save_mix=save_mix,
511
- save_ft=save_ft,
512
- save_segsnr=save_segsnr,
513
- show_progress=True,
514
- save_json=save_json)
541
+ logger.info(f"Creating mixture database for {location}")
542
+ logger.info("")
543
+
544
+ try:
545
+ genmixdb(
546
+ location=location,
547
+ save_mix=save_mix,
548
+ save_ft=save_ft,
549
+ save_segsnr=save_segsnr,
550
+ show_progress=True,
551
+ save_json=save_json,
552
+ )
553
+ except Exception as e:
554
+ logger.debug(e)
555
+ raise
515
556
 
516
557
  end_time = time.monotonic()
517
- logger.info(f'Completed in {seconds_to_hms(seconds=end_time - start_time)}')
518
- logger.info('')
558
+ logger.info(f"Completed in {seconds_to_hms(seconds=end_time - start_time)}")
559
+ logger.info("")
519
560
 
520
561
 
521
- if __name__ == '__main__':
562
+ if __name__ == "__main__":
522
563
  main()