sonusai 0.19.6__py3-none-any.whl → 0.19.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. sonusai/__init__.py +1 -1
  2. sonusai/aawscd_probwrite.py +1 -1
  3. sonusai/calc_metric_spenh.py +1 -1
  4. sonusai/genft.py +29 -14
  5. sonusai/genmetrics.py +60 -42
  6. sonusai/genmix.py +41 -29
  7. sonusai/genmixdb.py +54 -62
  8. sonusai/metrics/calc_class_weights.py +1 -3
  9. sonusai/metrics/calc_optimal_thresholds.py +2 -2
  10. sonusai/metrics/calc_phase_distance.py +1 -1
  11. sonusai/metrics/calc_speech.py +6 -6
  12. sonusai/metrics/class_summary.py +6 -15
  13. sonusai/metrics/confusion_matrix_summary.py +11 -27
  14. sonusai/metrics/one_hot.py +3 -3
  15. sonusai/metrics/snr_summary.py +7 -7
  16. sonusai/mixture/__init__.py +2 -17
  17. sonusai/mixture/augmentation.py +5 -6
  18. sonusai/mixture/class_count.py +1 -1
  19. sonusai/mixture/config.py +36 -46
  20. sonusai/mixture/data_io.py +30 -1
  21. sonusai/mixture/datatypes.py +29 -40
  22. sonusai/mixture/db_datatypes.py +1 -1
  23. sonusai/mixture/feature.py +3 -23
  24. sonusai/mixture/generation.py +202 -235
  25. sonusai/mixture/helpers.py +29 -187
  26. sonusai/mixture/mixdb.py +386 -159
  27. sonusai/mixture/soundfile_audio.py +1 -1
  28. sonusai/mixture/sox_audio.py +4 -4
  29. sonusai/mixture/sox_augmentation.py +1 -1
  30. sonusai/mixture/target_class_balancing.py +9 -11
  31. sonusai/mixture/targets.py +23 -20
  32. sonusai/mixture/truth.py +21 -34
  33. sonusai/mixture/truth_functions/__init__.py +6 -0
  34. sonusai/mixture/truth_functions/crm.py +51 -37
  35. sonusai/mixture/truth_functions/energy.py +95 -50
  36. sonusai/mixture/truth_functions/file.py +12 -8
  37. sonusai/mixture/truth_functions/metadata.py +24 -0
  38. sonusai/mixture/truth_functions/metrics.py +28 -0
  39. sonusai/mixture/truth_functions/phoneme.py +4 -5
  40. sonusai/mixture/truth_functions/sed.py +32 -23
  41. sonusai/mixture/truth_functions/target.py +62 -29
  42. sonusai/mkwav.py +20 -19
  43. sonusai/queries/queries.py +9 -15
  44. sonusai/speech/l2arctic.py +6 -2
  45. sonusai/summarize_metric_spenh.py +1 -1
  46. sonusai/utils/__init__.py +1 -0
  47. sonusai/utils/asr_functions/aaware_whisper.py +1 -1
  48. sonusai/utils/audio_devices.py +27 -18
  49. sonusai/utils/docstring.py +6 -3
  50. sonusai/utils/energy_f.py +5 -3
  51. sonusai/utils/human_readable_size.py +6 -6
  52. sonusai/utils/load_object.py +15 -0
  53. sonusai/utils/onnx_utils.py +2 -2
  54. sonusai/utils/print_mixture_details.py +3 -3
  55. {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/METADATA +2 -2
  56. {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/RECORD +58 -56
  57. sonusai/mixture/truth_functions/datatypes.py +0 -37
  58. {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/WHEEL +0 -0
  59. {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/entry_points.txt +0 -0
@@ -1,17 +1,15 @@
1
1
  # ruff: noqa: S608
2
- from .datatypes import AudiosT
3
2
  from .datatypes import AudioT
4
3
  from .datatypes import Augmentation
5
- from .datatypes import AugmentationRules
6
- from .datatypes import AugmentedTargets
4
+ from .datatypes import AugmentationRule
5
+ from .datatypes import AugmentedTarget
7
6
  from .datatypes import GenMixData
8
- from .datatypes import ImpulseResponseFiles
7
+ from .datatypes import ImpulseResponseFile
9
8
  from .datatypes import Mixture
10
- from .datatypes import Mixtures
11
- from .datatypes import NoiseFiles
12
- from .datatypes import SpectralMasks
13
- from .datatypes import TargetFiles
14
- from .datatypes import Targets
9
+ from .datatypes import NoiseFile
10
+ from .datatypes import SpectralMask
11
+ from .datatypes import Target
12
+ from .datatypes import TargetFile
15
13
  from .datatypes import UniversalSNRGenerator
16
14
  from .mixdb import MixtureDatabase
17
15
 
@@ -37,7 +35,7 @@ def initialize_db(location: str, test: bool = False) -> None:
37
35
  CREATE TABLE truth_parameters(
38
36
  id INTEGER PRIMARY KEY NOT NULL,
39
37
  name TEXT NOT NULL,
40
- parameters INTEGER NOT NULL)
38
+ parameters INTEGER)
41
39
  """)
42
40
 
43
41
  con.execute("""
@@ -121,8 +119,8 @@ def initialize_db(location: str, test: bool = False) -> None:
121
119
  id INTEGER PRIMARY KEY NOT NULL,
122
120
  file_id INTEGER NOT NULL,
123
121
  augmentation TEXT NOT NULL,
124
- gain FLOAT,
125
- FOREIGN KEY(file_id) REFERENCES target_file (id))
122
+ FOREIGN KEY(file_id) REFERENCES target_file (id),
123
+ UNIQUE(file_id, augmentation))
126
124
  """)
127
125
 
128
126
  con.execute("""
@@ -165,11 +163,12 @@ def populate_top_table(location: str, config: dict, test: bool = False) -> None:
165
163
  con = db_connection(location=location, readonly=False, test=test)
166
164
  con.execute(
167
165
  """
168
- INSERT INTO top (version, asr_configs, class_balancing, feature, noise_mix_mode, num_classes,
166
+ INSERT INTO top (id, version, asr_configs, class_balancing, feature, noise_mix_mode, num_classes,
169
167
  seed, mixid_width, speaker_metadata_tiers, textgrid_metadata_tiers)
170
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
168
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
171
169
  """,
172
170
  (
171
+ 1,
173
172
  MIXDB_VERSION,
174
173
  json.dumps(config["asr_configs"]),
175
174
  config["class_balancing"],
@@ -271,7 +270,7 @@ def populate_truth_parameters_table(location: str, config: dict, test: bool = Fa
271
270
  con.close()
272
271
 
273
272
 
274
- def populate_target_file_table(location: str, target_files: TargetFiles, test: bool = False) -> None:
273
+ def populate_target_file_table(location: str, target_files: list[TargetFile], test: bool = False) -> None:
275
274
  """Populate target file table"""
276
275
  import json
277
276
  from pathlib import Path
@@ -331,7 +330,7 @@ def populate_target_file_table(location: str, target_files: TargetFiles, test: b
331
330
 
332
331
  # Update textgrid_metadata_tiers in the top table
333
332
  con.execute(
334
- "UPDATE top SET textgrid_metadata_tiers=? WHERE top.id = ?",
333
+ "UPDATE top SET textgrid_metadata_tiers=? WHERE ? = top.id",
335
334
  (json.dumps(sorted(textgrid_metadata_tiers)), 1),
336
335
  )
337
336
 
@@ -339,7 +338,7 @@ def populate_target_file_table(location: str, target_files: TargetFiles, test: b
339
338
  con.close()
340
339
 
341
340
 
342
- def populate_noise_file_table(location: str, noise_files: NoiseFiles, test: bool = False) -> None:
341
+ def populate_noise_file_table(location: str, noise_files: list[NoiseFile], test: bool = False) -> None:
343
342
  """Populate noise file table"""
344
343
  from .mixdb import db_connection
345
344
 
@@ -353,7 +352,7 @@ def populate_noise_file_table(location: str, noise_files: NoiseFiles, test: bool
353
352
 
354
353
 
355
354
  def populate_impulse_response_file_table(
356
- location: str, impulse_response_files: ImpulseResponseFiles, test: bool = False
355
+ location: str, impulse_response_files: list[ImpulseResponseFile], test: bool = False
357
356
  ) -> None:
358
357
  """Populate impulse response file table"""
359
358
  import json
@@ -383,79 +382,73 @@ def update_mixid_width(location: str, num_mixtures: int, test: bool = False) ->
383
382
 
384
383
  con = db_connection(location=location, readonly=False, test=test)
385
384
  con.execute(
386
- "UPDATE top SET mixid_width=? WHERE top.id = ?",
385
+ "UPDATE top SET mixid_width=? WHERE ? = top.id",
387
386
  (max_text_width(num_mixtures), 1),
388
387
  )
389
388
  con.commit()
390
389
  con.close()
391
390
 
392
391
 
393
- def populate_mixture_table(location: str, mixtures: Mixtures, test: bool = False) -> None:
394
- """Populate mixture table"""
395
- from .helpers import from_mixture
396
- from .helpers import from_target
397
- from .mixdb import db_connection
398
-
399
- con = db_connection(location=location, readonly=False, test=test)
400
-
401
- # Populate target table
402
- targets: list[tuple[int, str, float]] = []
403
- for mixture in mixtures:
404
- for target in mixture.targets:
405
- entry = from_target(target)
406
- if entry not in targets:
407
- targets.append(entry)
408
-
409
- con.executemany("INSERT INTO target (file_id, augmentation, gain) VALUES (?, ?, ?)", targets)
410
-
411
- # Populate mixture table
412
- cur = con.cursor()
413
- for mixture in mixtures:
414
- cur.execute(
415
- """
416
- INSERT INTO mixture (name, noise_file_id, noise_augmentation, noise_offset, noise_snr_gain, random_snr,
417
- snr, samples, spectral_mask_id, spectral_mask_seed, target_snr_gain)
418
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
419
- """,
420
- from_mixture(mixture),
421
- )
422
-
423
- mixture_id = cur.lastrowid
424
- for target in mixture.targets:
425
- target_id = con.execute(
426
- """
427
- SELECT target.id
428
- FROM target
429
- WHERE ? = target.file_id AND ? = target.augmentation AND ? = target.gain
430
- """,
431
- from_target(target),
432
- ).fetchone()[0]
433
- con.execute(
434
- "INSERT INTO mixture_target (mixture_id, target_id) VALUES (?, ?)",
435
- (mixture_id, target_id),
436
- )
437
-
438
- con.commit()
439
- con.close()
392
+ def populate_mixture_table(
393
+ location: str,
394
+ noise_mix_mode: str,
395
+ augmented_targets: list[AugmentedTarget],
396
+ target_files: list[TargetFile],
397
+ target_augmentations: list[AugmentationRule],
398
+ noise_files: list[NoiseFile],
399
+ noise_augmentations: list[AugmentationRule],
400
+ spectral_masks: list[SpectralMask],
401
+ all_snrs: list[UniversalSNRGenerator],
402
+ mixups: list[int],
403
+ num_classes: int,
404
+ feature_step_samples: int,
405
+ num_ir: int,
406
+ test: bool = False,
407
+ ) -> tuple[int, int]:
408
+ """Generate mixtures and populate mixture table"""
409
+ if noise_mix_mode == "exhaustive":
410
+ func = _exhaustive_noise_mix
411
+ elif noise_mix_mode == "non-exhaustive":
412
+ func = _non_exhaustive_noise_mix
413
+ elif noise_mix_mode == "non-combinatorial":
414
+ func = _non_combinatorial_noise_mix
415
+ else:
416
+ raise ValueError(f"invalid noise_mix_mode: {noise_mix_mode}")
417
+
418
+ used_noise_files, used_noise_samples = func(
419
+ location=location,
420
+ augmented_targets=augmented_targets,
421
+ target_files=target_files,
422
+ target_augmentations=target_augmentations,
423
+ noise_files=noise_files,
424
+ noise_augmentations=noise_augmentations,
425
+ spectral_masks=spectral_masks,
426
+ all_snrs=all_snrs,
427
+ mixups=mixups,
428
+ num_classes=num_classes,
429
+ feature_step_samples=feature_step_samples,
430
+ num_ir=num_ir,
431
+ test=test,
432
+ )
440
433
 
434
+ return used_noise_files, used_noise_samples
441
435
 
442
- def update_mixture(mixdb: MixtureDatabase, mixture: Mixture, with_data: bool = False) -> tuple[Mixture, GenMixData]:
443
- """Update mixture record with name and gains
444
436
 
445
- :param mixdb: Mixture database
446
- :param mixture: Mixture record
447
- :param with_data: Return audio data
448
- :return: Generated audio data (if requested)
449
- """
437
+ def update_mixture_table(location: str, m_id: int, with_data: bool = False, test: bool = False) -> GenMixData:
438
+ """Update mixture record with name and gains"""
450
439
  from .audio import get_next_noise
451
440
  from .augmentation import apply_gain
452
441
  from .datatypes import GenMixData
442
+ from .helpers import from_mixture
453
443
  from .helpers import get_target
444
+ from .mixdb import db_connection
454
445
 
446
+ mixdb = MixtureDatabase(location, test)
447
+ mixture = mixdb.mixture(m_id)
455
448
  mixture, targets_audio = _initialize_targets_audio(mixdb, mixture)
456
449
 
457
450
  noise_audio = _augmented_noise_audio(mixdb, mixture)
458
- noise_audio = get_next_noise(audio=noise_audio, offset=mixture.noise.offset, length=mixture.samples)
451
+ noise_audio = get_next_noise(audio=noise_audio, offset=mixture.noise_offset, length=mixture.samples)
459
452
 
460
453
  # Apply IR and sum targets audio before initializing the mixture SNR gains
461
454
  target_audio = get_target(mixdb, mixture, targets_audio)
@@ -466,8 +459,29 @@ def update_mixture(mixdb: MixtureDatabase, mixture: Mixture, with_data: bool = F
466
459
 
467
460
  mixture.name = f"{int(mixture.name):0{mixdb.mixid_width}}"
468
461
 
462
+ con = db_connection(location=location, readonly=False, test=test)
463
+ con.execute(
464
+ """
465
+ UPDATE mixture SET name=?,
466
+ noise_file_id=?,
467
+ noise_augmentation=?,
468
+ noise_offset=?,
469
+ noise_snr_gain=?,
470
+ random_snr=?,
471
+ snr=?,
472
+ samples=?,
473
+ spectral_mask_id=?,
474
+ spectral_mask_seed=?,
475
+ target_snr_gain=?
476
+ WHERE ? = mixture.id
477
+ """,
478
+ (*from_mixture(mixture), m_id + 1),
479
+ )
480
+ con.commit()
481
+ con.close()
482
+
469
483
  if not with_data:
470
- return mixture, GenMixData()
484
+ return GenMixData()
471
485
 
472
486
  # Apply SNR gains
473
487
  targets_audio = [apply_gain(audio=target_audio, gain=mixture.target_snr_gain) for target_audio in targets_audio]
@@ -477,7 +491,7 @@ def update_mixture(mixdb: MixtureDatabase, mixture: Mixture, with_data: bool = F
477
491
  target_audio = get_target(mixdb, mixture, targets_audio)
478
492
  mixture_audio = target_audio + noise_audio
479
493
 
480
- return mixture, GenMixData(
494
+ return GenMixData(
481
495
  mixture=mixture_audio,
482
496
  targets=targets_audio,
483
497
  target=target_audio,
@@ -502,7 +516,7 @@ def _augmented_noise_audio(mixdb: MixtureDatabase, mixture: Mixture) -> AudioT:
502
516
  return audio
503
517
 
504
518
 
505
- def _initialize_targets_audio(mixdb: MixtureDatabase, mixture: Mixture) -> tuple[Mixture, AudiosT]:
519
+ def _initialize_targets_audio(mixdb: MixtureDatabase, mixture: Mixture) -> tuple[Mixture, list[AudioT]]:
506
520
  from .augmentation import apply_augmentation
507
521
  from .augmentation import pad_audio_to_length
508
522
 
@@ -517,13 +531,6 @@ def _initialize_targets_audio(mixdb: MixtureDatabase, mixture: Mixture) -> tuple
517
531
  )
518
532
  )
519
533
 
520
- # target_gain is used to back out the gain augmentation in order to return the target audio
521
- # to its normalized level when calculating truth (if needed).
522
- if target.augmentation.gain is not None:
523
- target.gain = round(10 ** (target.augmentation.gain / 20), ndigits=5)
524
- else:
525
- target.gain = 1
526
-
527
534
  mixture.samples = max([len(item) for item in targets_audio])
528
535
 
529
536
  for idx in range(len(targets_audio)):
@@ -540,14 +547,11 @@ def _initialize_mixture_gains(
540
547
  from sonusai.utils import asl_p56
541
548
  from sonusai.utils import db_to_linear
542
549
 
543
- if mixture.snr < -96:
550
+ if mixture.is_noise_only:
544
551
  # Special case for zeroing out target data
545
552
  mixture.target_snr_gain = 0
546
553
  mixture.noise_snr_gain = 1
547
- # Setting target_gain to zero will cause the truth to be all zeros.
548
- for target in mixture.targets:
549
- target.gain = 0
550
- elif mixture.snr > 96:
554
+ elif mixture.is_target_only:
551
555
  # Special case for zeroing out noise data
552
556
  mixture.target_snr_gain = 1
553
557
  mixture.noise_snr_gain = 0
@@ -598,97 +602,21 @@ def _initialize_mixture_gains(
598
602
  return mixture
599
603
 
600
604
 
601
- def generate_mixtures(
602
- noise_mix_mode: str,
603
- augmented_targets: AugmentedTargets,
604
- target_files: TargetFiles,
605
- target_augmentations: AugmentationRules,
606
- noise_files: NoiseFiles,
607
- noise_augmentations: AugmentationRules,
608
- spectral_masks: SpectralMasks,
609
- all_snrs: list[UniversalSNRGenerator],
610
- mixups: list[int],
611
- num_classes: int,
612
- feature_step_samples: int,
613
- num_ir: int,
614
- ) -> tuple[int, int, Mixtures]:
615
- """Generate mixtures
616
-
617
- :param noise_mix_mode: Noise mix mode
618
- :param augmented_targets: List of augmented targets
619
- :param target_files: List of target files
620
- :param target_augmentations: List of target augmentations
621
- :param noise_files: List of noise files
622
- :param noise_augmentations: List of noise augmentations
623
- :param spectral_masks: List of spectral masks
624
- :param all_snrs: List of all SNRs
625
- :param mixups: List of mixup values
626
- :param num_classes: Number of classes
627
- :param feature_step_samples: Number of samples in a feature step
628
- :param num_ir: Number of impulse response files
629
- :return: (Number of noise files used, number of noise samples used, list of mixture records)
630
- """
631
- if noise_mix_mode == "exhaustive":
632
- return _exhaustive_noise_mix(
633
- augmented_targets=augmented_targets,
634
- target_files=target_files,
635
- target_augmentations=target_augmentations,
636
- noise_files=noise_files,
637
- noise_augmentations=noise_augmentations,
638
- spectral_masks=spectral_masks,
639
- all_snrs=all_snrs,
640
- mixups=mixups,
641
- num_classes=num_classes,
642
- feature_step_samples=feature_step_samples,
643
- num_ir=num_ir,
644
- )
645
-
646
- if noise_mix_mode == "non-exhaustive":
647
- return _non_exhaustive_noise_mix(
648
- augmented_targets=augmented_targets,
649
- target_files=target_files,
650
- target_augmentations=target_augmentations,
651
- noise_files=noise_files,
652
- noise_augmentations=noise_augmentations,
653
- spectral_masks=spectral_masks,
654
- all_snrs=all_snrs,
655
- mixups=mixups,
656
- num_classes=num_classes,
657
- feature_step_samples=feature_step_samples,
658
- num_ir=num_ir,
659
- )
660
-
661
- if noise_mix_mode == "non-combinatorial":
662
- return _non_combinatorial_noise_mix(
663
- augmented_targets=augmented_targets,
664
- target_files=target_files,
665
- target_augmentations=target_augmentations,
666
- noise_files=noise_files,
667
- noise_augmentations=noise_augmentations,
668
- spectral_masks=spectral_masks,
669
- all_snrs=all_snrs,
670
- mixups=mixups,
671
- num_classes=num_classes,
672
- feature_step_samples=feature_step_samples,
673
- num_ir=num_ir,
674
- )
675
-
676
- raise ValueError(f"invalid noise_mix_mode: {noise_mix_mode}")
677
-
678
-
679
605
  def _exhaustive_noise_mix(
680
- augmented_targets: AugmentedTargets,
681
- target_files: TargetFiles,
682
- target_augmentations: AugmentationRules,
683
- noise_files: NoiseFiles,
684
- noise_augmentations: AugmentationRules,
685
- spectral_masks: SpectralMasks,
606
+ location: str,
607
+ augmented_targets: list[AugmentedTarget],
608
+ target_files: list[TargetFile],
609
+ target_augmentations: list[AugmentationRule],
610
+ noise_files: list[NoiseFile],
611
+ noise_augmentations: list[AugmentationRule],
612
+ spectral_masks: list[SpectralMask],
686
613
  all_snrs: list[UniversalSNRGenerator],
687
614
  mixups: list[int],
688
615
  num_classes: int,
689
616
  feature_step_samples: int,
690
617
  num_ir: int,
691
- ) -> tuple[int, int, Mixtures]:
618
+ test: bool = False,
619
+ ) -> tuple[int, int]:
692
620
  """Use every noise/augmentation with every target/augmentation"""
693
621
  from random import randint
694
622
 
@@ -697,12 +625,10 @@ def _exhaustive_noise_mix(
697
625
  from .augmentation import augmentation_from_rule
698
626
  from .augmentation import estimate_augmented_length_from_length
699
627
  from .datatypes import Mixture
700
- from .datatypes import Mixtures
701
628
  from .datatypes import Noise
702
629
  from .datatypes import UniversalSNR
703
630
  from .targets import get_augmented_target_ids_for_mixup
704
631
 
705
- mixtures: Mixtures = []
706
632
  m_id = 0
707
633
  used_noise_files = len(noise_files) * len(noise_augmentations)
708
634
  used_noise_samples = 0
@@ -739,42 +665,44 @@ def _exhaustive_noise_mix(
739
665
 
740
666
  for spectral_mask_id in range(len(spectral_masks)):
741
667
  for snr in all_snrs:
742
- mixtures.append(
743
- Mixture(
668
+ _insert_mixture_record(
669
+ location=location,
670
+ m_id=m_id,
671
+ mixture=Mixture(
744
672
  targets=targets,
745
673
  name=str(m_id),
746
- noise=Noise(
747
- file_id=noise_file_id + 1,
748
- augmentation=noise_augmentation,
749
- offset=noise_offset,
750
- ),
674
+ noise=Noise(file_id=noise_file_id + 1, augmentation=noise_augmentation),
675
+ noise_offset=noise_offset,
751
676
  samples=target_length,
752
677
  snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
753
678
  spectral_mask_id=spectral_mask_id + 1,
754
679
  spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
755
- )
680
+ ),
681
+ test=test,
756
682
  )
757
683
  m_id += 1
758
684
 
759
685
  noise_offset = int((noise_offset + target_length) % noise_length)
760
686
  used_noise_samples += target_length
761
687
 
762
- return used_noise_files, used_noise_samples, mixtures
688
+ return used_noise_files, used_noise_samples
763
689
 
764
690
 
765
691
  def _non_exhaustive_noise_mix(
766
- augmented_targets: AugmentedTargets,
767
- target_files: TargetFiles,
768
- target_augmentations: AugmentationRules,
769
- noise_files: NoiseFiles,
770
- noise_augmentations: AugmentationRules,
771
- spectral_masks: SpectralMasks,
692
+ location: str,
693
+ augmented_targets: list[AugmentedTarget],
694
+ target_files: list[TargetFile],
695
+ target_augmentations: list[AugmentationRule],
696
+ noise_files: list[NoiseFile],
697
+ noise_augmentations: list[AugmentationRule],
698
+ spectral_masks: list[SpectralMask],
772
699
  all_snrs: list[UniversalSNRGenerator],
773
700
  mixups: list[int],
774
701
  num_classes: int,
775
702
  feature_step_samples: int,
776
703
  num_ir: int,
777
- ) -> tuple[int, int, Mixtures]:
704
+ test: bool = False,
705
+ ) -> tuple[int, int]:
778
706
  """Cycle through every target/augmentation without necessarily using all noise/augmentation combinations
779
707
  (reduced data set).
780
708
  """
@@ -783,12 +711,10 @@ def _non_exhaustive_noise_mix(
783
711
  import numpy as np
784
712
 
785
713
  from .datatypes import Mixture
786
- from .datatypes import Mixtures
787
714
  from .datatypes import Noise
788
715
  from .datatypes import UniversalSNR
789
716
  from .targets import get_augmented_target_ids_for_mixup
790
717
 
791
- mixtures: Mixtures = []
792
718
  m_id = 0
793
719
  used_noise_files = set()
794
720
  used_noise_samples = 0
@@ -837,39 +763,41 @@ def _non_exhaustive_noise_mix(
837
763
 
838
764
  used_noise_files.add(f"{noise_file_id}_{noise_augmentation_id}")
839
765
 
840
- mixtures.append(
841
- Mixture(
766
+ _insert_mixture_record(
767
+ location=location,
768
+ m_id=m_id,
769
+ mixture=Mixture(
842
770
  targets=targets,
843
771
  name=str(m_id),
844
- noise=Noise(
845
- file_id=noise_file_id + 1,
846
- augmentation=noise_augmentation,
847
- offset=noise_offset,
848
- ),
772
+ noise=Noise(file_id=noise_file_id + 1, augmentation=noise_augmentation),
773
+ noise_offset=noise_offset,
849
774
  samples=target_length,
850
775
  snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
851
776
  spectral_mask_id=spectral_mask_id + 1,
852
777
  spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
853
- )
778
+ ),
779
+ test=test,
854
780
  )
855
781
  m_id += 1
856
782
 
857
- return len(used_noise_files), used_noise_samples, mixtures
783
+ return len(used_noise_files), used_noise_samples
858
784
 
859
785
 
860
786
  def _non_combinatorial_noise_mix(
861
- augmented_targets: AugmentedTargets,
862
- target_files: TargetFiles,
863
- target_augmentations: AugmentationRules,
864
- noise_files: NoiseFiles,
865
- noise_augmentations: AugmentationRules,
866
- spectral_masks: SpectralMasks,
787
+ location: str,
788
+ augmented_targets: list[AugmentedTarget],
789
+ target_files: list[TargetFile],
790
+ target_augmentations: list[AugmentationRule],
791
+ noise_files: list[NoiseFile],
792
+ noise_augmentations: list[AugmentationRule],
793
+ spectral_masks: list[SpectralMask],
867
794
  all_snrs: list[UniversalSNRGenerator],
868
795
  mixups: list[int],
869
796
  num_classes: int,
870
797
  feature_step_samples: int,
871
798
  num_ir: int,
872
- ) -> tuple[int, int, Mixtures]:
799
+ test: bool = False,
800
+ ) -> tuple[int, int]:
873
801
  """Combine a target/augmentation with a single cut of a noise/augmentation non-exhaustively
874
802
  (each target/augmentation does not use each noise/augmentation). Cut has random start and loop back to
875
803
  beginning if end of noise/augmentation is reached.
@@ -880,12 +808,10 @@ def _non_combinatorial_noise_mix(
880
808
  import numpy as np
881
809
 
882
810
  from .datatypes import Mixture
883
- from .datatypes import Mixtures
884
811
  from .datatypes import Noise
885
812
  from .datatypes import UniversalSNR
886
813
  from .targets import get_augmented_target_ids_for_mixup
887
814
 
888
- mixtures: Mixtures = []
889
815
  m_id = 0
890
816
  used_noise_files = set()
891
817
  used_noise_samples = 0
@@ -931,31 +857,31 @@ def _non_combinatorial_noise_mix(
931
857
 
932
858
  used_noise_files.add(f"{noise_file_id}_{noise_augmentation_id}")
933
859
 
934
- mixtures.append(
935
- Mixture(
860
+ _insert_mixture_record(
861
+ location=location,
862
+ m_id=m_id,
863
+ mixture=Mixture(
936
864
  targets=targets,
937
865
  name=str(m_id),
938
- noise=Noise(
939
- file_id=noise_file_id + 1,
940
- augmentation=noise_augmentation,
941
- offset=choice(range(noise_length)), # noqa: S311
942
- ),
866
+ noise=Noise(file_id=noise_file_id + 1, augmentation=noise_augmentation),
867
+ noise_offset=choice(range(noise_length)), # noqa: S311
943
868
  samples=target_length,
944
869
  snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
945
870
  spectral_mask_id=spectral_mask_id + 1,
946
871
  spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
947
- )
872
+ ),
873
+ test=test,
948
874
  )
949
875
  m_id += 1
950
876
 
951
- return len(used_noise_files), used_noise_samples, mixtures
877
+ return len(used_noise_files), used_noise_samples
952
878
 
953
879
 
954
880
  def _get_next_noise_indices(
955
881
  noise_file_id: int | None,
956
882
  noise_augmentation_id: int | None,
957
- noise_files: NoiseFiles,
958
- noise_augmentations: AugmentationRules,
883
+ noise_files: list[NoiseFile],
884
+ noise_augmentations: list[AugmentationRule],
959
885
  num_ir: int,
960
886
  ) -> tuple[int, int, Augmentation, int]:
961
887
  from .augmentation import augmentation_from_rule
@@ -984,8 +910,8 @@ def _get_next_noise_offset(
984
910
  noise_augmentation_id: int | None,
985
911
  noise_offset: int | None,
986
912
  target_length: int,
987
- noise_files: NoiseFiles,
988
- noise_augmentations: AugmentationRules,
913
+ noise_files: list[NoiseFile],
914
+ noise_augmentations: list[AugmentationRule],
989
915
  num_ir: int,
990
916
  ) -> tuple[int, int, Augmentation, int]:
991
917
  from .augmentation import augmentation_from_rule
@@ -1018,18 +944,16 @@ def _get_next_noise_offset(
1018
944
 
1019
945
  def _get_target_info(
1020
946
  augmented_target_ids: list[int],
1021
- augmented_targets: AugmentedTargets,
1022
- target_files: TargetFiles,
1023
- target_augmentations: AugmentationRules,
947
+ augmented_targets: list[AugmentedTarget],
948
+ target_files: list[TargetFile],
949
+ target_augmentations: list[AugmentationRule],
1024
950
  feature_step_samples: int,
1025
951
  num_ir: int,
1026
- ) -> tuple[Targets, int]:
952
+ ) -> tuple[list[Target], int]:
1027
953
  from .augmentation import augmentation_from_rule
1028
954
  from .augmentation import estimate_augmented_length_from_length
1029
- from .datatypes import Target
1030
- from .datatypes import Targets
1031
955
 
1032
- mixups: Targets = []
956
+ mixups: list[Target] = []
1033
957
  target_length = 0
1034
958
  for idx in augmented_target_ids:
1035
959
  tfi = augmented_targets[idx].target_id
@@ -1049,6 +973,49 @@ def _get_target_info(
1049
973
  return mixups, target_length
1050
974
 
1051
975
 
976
+ def _insert_mixture_record(location: str, m_id: int, mixture: Mixture, test: bool = False) -> None:
977
+ from .helpers import from_mixture
978
+ from .helpers import from_target
979
+ from .mixdb import db_connection
980
+
981
+ con = db_connection(location=location, readonly=False, test=test)
982
+ # Populate target table
983
+ for target in mixture.targets:
984
+ con.execute(
985
+ """
986
+ INSERT OR IGNORE INTO target (file_id, augmentation)
987
+ VALUES (?, ?)
988
+ """,
989
+ from_target(target),
990
+ )
991
+
992
+ # Populate mixture table
993
+ con.execute(
994
+ """
995
+ INSERT INTO mixture (id, name, noise_file_id, noise_augmentation, noise_offset, noise_snr_gain, random_snr,
996
+ snr, samples, spectral_mask_id, spectral_mask_seed, target_snr_gain)
997
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
998
+ """,
999
+ (m_id + 1, *from_mixture(mixture)),
1000
+ )
1001
+
1002
+ for target in mixture.targets:
1003
+ target_id = con.execute(
1004
+ """
1005
+ SELECT target.id
1006
+ FROM target
1007
+ WHERE ? = target.file_id AND ? = target.augmentation
1008
+ """,
1009
+ from_target(target),
1010
+ ).fetchone()[0]
1011
+ con.execute(
1012
+ "INSERT INTO mixture_target (mixture_id, target_id) VALUES (?, ?)",
1013
+ (m_id + 1, target_id),
1014
+ )
1015
+ con.commit()
1016
+ con.close()
1017
+
1018
+
1052
1019
  def get_all_snrs_from_config(config: dict) -> list[UniversalSNRGenerator]:
1053
1020
  from .datatypes import UniversalSNRGenerator
1054
1021
 
@@ -1073,7 +1040,7 @@ def _get_textgrid_tiers_from_target_file(target_file: str) -> list[str]:
1073
1040
  return sorted(tg.tierNames)
1074
1041
 
1075
1042
 
1076
- def _populate_speaker_table(location: str, target_files: TargetFiles, test: bool = False) -> None:
1043
+ def _populate_speaker_table(location: str, target_files: list[TargetFile], test: bool = False) -> None:
1077
1044
  """Populate speaker table"""
1078
1045
  import json
1079
1046
  from pathlib import Path
@@ -1122,7 +1089,7 @@ def _populate_speaker_table(location: str, target_files: TargetFiles, test: bool
1122
1089
  if description[0] not in ("id", "parent")
1123
1090
  ]
1124
1091
  con.execute(
1125
- "UPDATE top SET speaker_metadata_tiers=? WHERE top.id = ?",
1092
+ "UPDATE top SET speaker_metadata_tiers=? WHERE ? = top.id",
1126
1093
  (json.dumps(tiers), 1),
1127
1094
  )
1128
1095
 
@@ -1133,7 +1100,7 @@ def _populate_speaker_table(location: str, target_files: TargetFiles, test: bool
1133
1100
  con.close()
1134
1101
 
1135
1102
 
1136
- def _populate_truth_config_table(location: str, target_files: TargetFiles, test: bool = False) -> None:
1103
+ def _populate_truth_config_table(location: str, target_files: list[TargetFile], test: bool = False) -> None:
1137
1104
  """Populate truth_config table"""
1138
1105
  import json
1139
1106