sonusai 0.19.8__py3-none-any.whl → 0.19.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/mixture/audio.py CHANGED
@@ -44,44 +44,77 @@ def validate_input_file(input_filepath: str | Path) -> None:
44
44
  raise OSError(f"This installation cannot process .{ext} files")
45
45
 
46
46
 
47
- @lru_cache
48
- def get_sample_rate(name: str | Path) -> int:
47
+ def get_sample_rate(name: str | Path, use_cache: bool = True) -> int:
49
48
  """Get sample rate from audio file
50
49
 
51
50
  :param name: File name
51
+ :param use_cache: If true, use LRU caching
52
52
  :return: Sample rate
53
53
  """
54
+ if use_cache:
55
+ return _get_sample_rate(name)
56
+ return _get_sample_rate.__wrapped__(name)
57
+
58
+
59
+ @lru_cache
60
+ def _get_sample_rate(name: str | Path) -> int:
54
61
  from .soundfile_audio import get_sample_rate
55
62
 
56
63
  return get_sample_rate(name)
57
64
 
58
65
 
59
- @lru_cache
60
- def read_audio(name: str | Path) -> AudioT:
66
+ def read_audio(name: str | Path, use_cache: bool = True) -> AudioT:
61
67
  """Read audio data from a file
62
68
 
63
69
  :param name: File name
70
+ :param use_cache: If true, use LRU caching
64
71
  :return: Array of time domain audio data
65
72
  """
73
+ if use_cache:
74
+ return _read_audio(name)
75
+ return _read_audio.__wrapped__(name)
76
+
77
+
78
+ @lru_cache
79
+ def _read_audio(name: str | Path) -> AudioT:
66
80
  from .soundfile_audio import read_audio
67
81
 
68
82
  return read_audio(name)
69
83
 
70
84
 
71
- @lru_cache
72
- def read_ir(name: str | Path) -> ImpulseResponseData:
85
+ def read_ir(name: str | Path, use_cache: bool = True) -> ImpulseResponseData:
73
86
  """Read impulse response data
74
87
 
75
88
  :param name: File name
89
+ :param use_cache: If true, use LRU caching
76
90
  :return: ImpulseResponseData object
77
91
  """
92
+ if use_cache:
93
+ return _read_ir(name)
94
+ return _read_ir.__wrapped__(name)
95
+
96
+
97
+ @lru_cache
98
+ def _read_ir(name: str | Path) -> ImpulseResponseData:
78
99
  from .soundfile_audio import read_ir
79
100
 
80
101
  return read_ir(name)
81
102
 
82
103
 
104
+ def get_num_samples(name: str | Path, use_cache: bool = True) -> int:
105
+ """Get the number of samples resampled to the SonusAI sample rate in the given file
106
+
107
+ :param name: File name
108
+ :param use_cache: If true, use LRU caching
109
+ :return: number of samples in resampled audio
110
+ """
111
+ if use_cache:
112
+ return _get_num_samples(name)
113
+ return _get_num_samples.__wrapped__(name)
114
+
115
+
83
116
  @lru_cache
84
- def get_num_samples(name: str | Path) -> int:
117
+ def _get_num_samples(name: str | Path) -> int:
85
118
  """Get the number of samples resampled to the SonusAI sample rate in the given file
86
119
 
87
120
  :param name: File name
@@ -119,8 +119,7 @@ def initialize_db(location: str, test: bool = False) -> None:
119
119
  id INTEGER PRIMARY KEY NOT NULL,
120
120
  file_id INTEGER NOT NULL,
121
121
  augmentation TEXT NOT NULL,
122
- FOREIGN KEY(file_id) REFERENCES target_file (id),
123
- UNIQUE(file_id, augmentation))
122
+ FOREIGN KEY(file_id) REFERENCES target_file (id))
124
123
  """)
125
124
 
