sonusai 0.18.9__py3-none-any.whl → 0.19.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +81 -91
  13. sonusai/genmetrics.py +51 -61
  14. sonusai/genmix.py +105 -115
  15. sonusai/genmixdb.py +201 -174
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +16 -18
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +20 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +40 -38
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +669 -477
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +58 -101
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +41 -30
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/METADATA +20 -21
  113. sonusai-0.19.6.dist-info/RECORD +125 -0
  114. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.9.dist-info/RECORD +0 -125
  118. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/entry_points.txt +0 -0
sonusai/genmixdb.py CHANGED
@@ -76,8 +76,8 @@ generation functions. By default, these are included with the feature data in a
76
76
  truth generation is turned on with default settings (see truth section) and a single class, i.e., detecting a single
77
77
  type of sound. The truth format is a single float per class representing the probability of activity/presence, and
78
78
  multi-class truth is possible by specifying the number of classes and either a scalar index or a vector of indices in
79
- which to put the truth result. For example, 'num_class: 3' and 'truth_index: 2' adds a 1x3 vector to the feature data
80
- with truth put in index 2 (others would be 0) for data/target.wav being an audio clip from sound type of class 2.
79
+ which to put the truth result. For example, 'num_class: 3' and 'class_indices: [ 2 ]' adds a 1x3 vector to the feature
80
+ data with truth put in index 2 (others would be 0) for data/target.wav being an audio clip from sound type of class 2.
81
81
 
82
82
  The mixture is created with potential data augmentation functions in the following way:
83
83
  1. apply noise augmentation rule
@@ -112,11 +112,10 @@ targets:
112
112
  will find all .wav files in the specified directories and process them as targets.
113
113
 
114
114
  """
115
+
115
116
  import signal
116
- from dataclasses import dataclass
117
117
 
118
118
  from sonusai.mixture import Mixture
119
- from sonusai.mixture import MixtureDatabase
120
119
 
121
120
 
122
121
  def signal_handler(_sig, _frame):
@@ -124,43 +123,33 @@ def signal_handler(_sig, _frame):
124
123
 
125
124
  from sonusai import logger
126
125
 
127
- logger.info('Canceled due to keyboard interrupt')
126
+ logger.info("Canceled due to keyboard interrupt")
128
127
  sys.exit(1)
129
128
 
130
129
 
131
130
  signal.signal(signal.SIGINT, signal_handler)
132
131
 
133
132
 
134
- @dataclass
135
- class MPGlobal:
136
- mixdb: MixtureDatabase = None
137
- save_mix: bool = None
138
- save_ft: bool = None
139
- save_segsnr: bool = None
140
-
141
-
142
- MP_GLOBAL = MPGlobal()
143
-
144
-
145
- def genmixdb(location: str,
146
- save_mix: bool = False,
147
- save_ft: bool = False,
148
- save_segsnr: bool = False,
149
- logging: bool = True,
150
- show_progress: bool = False,
151
- test: bool = False,
152
- save_json: bool = False) -> MixtureDatabase:
133
+ def genmixdb(
134
+ location: str,
135
+ save_mix: bool = False,
136
+ save_ft: bool = False,
137
+ save_segsnr: bool = False,
138
+ logging: bool = True,
139
+ show_progress: bool = False,
140
+ test: bool = False,
141
+ save_json: bool = False,
142
+ ) -> None:
143
+ from functools import partial
153
144
  from random import seed
154
145
 
155
146
  import yaml
156
- from tqdm import tqdm
157
147
 
158
- from sonusai import SonusAIError
159
148
  from sonusai import logger
160
- from sonusai.mixture import AugmentationRule
161
- from sonusai.mixture import MixtureDatabase
162
149
  from sonusai.mixture import SAMPLE_BYTES
163
150
  from sonusai.mixture import SAMPLE_RATE
151
+ from sonusai.mixture import AugmentationRule
152
+ from sonusai.mixture import MixtureDatabase
164
153
  from sonusai.mixture import balance_targets
165
154
  from sonusai.mixture import generate_mixtures
166
155
  from sonusai.mixture import get_all_snrs_from_config
@@ -182,11 +171,13 @@ def genmixdb(location: str,
182
171
  from sonusai.mixture import populate_spectral_mask_table
183
172
  from sonusai.mixture import populate_target_file_table
184
173
  from sonusai.mixture import populate_top_table
174
+ from sonusai.mixture import populate_truth_parameters_table
185
175
  from sonusai.mixture import update_mixid_width
186
176
  from sonusai.utils import dataclass_from_dict
187
177
  from sonusai.utils import human_readable_size
188
- from sonusai.utils import pp_tqdm_imap
178
+ from sonusai.utils import par_track
189
179
  from sonusai.utils import seconds_to_hms
180
+ from sonusai.utils import track
190
181
 
191
182
  config = load_config(location)
192
183
  initialize_db(location=location, test=test)
@@ -197,113 +188,116 @@ def genmixdb(location: str,
197
188
  populate_class_label_table(location, config, test)
198
189
  populate_class_weights_threshold_table(location, config, test)
199
190
  populate_spectral_mask_table(location, config, test)
191
+ populate_truth_parameters_table(location, config, test)
200
192
 
201
- seed(config['seed'])
193
+ seed(config["seed"])
202
194
 
203
195
  if logging:
204
- logger.debug(f'Seed: {config["seed"]}')
205
- logger.debug('Configuration:')
196
+ logger.debug(f"Seed: {config['seed']}")
197
+ logger.debug("Configuration:")
206
198
  logger.debug(yaml.dump(config))
207
199
 
208
200
  if logging:
209
- logger.info('Collecting targets')
201
+ logger.info("Collecting targets")
210
202
 
211
203
  target_files = get_target_files(config, show_progress=show_progress)
212
204
 
213
205
  if len(target_files) == 0:
214
- raise SonusAIError('Canceled due to no targets')
206
+ raise RuntimeError("Canceled due to no targets")
215
207
 
216
208
  populate_target_file_table(location, target_files, test)
217
209
 
218
210
  if logging:
219
- logger.debug('List of targets:')
211
+ logger.debug("List of targets:")
220
212
  logger.debug(yaml.dump([target.name for target in mixdb.target_files], default_flow_style=False))
221
- logger.debug('')
213
+ logger.debug("")
222
214
 
223
215
  if logging:
224
- logger.info('Collecting noises')
216
+ logger.info("Collecting noises")
225
217
 
226
218
  noise_files = get_noise_files(config, show_progress=show_progress)
227
219
 
228
220
  populate_noise_file_table(location, noise_files, test)
229
221
 
230
222
  if logging:
231
- logger.debug('List of noises:')
223
+ logger.debug("List of noises:")
232
224
  logger.debug(yaml.dump([noise.name for noise in mixdb.noise_files], default_flow_style=False))
233
- logger.debug('')
225
+ logger.debug("")
234
226
 
235
227
  if logging:
236
- logger.info('Collecting impulse responses')
228
+ logger.info("Collecting impulse responses")
237
229
 
238
230
  impulse_response_files = get_impulse_response_files(config)
239
231
 
240
232
  populate_impulse_response_file_table(location, impulse_response_files, test)
241
233
 
242
234
  if logging:
243
- logger.debug('List of impulse responses:')
235
+ logger.debug("List of impulse responses:")
244
236
  logger.debug(
245
- yaml.dump([impulse_response for impulse_response in mixdb.impulse_response_files],
246
- default_flow_style=False))
247
- logger.debug('')
237
+ yaml.dump(
238
+ [entry.file for entry in mixdb.impulse_response_files],
239
+ default_flow_style=False,
240
+ )
241
+ )
242
+ logger.debug("")
248
243
 
249
244
  if logging:
250
- logger.info('Collecting target augmentations')
245
+ logger.info("Collecting target augmentations")
251
246
 
252
- target_augmentations = get_augmentation_rules(rules=config['target_augmentations'],
253
- num_ir=mixdb.num_impulse_response_files)
247
+ target_augmentations = get_augmentation_rules(
248
+ rules=config["target_augmentations"], num_ir=mixdb.num_impulse_response_files
249
+ )
254
250
  mixups = get_mixups(target_augmentations)
255
251
 
256
252
  if logging:
257
253
  for mixup in mixups:
258
- logger.debug(f'Expanded list of target augmentation rules for mixup of {mixup}:')
254
+ logger.debug(f"Expanded list of target augmentation rules for mixup of {mixup}:")
259
255
  for target_augmentation in get_target_augmentations_for_mixup(target_augmentations, mixup):
260
256
  ta_dict = target_augmentation.to_dict()
261
- del ta_dict['mixup']
262
- logger.debug(f'- {ta_dict}')
263
- logger.debug('')
257
+ del ta_dict["mixup"]
258
+ logger.debug(f"- {ta_dict}")
259
+ logger.debug("")
264
260
 
265
261
  if logging:
266
- logger.info('Collecting noise augmentations')
262
+ logger.info("Collecting noise augmentations")
267
263
 
268
- noise_augmentations = get_augmentation_rules(rules=config['noise_augmentations'],
269
- num_ir=mixdb.num_impulse_response_files)
264
+ noise_augmentations = get_augmentation_rules(
265
+ rules=config["noise_augmentations"], num_ir=mixdb.num_impulse_response_files
266
+ )
270
267
 
271
268
  if logging:
272
- logger.debug('Expanded list of noise augmentations:')
269
+ logger.debug("Expanded list of noise augmentations:")
273
270
  for noise_augmentation in noise_augmentations:
274
271
  na_dict = noise_augmentation.to_dict()
275
- del na_dict['mixup']
276
- logger.debug(f'- {na_dict}')
277
- logger.debug('')
272
+ del na_dict["mixup"]
273
+ logger.debug(f"- {na_dict}")
274
+ logger.debug("")
278
275
 
279
276
  if logging:
280
- logger.debug(f'SNRs: {config["snrs"]}\n')
281
- logger.debug(f'Random SNRs: {config["random_snrs"]}\n')
282
- logger.debug(f'Noise mix mode: {mixdb.noise_mix_mode}\n')
283
- logger.debug(f'Spectral masks:')
277
+ logger.debug(f"SNRs: {config['snrs']}\n")
278
+ logger.debug(f"Random SNRs: {config['random_snrs']}\n")
279
+ logger.debug(f"Noise mix mode: {mixdb.noise_mix_mode}\n")
280
+ logger.debug("Spectral masks:")
284
281
  for spectral_mask in mixdb.spectral_masks:
285
- logger.debug(f'- {spectral_mask}')
286
- logger.debug('')
287
-
288
- if mixdb.truth_mutex and any(mixup > 1 for mixup in mixups):
289
- raise SonusAIError(f'Mutex truth mode is not compatible with mixup')
282
+ logger.debug(f"- {spectral_mask}")
283
+ logger.debug("")
290
284
 
291
285
  if logging:
292
- logger.info('Collecting augmented targets')
286
+ logger.info("Collecting augmented targets")
293
287
 
294
288
  augmented_targets = get_augmented_targets(target_files, target_augmentations, mixups)
295
289
 
296
- if config['class_balancing']:
297
- class_balancing_augmentation = dataclass_from_dict(AugmentationRule, config['class_balancing_augmentation'])
290
+ if config["class_balancing"]:
291
+ class_balancing_augmentation = dataclass_from_dict(AugmentationRule, config["class_balancing_augmentation"])
298
292
  augmented_targets, target_augmentations = balance_targets(
299
293
  augmented_targets=augmented_targets,
300
294
  targets=target_files,
301
295
  target_augmentations=target_augmentations,
302
296
  class_balancing_augmentation=class_balancing_augmentation,
303
297
  num_classes=mixdb.num_classes,
304
- truth_mutex=mixdb.truth_mutex,
305
298
  num_ir=mixdb.num_impulse_response_files,
306
- mixups=mixups)
299
+ mixups=mixups,
300
+ )
307
301
 
308
302
  target_audio_samples = sum([targets.samples for targets in mixdb.target_files])
309
303
  target_audio_duration = target_audio_samples / SAMPLE_RATE
@@ -311,13 +305,17 @@ def genmixdb(location: str,
311
305
  noise_audio_samples = noise_audio_duration * SAMPLE_RATE
312
306
 
313
307
  if logging:
314
- logger.info('')
315
- logger.info(f'Target audio: {mixdb.num_target_files} files, '
316
- f'{human_readable_size(target_audio_samples * SAMPLE_BYTES, 1)}, '
317
- f'{seconds_to_hms(seconds=target_audio_duration)}')
318
- logger.info(f'Noise audio: {mixdb.num_noise_files} files, '
319
- f'{human_readable_size(noise_audio_samples * SAMPLE_BYTES, 1)}, '
320
- f'{seconds_to_hms(seconds=noise_audio_duration)}')
308
+ logger.info("")
309
+ logger.info(
310
+ f"Target audio: {mixdb.num_target_files} files, "
311
+ f"{human_readable_size(target_audio_samples * SAMPLE_BYTES, 1)}, "
312
+ f"{seconds_to_hms(seconds=target_audio_duration)}"
313
+ )
314
+ logger.info(
315
+ f"Noise audio: {mixdb.num_noise_files} files, "
316
+ f"{human_readable_size(noise_audio_samples * SAMPLE_BYTES, 1)}, "
317
+ f"{seconds_to_hms(seconds=noise_audio_duration)}"
318
+ )
321
319
 
322
320
  used_noise_files, used_noise_samples, mixtures = generate_mixtures(
323
321
  noise_mix_mode=mixdb.noise_mix_mode,
@@ -330,41 +328,53 @@ def genmixdb(location: str,
330
328
  all_snrs=get_all_snrs_from_config(config),
331
329
  mixups=mixups,
332
330
  num_classes=mixdb.num_classes,
333
- truth_mutex=mixdb.truth_mutex,
334
331
  feature_step_samples=mixdb.feature_step_samples,
335
- num_ir=mixdb.num_impulse_response_files)
332
+ num_ir=mixdb.num_impulse_response_files,
333
+ )
336
334
 
337
335
  num_mixtures = len(mixtures)
338
336
  update_mixid_width(location, num_mixtures, test)
339
337
 
340
338
  if logging:
341
- logger.info('')
342
- logger.info(f'Found {num_mixtures:,} mixtures to process')
339
+ logger.info("")
340
+ logger.info(f"Found {num_mixtures:,} mixtures to process")
343
341
 
344
342
  total_duration = float(sum([mixture.samples for mixture in mixtures])) / SAMPLE_RATE
345
343
 
346
344
  if logging:
347
- log_duration_and_sizes(total_duration=total_duration,
348
- num_classes=mixdb.num_classes,
349
- feature_step_samples=mixdb.feature_step_samples,
350
- feature_parameters=mixdb.feature_parameters,
351
- stride=mixdb.fg_stride,
352
- desc='Estimated')
353
- logger.info(f'Feature shape: '
354
- f'{mixdb.fg_stride} x {mixdb.feature_parameters} '
355
- f'({mixdb.fg_stride * mixdb.feature_parameters} total params)')
356
- logger.info(f'Feature samples: {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)')
357
- logger.info(f'Feature step samples: {mixdb.feature_step_samples} samples ({mixdb.feature_step_ms} ms)')
358
- logger.info('')
345
+ log_duration_and_sizes(
346
+ total_duration=total_duration,
347
+ num_classes=mixdb.num_classes,
348
+ feature_step_samples=mixdb.feature_step_samples,
349
+ feature_parameters=mixdb.feature_parameters,
350
+ stride=mixdb.fg_stride,
351
+ desc="Estimated",
352
+ )
353
+ logger.info(
354
+ f"Feature shape: "
355
+ f"{mixdb.fg_stride} x {mixdb.feature_parameters} "
356
+ f"({mixdb.fg_stride * mixdb.feature_parameters} total params)"
357
+ )
358
+ logger.info(f"Feature samples: {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)")
359
+ logger.info(f"Feature step samples: {mixdb.feature_step_samples} samples ({mixdb.feature_step_ms} ms)")
360
+ logger.info("")
359
361
 
360
362
  # Fill in the details
361
363
  if logging:
362
- logger.info('Generating mixtures')
363
- progress = tqdm(total=num_mixtures, disable=not show_progress)
364
- mixtures = pp_tqdm_imap(_process_mixture, mixtures,
365
- progress=progress,
366
- initializer=_initializer,
367
- initargs=(location, save_mix, save_ft, save_segsnr, test))
364
+ logger.info("Generating mixtures")
365
+ progress = track(total=num_mixtures, disable=not show_progress)
366
+ mixtures = par_track(
367
+ partial(
368
+ _process_mixture,
369
+ location=location,
370
+ save_mix=save_mix,
371
+ save_ft=save_ft,
372
+ save_segsnr=save_segsnr,
373
+ test=test,
374
+ ),
375
+ mixtures,
376
+ progress=progress,
377
+ )
368
378
  progress.close()
369
379
 
370
380
  populate_mixture_table(location, mixtures, test)
@@ -378,77 +388,88 @@ def genmixdb(location: str,
378
388
  noise_samples_percent = (float(used_noise_samples) / float(noise_audio_samples)) * 100
379
389
 
380
390
  if logging:
381
- log_duration_and_sizes(total_duration=total_duration,
382
- num_classes=mixdb.num_classes,
383
- feature_step_samples=mixdb.feature_step_samples,
384
- feature_parameters=mixdb.feature_parameters,
385
- stride=mixdb.fg_stride,
386
- desc='Actual')
387
- logger.info('')
388
- logger.info(f'Used {noise_files_percent:,.0f}% of noise files')
389
- logger.info(f'Used {noise_samples_percent:,.0f}% of noise audio')
390
- logger.info('')
391
+ log_duration_and_sizes(
392
+ total_duration=total_duration,
393
+ num_classes=mixdb.num_classes,
394
+ feature_step_samples=mixdb.feature_step_samples,
395
+ feature_parameters=mixdb.feature_parameters,
396
+ stride=mixdb.fg_stride,
397
+ desc="Actual",
398
+ )
399
+ logger.info("")
400
+ logger.info(f"Used {noise_files_percent:,.0f}% of noise files")
401
+ logger.info(f"Used {noise_samples_percent:,.0f}% of noise audio")
402
+ logger.info("")
391
403
 
392
404
  if not test and save_json:
393
405
  if logging:
394
- logger.info(f'Writing JSON version of database to {location}')
406
+ logger.info(f"Writing JSON version of database to {location}")
395
407
  mixdb = MixtureDatabase(location)
396
408
  mixdb.save()
397
409
 
398
- return mixdb
399
-
400
410
 
401
- def _initializer(location: str, save_mix: bool, save_ft: bool, save_segsnr: bool, test: bool) -> None:
402
- MP_GLOBAL.mixdb = MixtureDatabase(location, test)
403
- MP_GLOBAL.save_mix = save_mix
404
- MP_GLOBAL.save_ft = save_ft
405
- MP_GLOBAL.save_segsnr = save_segsnr
406
-
407
-
408
- def _process_mixture(mixture: Mixture) -> Mixture:
411
+ def _process_mixture(
412
+ mixture: Mixture,
413
+ location: str,
414
+ save_mix: bool,
415
+ save_ft: bool,
416
+ save_segsnr: bool,
417
+ test: bool,
418
+ ) -> Mixture:
409
419
  from typing import Any
410
420
 
421
+ from sonusai.mixture import MixtureDatabase
411
422
  from sonusai.mixture import get_ft
412
423
  from sonusai.mixture import get_segsnr
413
- from sonusai.mixture import get_truth_t
424
+ from sonusai.mixture import get_truth
414
425
  from sonusai.mixture import update_mixture
415
- from sonusai.mixture import write_mixture_data
426
+ from sonusai.mixture import write_cached_data
416
427
  from sonusai.mixture import write_mixture_metadata
417
428
 
418
- with_data = MP_GLOBAL.save_mix or MP_GLOBAL.save_ft
419
- mixdb = MP_GLOBAL.mixdb
429
+ with_data = save_mix or save_ft
430
+ mixdb = MixtureDatabase(location, test)
420
431
 
421
432
  mixture, genmix_data = update_mixture(mixdb, mixture, with_data)
422
433
 
423
434
  if with_data:
424
435
  write_data: list[tuple[str, Any]] = []
425
436
 
426
- if MP_GLOBAL.save_mix:
427
- write_data.append(('targets', genmix_data.targets))
428
- write_data.append(('noise', genmix_data.noise))
429
- write_data.append(('mixture', genmix_data.mixture))
430
-
431
- if MP_GLOBAL.save_ft:
432
- truth_t = get_truth_t(mixdb=mixdb,
433
- mixture=mixture,
434
- targets_audio=genmix_data.targets,
435
- noise_audio=genmix_data.noise,
436
- mixture_audio=genmix_data.mixture)
437
- feature, truth_f = get_ft(mixdb=mixdb,
438
- mixture=mixture,
439
- mixture_audio=genmix_data.mixture,
440
- truth_t=truth_t)
441
- write_data.append(('feature', feature))
442
- write_data.append(('truth_f', truth_f))
443
-
444
- if MP_GLOBAL.save_segsnr:
445
- segsnr = get_segsnr(mixdb=mixdb,
446
- mixture=mixture,
447
- target_audio=genmix_data.target,
448
- noise=genmix_data.noise)
449
- write_data.append(('segsnr', segsnr))
450
-
451
- write_mixture_data(mixdb, mixture, write_data)
437
+ if save_mix:
438
+ write_data.append(("targets", genmix_data.targets))
439
+ write_data.append(("noise", genmix_data.noise))
440
+ write_data.append(("mixture", genmix_data.mixture))
441
+
442
+ if save_ft:
443
+ if genmix_data.targets is None or genmix_data.noise is None or genmix_data.mixture is None:
444
+ raise RuntimeError("Mixture data was not generated properly")
445
+ truth_t = get_truth(
446
+ mixdb=mixdb,
447
+ mixture=mixture,
448
+ targets_audio=genmix_data.targets,
449
+ noise_audio=genmix_data.noise,
450
+ mixture_audio=genmix_data.mixture,
451
+ )
452
+ feature, truth_f = get_ft(
453
+ mixdb=mixdb,
454
+ mixture=mixture,
455
+ mixture_audio=genmix_data.mixture,
456
+ truth_t=truth_t,
457
+ )
458
+ write_data.append(("feature", feature))
459
+ write_data.append(("truth_f", truth_f))
460
+
461
+ if save_segsnr:
462
+ if genmix_data.target is None:
463
+ raise RuntimeError("Target data was not generated properly")
464
+ segsnr = get_segsnr(
465
+ mixdb=mixdb,
466
+ mixture=mixture,
467
+ target_audio=genmix_data.target,
468
+ noise=genmix_data.noise,
469
+ )
470
+ write_data.append(("segsnr", segsnr))
471
+
472
+ write_cached_data(mixdb.location, "mixture", mixture.name, write_data)
452
473
  write_mixture_metadata(mixdb, mixture)
453
474
 
454
475
  return mixture
@@ -478,13 +499,13 @@ def main() -> None:
478
499
  from sonusai.mixture import load_config
479
500
  from sonusai.utils import seconds_to_hms
480
501
 
481
- verbose = args['--verbose']
482
- save_mix = args['--mix']
483
- save_ft = args['--ft']
484
- save_segsnr = args['--segsnr']
485
- dryrun = args['--dryrun']
486
- save_json = args['--json']
487
- location = args['LOC']
502
+ verbose = args["--verbose"]
503
+ save_mix = args["--mix"]
504
+ save_ft = args["--ft"]
505
+ save_segsnr = args["--segsnr"]
506
+ dryrun = args["--dryrun"]
507
+ save_json = args["--json"]
508
+ location = args["LOC"]
488
509
 
489
510
  start_time = time.monotonic()
490
511
 
@@ -493,30 +514,36 @@ def main() -> None:
493
514
 
494
515
  makedirs(location, exist_ok=True)
495
516
 
496
- create_file_handler(join(location, 'genmixdb.log'))
517
+ create_file_handler(join(location, "genmixdb.log"))
497
518
  update_console_handler(verbose)
498
- initial_log_messages('genmixdb')
519
+ initial_log_messages("genmixdb")
499
520
 
500
521
  if dryrun:
501
522
  config = load_config(location)
502
- logger.info('Dryrun configuration:')
523
+ logger.info("Dryrun configuration:")
503
524
  logger.info(yaml.dump(config))
504
525
  return
505
526
 
506
- logger.info(f'Creating mixture database for {location}')
507
- logger.info('')
508
-
509
- genmixdb(location=location,
510
- save_mix=save_mix,
511
- save_ft=save_ft,
512
- save_segsnr=save_segsnr,
513
- show_progress=True,
514
- save_json=save_json)
527
+ logger.info(f"Creating mixture database for {location}")
528
+ logger.info("")
529
+
530
+ try:
531
+ genmixdb(
532
+ location=location,
533
+ save_mix=save_mix,
534
+ save_ft=save_ft,
535
+ save_segsnr=save_segsnr,
536
+ show_progress=True,
537
+ save_json=save_json,
538
+ )
539
+ except Exception as e:
540
+ logger.debug(e)
541
+ raise
515
542
 
516
543
  end_time = time.monotonic()
517
- logger.info(f'Completed in {seconds_to_hms(seconds=end_time - start_time)}')
518
- logger.info('')
544
+ logger.info(f"Completed in {seconds_to_hms(seconds=end_time - start_time)}")
545
+ logger.info("")
519
546
 
520
547
 
521
- if __name__ == '__main__':
548
+ if __name__ == "__main__":
522
549
  main()