126
125
  con.execute("""
@@ -389,8 +388,7 @@ def update_mixid_width(location: str, num_mixtures: int, test: bool = False) ->
389
388
  con.close()
390
389
 
391
390
 
392
- def populate_mixture_table(
393
- location: str,
391
+ def generate_mixtures(
394
392
  noise_mix_mode: str,
395
393
  augmented_targets: list[AugmentedTarget],
396
394
  target_files: list[TargetFile],
@@ -403,9 +401,8 @@ def populate_mixture_table(
403
401
  num_classes: int,
404
402
  feature_step_samples: int,
405
403
  num_ir: int,
406
- test: bool = False,
407
- ) -> tuple[int, int]:
408
- """Generate mixtures and populate mixture table"""
404
+ ) -> tuple[int, int, list[Mixture]]:
405
+ """Generate mixtures"""
409
406
  if noise_mix_mode == "exhaustive":
410
407
  func = _exhaustive_noise_mix
411
408
  elif noise_mix_mode == "non-exhaustive":
@@ -415,8 +412,7 @@ def populate_mixture_table(
415
412
  else:
416
413
  raise ValueError(f"invalid noise_mix_mode: {noise_mix_mode}")
417
414
 
418
- used_noise_files, used_noise_samples = func(
419
- location=location,
415
+ return func(
420
416
  augmented_targets=augmented_targets,
421
417
  target_files=target_files,
422
418
  target_augmentations=target_augmentations,
@@ -428,23 +424,76 @@ def populate_mixture_table(
428
424
  num_classes=num_classes,
429
425
  feature_step_samples=feature_step_samples,
430
426
  num_ir=num_ir,
431
- test=test,
432
427
  )
433
428
 
434
- return used_noise_files, used_noise_samples
429
+
430
+ def populate_mixture_table(
431
+ location: str,
432
+ mixtures: list[Mixture],
433
+ test: bool = False,
434
+ logging: bool = False,
435
+ show_progress: bool = False,
436
+ ) -> None:
437
+ """Populate mixture table"""
438
+ from sonusai import logger
439
+ from sonusai.utils import track
440
+
441
+ from .helpers import from_mixture
442
+ from .helpers import from_target
443
+ from .mixdb import db_connection
444
+
445
+ con = db_connection(location=location, readonly=False, test=test)
446
+
447
+ # Populate target table
448
+ if logging:
449
+ logger.info("Populating target table")
450
+ targets: list[tuple[int, str]] = []
451
+ for mixture in mixtures:
452
+ for target in mixture.targets:
453
+ entry = from_target(target)
454
+ if entry not in targets:
455
+ targets.append(entry)
456
+ for target in track(targets, disable=not show_progress):
457
+ con.execute("INSERT INTO target (file_id, augmentation) VALUES (?, ?)", target)
458
+
459
+ # Populate mixture table
460
+ if logging:
461
+ logger.info("Populating mixture table")
462
+ for mixture in track(mixtures, disable=not show_progress):
463
+ m_id = int(mixture.name)
464
+ con.execute(
465
+ """
466
+ INSERT INTO mixture (id, name, noise_file_id, noise_augmentation, noise_offset, noise_snr_gain, random_snr,
467
+ snr, samples, spectral_mask_id, spectral_mask_seed, target_snr_gain)
468
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
469
+ """,
470
+ (m_id + 1, *from_mixture(mixture)),
471
+ )
472
+
473
+ for target in mixture.targets:
474
+ target_id = con.execute(
475
+ """
476
+ SELECT target.id
477
+ FROM target
478
+ WHERE ? = target.file_id AND ? = target.augmentation
479
+ """,
480
+ from_target(target),
481
+ ).fetchone()[0]
482
+ con.execute(
483
+ "INSERT INTO mixture_target (mixture_id, target_id) VALUES (?, ?)",
484
+ (m_id + 1, target_id),
485
+ )
486
+
487
+ con.commit()
488
+ con.close()
435
489
 
436
490
 
437
- def update_mixture_table(location: str, m_id: int, with_data: bool = False, test: bool = False) -> GenMixData:
491
+ def update_mixture(mixdb: MixtureDatabase, mixture: Mixture, with_data: bool = False) -> tuple[Mixture, GenMixData]:
438
492
  """Update mixture record with name and gains"""
439
493
  from .audio import get_next_noise
440
494
  from .augmentation import apply_gain
441
- from .datatypes import GenMixData
442
- from .helpers import from_mixture
443
495
  from .helpers import get_target
444
- from .mixdb import db_connection
445
496
 
446
- mixdb = MixtureDatabase(location, test)
447
- mixture = mixdb.mixture(m_id)
448
497
  mixture, targets_audio = _initialize_targets_audio(mixdb, mixture)
449
498
 
450
499
  noise_audio = _augmented_noise_audio(mixdb, mixture)
@@ -459,29 +508,8 @@ def update_mixture_table(location: str, m_id: int, with_data: bool = False, test
459
508
 
460
509
  mixture.name = f"{int(mixture.name):0{mixdb.mixid_width}}"
461
510
 
462
- con = db_connection(location=location, readonly=False, test=test)
463
- con.execute(
464
- """
465
- UPDATE mixture SET name=?,
466
- noise_file_id=?,
467
- noise_augmentation=?,
468
- noise_offset=?,
469
- noise_snr_gain=?,
470
- random_snr=?,
471
- snr=?,
472
- samples=?,
473
- spectral_mask_id=?,
474
- spectral_mask_seed=?,
475
- target_snr_gain=?
476
- WHERE ? = mixture.id
477
- """,
478
- (*from_mixture(mixture), m_id + 1),
479
- )
480
- con.commit()
481
- con.close()
482
-
483
511
  if not with_data:
484
- return GenMixData()
512
+ return mixture, GenMixData()
485
513
 
486
514
  # Apply SNR gains
487
515
  targets_audio = [apply_gain(audio=target_audio, gain=mixture.target_snr_gain) for target_audio in targets_audio]
@@ -491,7 +519,7 @@ def update_mixture_table(location: str, m_id: int, with_data: bool = False, test
491
519
  target_audio = get_target(mixdb, mixture, targets_audio)
492
520
  mixture_audio = target_audio + noise_audio
493
521
 
494
- return GenMixData(
522
+ return mixture, GenMixData(
495
523
  mixture=mixture_audio,
496
524
  targets=targets_audio,
497
525
  target=target_audio,
@@ -511,7 +539,7 @@ def _augmented_noise_audio(mixdb: MixtureDatabase, mixture: Mixture) -> AudioT:
511
539
  audio = read_audio(noise.name)
512
540
  audio = apply_augmentation(audio, noise_augmentation)
513
541
  if noise_augmentation.ir is not None:
514
- audio = apply_impulse_response(audio, read_ir(mixdb.impulse_response_file(noise_augmentation.ir)))
542
+ audio = apply_impulse_response(audio, read_ir(mixdb.impulse_response_file(noise_augmentation.ir))) # pyright: ignore [reportArgumentType]
515
543
 
516
544
  return audio
517
545
 
@@ -540,7 +568,10 @@ def _initialize_targets_audio(mixdb: MixtureDatabase, mixture: Mixture) -> tuple
540
568
 
541
569
 
542
570
  def _initialize_mixture_gains(
543
- mixdb: MixtureDatabase, mixture: Mixture, target_audio: AudioT, noise_audio: AudioT
571
+ mixdb: MixtureDatabase,
572
+ mixture: Mixture,
573
+ target_audio: AudioT,
574
+ noise_audio: AudioT,
544
575
  ) -> Mixture:
545
576
  import numpy as np
546
577
 
@@ -603,7 +634,6 @@ def _initialize_mixture_gains(
603
634
 
604
635
 
605
636
  def _exhaustive_noise_mix(
606
- location: str,
607
637
  augmented_targets: list[AugmentedTarget],
608
638
  target_files: list[TargetFile],
609
639
  target_augmentations: list[AugmentationRule],
@@ -615,9 +645,8 @@ def _exhaustive_noise_mix(
615
645
  num_classes: int,
616
646
  feature_step_samples: int,
617
647
  num_ir: int,
618
- test: bool = False,
619
- ) -> tuple[int, int]:
620
- """Use every noise/augmentation with every target/augmentation"""
648
+ ) -> tuple[int, int, list[Mixture]]:
649
+ """Use every noise/augmentation with every target/augmentation+interferences/augmentation"""
621
650
  from random import randint
622
651
 
623
652
  import numpy as np
@@ -643,6 +672,8 @@ def _exhaustive_noise_mix(
643
672
  )
644
673
  for mixup in mixups
645
674
  ]
675
+
676
+ mixtures: list[Mixture] = []
646
677
  for noise_file_id in range(len(noise_files)):
647
678
  for noise_augmentation_rule in noise_augmentations:
648
679
  noise_augmentation = augmentation_from_rule(noise_augmentation_rule, num_ir)
@@ -665,10 +696,8 @@ def _exhaustive_noise_mix(
665
696
 
666
697
  for spectral_mask_id in range(len(spectral_masks)):
667
698
  for snr in all_snrs:
668
- _insert_mixture_record(
669
- location=location,
670
- m_id=m_id,
671
- mixture=Mixture(
699
+ mixtures.append(
700
+ Mixture(
672
701
  targets=targets,
673
702
  name=str(m_id),
674
703
  noise=Noise(file_id=noise_file_id + 1, augmentation=noise_augmentation),
@@ -677,19 +706,17 @@ def _exhaustive_noise_mix(
677
706
  snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
678
707
  spectral_mask_id=spectral_mask_id + 1,
679
708
  spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
680
- ),
681
- test=test,
709
+ )
682
710
  )
683
711
  m_id += 1
684
712
 
685
713
  noise_offset = int((noise_offset + target_length) % noise_length)
686
714
  used_noise_samples += target_length
687
715
 
688
- return used_noise_files, used_noise_samples
716
+ return used_noise_files, used_noise_samples, mixtures
689
717
 
690
718
 
691
719
  def _non_exhaustive_noise_mix(
692
- location: str,
693
720
  augmented_targets: list[AugmentedTarget],
694
721
  target_files: list[TargetFile],
695
722
  target_augmentations: list[AugmentationRule],
@@ -701,10 +728,9 @@ def _non_exhaustive_noise_mix(
701
728
  num_classes: int,
702
729
  feature_step_samples: int,
703
730
  num_ir: int,
704
- test: bool = False,
705
- ) -> tuple[int, int]:
706
- """Cycle through every target/augmentation without necessarily using all noise/augmentation combinations
707
- (reduced data set).
731
+ ) -> tuple[int, int, list[Mixture]]:
732
+ """Cycle through every target/augmentation+interferences/augmentation without necessarily using all
733
+ noise/augmentation combinations (reduced data set).
708
734
  """
709
735
  from random import randint
710
736
 
@@ -732,6 +758,8 @@ def _non_exhaustive_noise_mix(
732
758
  )
733
759
  for mixup in mixups
734
760
  ]
761
+
762
+ mixtures: list[Mixture] = []
735
763
  for mixup in augmented_target_indices_for_mixups:
736
764
  for augmented_target_indices in mixup:
737
765
  targets, target_length = _get_target_info(
@@ -763,10 +791,8 @@ def _non_exhaustive_noise_mix(
763
791
 
764
792
  used_noise_files.add(f"{noise_file_id}_{noise_augmentation_id}")
765
793
 
766
- _insert_mixture_record(
767
- location=location,
768
- m_id=m_id,
769
- mixture=Mixture(
794
+ mixtures.append(
795
+ Mixture(
770
796
  targets=targets,
771
797
  name=str(m_id),
772
798
  noise=Noise(file_id=noise_file_id + 1, augmentation=noise_augmentation),
@@ -775,16 +801,14 @@ def _non_exhaustive_noise_mix(
775
801
  snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
776
802
  spectral_mask_id=spectral_mask_id + 1,
777
803
  spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
778
- ),
779
- test=test,
804
+ )
780
805
  )
781
806
  m_id += 1
782
807
 
783
- return len(used_noise_files), used_noise_samples
808
+ return len(used_noise_files), used_noise_samples, mixtures
784
809
 
785
810
 
786
811
  def _non_combinatorial_noise_mix(
787
- location: str,
788
812
  augmented_targets: list[AugmentedTarget],
789
813
  target_files: list[TargetFile],
790
814
  target_augmentations: list[AugmentationRule],
@@ -796,11 +820,10 @@ def _non_combinatorial_noise_mix(
796
820
  num_classes: int,
797
821
  feature_step_samples: int,
798
822
  num_ir: int,
799
- test: bool = False,
800
- ) -> tuple[int, int]:
801
- """Combine a target/augmentation with a single cut of a noise/augmentation non-exhaustively
802
- (each target/augmentation does not use each noise/augmentation). Cut has random start and loop back to
803
- beginning if end of noise/augmentation is reached.
823
+ ) -> tuple[int, int, list[Mixture]]:
824
+ """Combine a target/augmentation+interferences/augmentation with a single cut of a noise/augmentation
825
+ non-exhaustively (each target/augmentation+interferences/augmentation does not use each noise/augmentation).
826
+ Cut has random start and loop back to beginning if end of noise/augmentation is reached.
804
827
  """
805
828
  from random import choice
806
829
  from random import randint
@@ -828,6 +851,8 @@ def _non_combinatorial_noise_mix(
828
851
  )
829
852
  for mixup in mixups
830
853
  ]
854
+
855
+ mixtures: list[Mixture] = []
831
856
  for mixup in augmented_target_indices_for_mixups:
832
857
  for augmented_target_indices in mixup:
833
858
  targets, target_length = _get_target_info(
@@ -857,10 +882,8 @@ def _non_combinatorial_noise_mix(
857
882
 
858
883
  used_noise_files.add(f"{noise_file_id}_{noise_augmentation_id}")
859
884
 
860
- _insert_mixture_record(
861
- location=location,
862
- m_id=m_id,
863
- mixture=Mixture(
885
+ mixtures.append(
886
+ Mixture(
864
887
  targets=targets,
865
888
  name=str(m_id),
866
889
  noise=Noise(file_id=noise_file_id + 1, augmentation=noise_augmentation),
@@ -869,12 +892,11 @@ def _non_combinatorial_noise_mix(
869
892
  snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
870
893
  spectral_mask_id=spectral_mask_id + 1,
871
894
  spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
872
- ),
873
- test=test,
895
+ )
874
896
  )
875
897
  m_id += 1
876
898
 
877
- return len(used_noise_files), used_noise_samples
899
+ return len(used_noise_files), used_noise_samples, mixtures
878
900
 
879
901
 
880
902
  def _get_next_noise_indices(
@@ -973,49 +995,6 @@ def _get_target_info(
973
995
  return mixups, target_length
974
996
 
975
997
 
976
- def _insert_mixture_record(location: str, m_id: int, mixture: Mixture, test: bool = False) -> None:
977
- from .helpers import from_mixture
978
- from .helpers import from_target
979
- from .mixdb import db_connection
980
-
981
- con = db_connection(location=location, readonly=False, test=test)
982
- # Populate target table
983
- for target in mixture.targets:
984
- con.execute(
985
- """
986
- INSERT OR IGNORE INTO target (file_id, augmentation)
987
- VALUES (?, ?)
988
- """,
989
- from_target(target),
990
- )
991
-
992
- # Populate mixture table
993
- con.execute(
994
- """
995
- INSERT INTO mixture (id, name, noise_file_id, noise_augmentation, noise_offset, noise_snr_gain, random_snr,
996
- snr, samples, spectral_mask_id, spectral_mask_seed, target_snr_gain)
997
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
998
- """,
999
- (m_id + 1, *from_mixture(mixture)),
1000
- )
1001
-
1002
- for target in mixture.targets:
1003
- target_id = con.execute(
1004
- """
1005
- SELECT target.id
1006
- FROM target
1007
- WHERE ? = target.file_id AND ? = target.augmentation
1008
- """,
1009
- from_target(target),
1010
- ).fetchone()[0]
1011
- con.execute(
1012
- "INSERT INTO mixture_target (mixture_id, target_id) VALUES (?, ?)",
1013
- (m_id + 1, target_id),
1014
- )
1015
- con.commit()
1016
- con.close()
1017
-
1018
-
1019
998
  def get_all_snrs_from_config(config: dict) -> list[UniversalSNRGenerator]:
1020
999
  from .datatypes import UniversalSNRGenerator
1021
1000
 
@@ -135,14 +135,20 @@ def mixture_all_speech_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> lis
135
135
  return results
136
136
 
137
137
 
138
- def mixture_metadata(mixdb: MixtureDatabase, m_id: int) -> str:
138
+ def mixture_metadata(mixdb: MixtureDatabase, m_id: int | None = None, mixture: Mixture | None = None) -> str:
139
139
  """Create a string of metadata for a Mixture
140
140
 
141
141
  :param mixdb: Mixture database
142
142
  :param m_id: Mixture ID
143
+ :param mixture: Mixture record
143
144
  :return: String of metadata
144
145
  """
145
- mixture = mixdb.mixture(m_id)
146
+ if m_id is not None:
147
+ mixture = mixdb.mixture(m_id)
148
+
149
+ if mixture is None:
150
+ raise ValueError("No mixture specified.")
151
+
146
152
  metadata = ""
147
153
  speech_metadata = mixture_all_speech_metadata(mixdb, mixture)
148
154
  for mi, target in enumerate(mixture.targets):
@@ -173,17 +179,25 @@ def mixture_metadata(mixdb: MixtureDatabase, m_id: int) -> str:
173
179
  return metadata
174
180
 
175
181
 
176
- def write_mixture_metadata(mixdb: MixtureDatabase, m_id: int) -> None:
182
+ def write_mixture_metadata(mixdb: MixtureDatabase, m_id: int | None = None, mixture: Mixture | None = None) -> None:
177
183
  """Write mixture metadata to a text file
178
184
 
179
185
  :param mixdb: Mixture database
180
186
  :param m_id: Mixture ID
187
+ :param mixture: Mixture record
181
188
  """
182
189
  from os.path import join
183
190
 
184
- name = join(mixdb.location, "mixture", mixdb.mixture(m_id).name, "metadata.txt")
191
+ if m_id is not None:
192
+ name = mixdb.mixture(m_id).name
193
+ elif mixture is not None:
194
+ name = mixture.name
195
+ else:
196
+ raise ValueError("No mixture specified.")
197
+
198
+ name = join(mixdb.location, "mixture", name, "metadata.txt")
185
199
  with open(file=name, mode="w") as f:
186
- f.write(mixture_metadata(mixdb, m_id))
200
+ f.write(mixture_metadata(mixdb, m_id, mixture))
187
201
 
188
202
 
189
203
  def from_mixture(
@@ -246,12 +260,13 @@ def to_target(entry: TargetRecord) -> Target:
246
260
  )
247
261
 
248
262
 
249
- def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: list[AudioT]) -> AudioT:
263
+ def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: list[AudioT], use_cache: bool = True) -> AudioT:
250
264
  """Get the augmented target audio data for the given mixture record
251
265
 
252
266
  :param mixdb: Mixture database
253
267
  :param mixture: Mixture record
254
268
  :param targets_audio: List of augmented target audio data (one per target in the mixup)
269
+ :param use_cache: If true, use LRU caching
255
270
  :return: Sum of augmented target audio data
256
271
  """
257
272
  # Apply impulse responses to targets
@@ -265,7 +280,7 @@ def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: list[Aud
265
280
  ir_idx = mixture.targets[idx].augmentation.ir
266
281
  if ir_idx is not None:
267
282
  targets_ir.append(
268
- apply_impulse_response(audio=target, ir=read_ir(mixdb.impulse_response_file(int(ir_idx))))
283
+ apply_impulse_response(audio=target, ir=read_ir(mixdb.impulse_response_file(int(ir_idx)), use_cache)) # pyright: ignore [reportArgumentType]
269
284
  )
270
285
  else:
271
286
  targets_ir.append(target)