sonusai 0.19.9__py3-none-any.whl → 0.20.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. sonusai/calc_metric_spenh.py +265 -233
  2. sonusai/data/genmixdb.yml +4 -2
  3. sonusai/data/silero_vad_v5.1.jit +0 -0
  4. sonusai/data/silero_vad_v5.1.onnx +0 -0
  5. sonusai/doc/doc.py +14 -0
  6. sonusai/genft.py +1 -1
  7. sonusai/genmetrics.py +15 -18
  8. sonusai/genmix.py +1 -1
  9. sonusai/genmixdb.py +30 -52
  10. sonusai/ir_metric.py +555 -0
  11. sonusai/metrics_summary.py +322 -0
  12. sonusai/mixture/__init__.py +6 -2
  13. sonusai/mixture/audio.py +139 -15
  14. sonusai/mixture/augmentation.py +199 -84
  15. sonusai/mixture/config.py +9 -4
  16. sonusai/mixture/constants.py +0 -1
  17. sonusai/mixture/datatypes.py +19 -10
  18. sonusai/mixture/generation.py +52 -64
  19. sonusai/mixture/helpers.py +38 -26
  20. sonusai/mixture/ir_delay.py +63 -0
  21. sonusai/mixture/mixdb.py +190 -46
  22. sonusai/mixture/targets.py +3 -6
  23. sonusai/mixture/truth_functions/energy.py +9 -5
  24. sonusai/mixture/truth_functions/metrics.py +1 -1
  25. sonusai/mkwav.py +1 -1
  26. sonusai/onnx_predict.py +1 -1
  27. sonusai/queries/queries.py +1 -1
  28. sonusai/utils/__init__.py +2 -0
  29. sonusai/utils/asr.py +1 -1
  30. sonusai/utils/load_object.py +8 -2
  31. sonusai/utils/stratified_shuffle_split.py +1 -1
  32. sonusai/utils/temp_seed.py +13 -0
  33. {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/METADATA +2 -2
  34. {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/RECORD +36 -35
  35. {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/WHEEL +1 -1
  36. sonusai/mixture/soundfile_audio.py +0 -130
  37. sonusai/mixture/sox_audio.py +0 -476
  38. sonusai/mixture/sox_augmentation.py +0 -136
  39. sonusai/mixture/torchaudio_audio.py +0 -106
  40. sonusai/mixture/torchaudio_augmentation.py +0 -109
  41. {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py CHANGED
@@ -61,7 +61,7 @@ def db_connection(
61
61
  if not create and readonly:
62
62
  name += "?mode=ro"
63
63
 
64
- connection = sqlite3.connect("file:" + name, uri=True)
64
+ connection = sqlite3.connect("file:" + name, uri=True, timeout=20)
65
65
 
66
66
  if verbose:
67
67
  connection.set_trace_callback(print)
@@ -84,7 +84,7 @@ class SQLiteContextManager:
84
84
 
85
85
 
86
86
  class MixtureDatabase:
87
- def __init__(self, location: str, test: bool = False) -> None:
87
+ def __init__(self, location: str, test: bool = False, use_cache: bool = True) -> None:
88
88
  import json
89
89
  from os.path import exists
90
90
 
@@ -92,6 +92,7 @@ class MixtureDatabase:
92
92
 
93
93
  self.location = location
94
94
  self.test = test
95
+ self.use_cache = use_cache
95
96
 
96
97
  if not exists(db_file(self.location, self.test)):
97
98
  raise OSError(f"Could not find mixture database in {self.location}")
@@ -121,7 +122,7 @@ class MixtureDatabase:
121
122
  class_weights_threshold=self.class_weights_thresholds,
122
123
  feature=self.feature,
123
124
  impulse_response_files=self.impulse_response_files,
124
- mixtures=self.mixtures,
125
+ mixtures=self.mixtures(),
125
126
  noise_mix_mode=self.noise_mix_mode,
126
127
  noise_files=self.noise_files,
127
128
  num_classes=self.num_classes,
@@ -254,6 +255,16 @@ class MixtureDatabase:
254
255
  "Predicted rating of overall quality of mixture versus true targets",
255
256
  ),
256
257
  MetricDoc("Mixture Metrics", "ssnr", "Segmental SNR"),
258
+ MetricDoc("Mixture Metrics", "mxdco", "Mixture DC offset"),
259
+ MetricDoc("Mixture Metrics", "mxmin", "Mixture min level"),
260
+ MetricDoc("Mixture Metrics", "mxmax", "Mixture max levl"),
261
+ MetricDoc("Mixture Metrics", "mxpkdb", "Mixture Pk lev dB"),
262
+ MetricDoc("Mixture Metrics", "mxlrms", "Mixture RMS lev dB"),
263
+ MetricDoc("Mixture Metrics", "mxpkr", "Mixture RMS Pk dB"),
264
+ MetricDoc("Mixture Metrics", "mxtr", "Mixture RMS Tr dB"),
265
+ MetricDoc("Mixture Metrics", "mxcr", "Mixture Crest factor"),
266
+ MetricDoc("Mixture Metrics", "mxfl", "Mixture Flat factor"),
267
+ MetricDoc("Mixture Metrics", "mxpkc", "Mixture Pk count"),
257
268
  MetricDoc("Mixture Metrics", "mxtdco", "Mixture target DC offset"),
258
269
  MetricDoc("Mixture Metrics", "mxtmin", "Mixture target min level"),
259
270
  MetricDoc("Mixture Metrics", "mxtmax", "Mixture target max levl"),
@@ -488,7 +499,7 @@ class MixtureDatabase:
488
499
  return truth_configs
489
500
 
490
501
  def target_truth_configs(self, t_id: int) -> TruthConfigs:
491
- return _target_truth_configs(self.db, t_id)
502
+ return _target_truth_configs(self.db, t_id, self.use_cache)
492
503
 
493
504
  @cached_property
494
505
  def random_snrs(self) -> list[float]:
@@ -556,7 +567,7 @@ class MixtureDatabase:
556
567
  :param sm_id: Spectral mask ID
557
568
  :return: Spectral mask
558
569
  """
559
- return _spectral_mask(self.db, sm_id)
570
+ return _spectral_mask(self.db, sm_id, self.use_cache)
560
571
 
561
572
  @cached_property
562
573
  def target_files(self) -> list[TargetFile]:
@@ -619,7 +630,7 @@ class MixtureDatabase:
619
630
  :param t_id: Target file ID
620
631
  :return: Target file
621
632
  """
622
- return _target_file(self.db, t_id)
633
+ return _target_file(self.db, t_id, self.use_cache)
623
634
 
624
635
  @cached_property
625
636
  def num_target_files(self) -> int:
@@ -657,7 +668,7 @@ class MixtureDatabase:
657
668
  :param n_id: Noise file ID
658
669
  :return: Noise file
659
670
  """
660
- return _noise_file(self.db, n_id)
671
+ return _noise_file(self.db, n_id, self.use_cache)
661
672
 
662
673
  @cached_property
663
674
  def num_noise_files(self) -> int:
@@ -680,7 +691,7 @@ class MixtureDatabase:
680
691
 
681
692
  with self.db() as c:
682
693
  return [
683
- ImpulseResponseFile(impulse_response[1], json.loads(impulse_response[2]))
694
+ ImpulseResponseFile(impulse_response[1], json.loads(impulse_response[2]), impulse_response[3])
684
695
  for impulse_response in c.execute(
685
696
  "SELECT impulse_response_file.* FROM impulse_response_file"
686
697
  ).fetchall()
@@ -699,14 +710,24 @@ class MixtureDatabase:
699
710
  ]
700
711
 
701
712
  def impulse_response_file(self, ir_id: int | None) -> str | None:
702
- """Get impulse response file with ID from db
713
+ """Get impulse response file name with ID from db
714
+
715
+ :param ir_id: Impulse response file ID
716
+ :return: Impulse response file name
717
+ """
718
+ if ir_id is None:
719
+ return None
720
+ return _impulse_response_file(self.db, ir_id, self.use_cache)
721
+
722
+ def impulse_response_delay(self, ir_id: int | None) -> int | None:
723
+ """Get impulse response delay with ID from db
703
724
 
704
725
  :param ir_id: Impulse response file ID
705
- :return: Noise
726
+ :return: Impulse response delay
706
727
  """
707
728
  if ir_id is None:
708
729
  return None
709
- return _impulse_response_file(self.db, ir_id)
730
+ return _impulse_response_delay(self.db, ir_id, self.use_cache)
710
731
 
711
732
  @cached_property
712
733
  def num_impulse_response_files(self) -> int:
@@ -717,7 +738,6 @@ class MixtureDatabase:
717
738
  with self.db() as c:
718
739
  return int(c.execute("SELECT count(impulse_response_file.id) FROM impulse_response_file").fetchone()[0])
719
740
 
720
- @cached_property
721
741
  def mixtures(self) -> list[Mixture]:
722
742
  """Get mixtures from db
723
743
 
@@ -760,7 +780,7 @@ class MixtureDatabase:
760
780
  :param m_id: Zero-based mixture ID
761
781
  :return: Mixture record
762
782
  """
763
- return _mixture(self.db, m_id)
783
+ return _mixture(self.db, m_id, self.use_cache)
764
784
 
765
785
  @cached_property
766
786
  def mixid_width(self) -> int:
@@ -805,7 +825,7 @@ class MixtureDatabase:
805
825
  """
806
826
  from .audio import read_audio
807
827
 
808
- return read_audio(self.target_file(t_id).name)
828
+ return read_audio(self.target_file(t_id).name, self.use_cache)
809
829
 
810
830
  def augmented_noise_audio(self, mixture: Mixture) -> AudioT:
811
831
  """Get augmented noise audio
@@ -814,18 +834,11 @@ class MixtureDatabase:
814
834
  :return: Augmented noise audio
815
835
  """
816
836
  from .audio import read_audio
817
- from .audio import read_ir
818
837
  from .augmentation import apply_augmentation
819
- from .augmentation import apply_impulse_response
820
838
 
821
839
  noise = self.noise_file(mixture.noise.file_id)
822
- audio = read_audio(noise.name)
823
- audio = apply_augmentation(audio, mixture.noise.augmentation)
824
- if mixture.noise.augmentation.ir is not None:
825
- audio = apply_impulse_response(
826
- audio,
827
- read_ir(self.impulse_response_file(mixture.noise.augmentation.ir)),
828
- )
840
+ audio = read_audio(noise.name, self.use_cache)
841
+ audio = apply_augmentation(self, audio, mixture.noise.augmentation.pre)
829
842
 
830
843
  return audio
831
844
 
@@ -859,8 +872,9 @@ class MixtureDatabase:
859
872
  for target in mixture.targets:
860
873
  target_audio = self.read_target_audio(target.file_id)
861
874
  target_audio = apply_augmentation(
875
+ mixdb=self,
862
876
  audio=target_audio,
863
- augmentation=target.augmentation,
877
+ augmentation=target.augmentation.pre,
864
878
  frame_length=self.feature_step_samples,
865
879
  )
866
880
  target_audio = apply_gain(audio=target_audio, gain=mixture.target_snr_gain)
@@ -1119,8 +1133,7 @@ class MixtureDatabase:
1119
1133
  offsets = range(0, mixture.samples, self.ft_config.overlap)
1120
1134
  if len(target_energy) != len(offsets):
1121
1135
  raise ValueError(
1122
- f"Number of frames in energy, {len(target_energy)},"
1123
- f" is not number of frames in mixture, {len(offsets)}"
1136
+ f"Number of frames in energy, {len(target_energy)}, is not number of frames in mixture, {len(offsets)}"
1124
1137
  )
1125
1138
 
1126
1139
  for idx, offset in enumerate(offsets):
@@ -1332,7 +1345,7 @@ class MixtureDatabase:
1332
1345
  return sorted(set(self.speaker_metadata_tiers + self.textgrid_metadata_tiers))
1333
1346
 
1334
1347
  def speaker(self, s_id: int | None, tier: str) -> str | None:
1335
- return _speaker(self.db, s_id, tier)
1348
+ return _speaker(self.db, s_id, tier, self.use_cache)
1336
1349
 
1337
1350
  def speech_metadata(self, tier: str) -> list[str]:
1338
1351
  from .helpers import get_textgrid_tier_from_target_file
@@ -1370,11 +1383,11 @@ class MixtureDatabase:
1370
1383
  # Check for tempo augmentation and adjust Interval start and end data as needed
1371
1384
  entries = []
1372
1385
  for entry in data:
1373
- if target.augmentation.tempo is not None:
1386
+ if target.augmentation.pre.tempo is not None:
1374
1387
  entries.append(
1375
1388
  Interval(
1376
- entry.start / target.augmentation.tempo,
1377
- entry.end / target.augmentation.tempo,
1389
+ entry.start / target.augmentation.pre.tempo,
1390
+ entry.end / target.augmentation.pre.tempo,
1378
1391
  entry.label,
1379
1392
  )
1380
1393
  )
@@ -1464,7 +1477,7 @@ class MixtureDatabase:
1464
1477
 
1465
1478
  return sorted(result)
1466
1479
 
1467
- def mixture_metrics(self, m_id: int, metrics: list[str], force: bool = False) -> list[Any]:
1480
+ def mixture_metrics(self, m_id: int, metrics: list[str], force: bool = False) -> dict[str, Any]:
1468
1481
  """Get metrics data for the given mixture ID
1469
1482
 
1470
1483
  :param m_id: Zero-based mixture ID
@@ -1595,6 +1608,19 @@ class MixtureDatabase:
1595
1608
 
1596
1609
  speech = create_speech()
1597
1610
 
1611
+ def create_mixture_stats() -> Callable[[], AudioStatsMetrics]:
1612
+ state: AudioStatsMetrics | None = None
1613
+
1614
+ def get() -> AudioStatsMetrics:
1615
+ nonlocal state
1616
+ if state is None:
1617
+ state = calc_audio_stats(mixture_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
1618
+ return state
1619
+
1620
+ return get
1621
+
1622
+ mixture_stats = create_mixture_stats()
1623
+
1598
1624
  def create_targets_stats() -> Callable[[], list[AudioStatsMetrics]]:
1599
1625
  state: list[AudioStatsMetrics] | None = None
1600
1626
 
@@ -1803,6 +1829,36 @@ class MixtureDatabase:
1803
1829
  extended=False,
1804
1830
  )
1805
1831
 
1832
+ if m == "mxdco":
1833
+ return mixture_stats().dco
1834
+
1835
+ if m == "mxmin":
1836
+ return mixture_stats().min
1837
+
1838
+ if m == "mxmax":
1839
+ return mixture_stats().max
1840
+
1841
+ if m == "mxpkdb":
1842
+ return mixture_stats().pkdb
1843
+
1844
+ if m == "mxlrms":
1845
+ return mixture_stats().lrms
1846
+
1847
+ if m == "mxpkr":
1848
+ return mixture_stats().pkr
1849
+
1850
+ if m == "mxtr":
1851
+ return mixture_stats().tr
1852
+
1853
+ if m == "mxcr":
1854
+ return mixture_stats().cr
1855
+
1856
+ if m == "mxfl":
1857
+ return mixture_stats().fl
1858
+
1859
+ if m == "mxpkc":
1860
+ return mixture_stats().pkc
1861
+
1806
1862
  if m == "mxtdco":
1807
1863
  return target_stats().dco
1808
1864
 
@@ -1916,21 +1972,34 @@ class MixtureDatabase:
1916
1972
 
1917
1973
  raise AttributeError(f"Unrecognized metric: '{m}'")
1918
1974
 
1919
- result: list[float | int | str | Segsnr | None] = []
1975
+ result: dict[str, Any] = {}
1920
1976
  for metric in metrics:
1921
- result.append(calc(metric))
1977
+ result[metric] = calc(metric)
1978
+
1979
+ # Check for metrics dependencies and add them even if not explicitly requested.
1980
+ if metric.startswith("mxwer"):
1981
+ dependencies = ("mxasr." + metric[6:], "tasr." + metric[6:])
1982
+ for dependency in dependencies:
1983
+ result[dependency] = calc(dependency)
1922
1984
 
1923
1985
  return result
1924
1986
 
1925
1987
 
1926
- @lru_cache
1927
- def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
1988
+ def _spectral_mask(db: partial, sm_id: int, use_cache: bool = True) -> SpectralMask:
1928
1989
  """Get spectral mask with ID from db
1929
1990
 
1930
1991
  :param db: Database context
1931
1992
  :param sm_id: Spectral mask ID
1993
+ :param use_cache: If true, use LRU caching
1932
1994
  :return: Spectral mask
1933
1995
  """
1996
+ if use_cache:
1997
+ return __spectral_mask(db, sm_id)
1998
+ return __spectral_mask.__wrapped__(db, sm_id)
1999
+
2000
+
2001
+ @lru_cache
2002
+ def __spectral_mask(db: partial, sm_id: int) -> SpectralMask:
1934
2003
  from .db_datatypes import SpectralMaskRecord
1935
2004
 
1936
2005
  with db() as c:
@@ -1953,12 +2022,26 @@ def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
1953
2022
  )
1954
2023
 
1955
2024
 
2025
+ def _target_file(db: partial, t_id: int, use_cache: bool = True) -> TargetFile:
2026
+ """Get target file with ID from db
2027
+
2028
+ :param db: Database context
2029
+ :param t_id: Target file ID
2030
+ :param use_cache: If true, use LRU caching
2031
+ :return: Target file
2032
+ """
2033
+ if use_cache:
2034
+ return __target_file(db, t_id, use_cache)
2035
+ return __target_file.__wrapped__(db, t_id, use_cache)
2036
+
2037
+
1956
2038
  @lru_cache
1957
- def _target_file(db: partial, t_id: int) -> TargetFile:
2039
+ def __target_file(db: partial, t_id: int, use_cache: bool = True) -> TargetFile:
1958
2040
  """Get target file with ID from db
1959
2041
 
1960
2042
  :param db: Database context
1961
2043
  :param t_id: Target file ID
2044
+ :param use_cache: If true, use LRU caching
1962
2045
  :return: Target file
1963
2046
  """
1964
2047
  import json
@@ -1982,19 +2065,26 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
1982
2065
  samples=target_file.samples,
1983
2066
  class_indices=json.loads(target_file.class_indices),
1984
2067
  level_type=target_file.level_type,
1985
- truth_configs=_target_truth_configs(db, t_id),
2068
+ truth_configs=_target_truth_configs(db, t_id, use_cache),
1986
2069
  speaker_id=target_file.speaker_id,
1987
2070
  )
1988
2071
 
1989
2072
 
1990
- @lru_cache
1991
- def _noise_file(db: partial, n_id: int) -> NoiseFile:
2073
+ def _noise_file(db: partial, n_id: int, use_cache: bool = True) -> NoiseFile:
1992
2074
  """Get noise file with ID from db
1993
2075
 
1994
2076
  :param db: Database context
1995
2077
  :param n_id: Noise file ID
2078
+ :param use_cache: If true, use LRU caching
1996
2079
  :return: Noise file
1997
2080
  """
2081
+ if use_cache:
2082
+ return __noise_file(db, n_id)
2083
+ return __noise_file.__wrapped__(db, n_id)
2084
+
2085
+
2086
+ @lru_cache
2087
+ def __noise_file(db: partial, n_id: int) -> NoiseFile:
1998
2088
  with db() as c:
1999
2089
  noise = c.execute(
2000
2090
  """
@@ -2007,14 +2097,21 @@ def _noise_file(db: partial, n_id: int) -> NoiseFile:
2007
2097
  return NoiseFile(name=noise[0], samples=noise[1])
2008
2098
 
2009
2099
 
2010
- @lru_cache
2011
- def _impulse_response_file(db: partial, ir_id: int) -> str:
2012
- """Get impulse response file with ID from db
2100
+ def _impulse_response_file(db: partial, ir_id: int, use_cache: bool = True) -> str:
2101
+ """Get impulse response file name with ID from db
2013
2102
 
2014
2103
  :param db: Database context
2015
2104
  :param ir_id: Impulse response file ID
2016
- :return: Noise
2105
+ :param use_cache: If true, use LRU caching
2106
+ :return: Impulse response file name
2017
2107
  """
2108
+ if use_cache:
2109
+ return __impulse_response_file(db, ir_id)
2110
+ return __impulse_response_file.__wrapped__(db, ir_id)
2111
+
2112
+
2113
+ @lru_cache
2114
+ def __impulse_response_file(db: partial, ir_id: int) -> str:
2018
2115
  with db() as c:
2019
2116
  return str(
2020
2117
  c.execute(
@@ -2028,14 +2125,49 @@ def _impulse_response_file(db: partial, ir_id: int) -> str:
2028
2125
  )
2029
2126
 
2030
2127
 
2128
+ def _impulse_response_delay(db: partial, ir_id: int, use_cache: bool = True) -> int:
2129
+ """Get impulse response delay with ID from db
2130
+
2131
+ :param db: Database context
2132
+ :param ir_id: Impulse response file ID
2133
+ :param use_cache: If true, use LRU caching
2134
+ :return: Impulse response delay
2135
+ """
2136
+ if use_cache:
2137
+ return __impulse_response_delay(db, ir_id)
2138
+ return __impulse_response_delay.__wrapped__(db, ir_id)
2139
+
2140
+
2031
2141
  @lru_cache
2032
- def _mixture(db: partial, m_id: int) -> Mixture:
2142
+ def __impulse_response_delay(db: partial, ir_id: int) -> int:
2143
+ with db() as c:
2144
+ return int(
2145
+ c.execute(
2146
+ """
2147
+ SELECT impulse_response_file.delay
2148
+ FROM impulse_response_file
2149
+ WHERE ? = impulse_response_file.id
2150
+ """,
2151
+ (ir_id + 1,),
2152
+ ).fetchone()[0]
2153
+ )
2154
+
2155
+
2156
+ def _mixture(db: partial, m_id: int, use_cache: bool = True) -> Mixture:
2033
2157
  """Get mixture record with ID from db
2034
2158
 
2035
2159
  :param db: Database context
2036
2160
  :param m_id: Zero-based mixture ID
2161
+ :param use_cache: If true, use LRU caching
2037
2162
  :return: Mixture record
2038
2163
  """
2164
+ if use_cache:
2165
+ return __mixture(db, m_id)
2166
+ return __mixture.__wrapped__(db, m_id)
2167
+
2168
+
2169
+ @lru_cache
2170
+ def __mixture(db: partial, m_id: int) -> Mixture:
2039
2171
  from .db_datatypes import MixtureRecord
2040
2172
  from .db_datatypes import TargetRecord
2041
2173
  from .helpers import to_mixture
@@ -2068,8 +2200,14 @@ def _mixture(db: partial, m_id: int) -> Mixture:
2068
2200
  return to_mixture(mixture, targets)
2069
2201
 
2070
2202
 
2203
+ def _speaker(db: partial, s_id: int | None, tier: str, use_cache: bool = True) -> str | None:
2204
+ if use_cache:
2205
+ return __speaker(db, s_id, tier)
2206
+ return __speaker.__wrapped__(db, s_id, tier)
2207
+
2208
+
2071
2209
  @lru_cache
2072
- def _speaker(db: partial, s_id: int | None, tier: str) -> str | None:
2210
+ def __speaker(db: partial, s_id: int | None, tier: str) -> str | None:
2073
2211
  if s_id is None:
2074
2212
  return None
2075
2213
 
@@ -2082,8 +2220,14 @@ def _speaker(db: partial, s_id: int | None, tier: str) -> str | None:
2082
2220
  return data[0]
2083
2221
 
2084
2222
 
2223
+ def _target_truth_configs(db: partial, t_id: int, use_cache: bool = True) -> TruthConfigs:
2224
+ if use_cache:
2225
+ return __target_truth_configs(db, t_id)
2226
+ return __target_truth_configs.__wrapped__(db, t_id)
2227
+
2228
+
2085
2229
  @lru_cache
2086
- def _target_truth_configs(db: partial, t_id: int) -> TruthConfigs:
2230
+ def __target_truth_configs(db: partial, t_id: int) -> TruthConfigs:
2087
2231
  import json
2088
2232
 
2089
2233
  from .datatypes import TruthConfig
@@ -16,14 +16,11 @@ def get_augmented_targets(
16
16
 
17
17
  augmented_targets: list[AugmentedTarget] = []
18
18
  for mixup in mixups:
19
- augmentation_indices = get_augmentation_indices_for_mixup(target_augmentations, mixup)
19
+ target_augmentation_indices = get_augmentation_indices_for_mixup(target_augmentations, mixup)
20
20
  for target_index in range(len(target_files)):
21
- for augmentation_index in augmentation_indices:
21
+ for target_augmentation_index in target_augmentation_indices:
22
22
  augmented_targets.append(
23
- AugmentedTarget(
24
- target_id=target_index,
25
- target_augmentation_id=augmentation_index,
26
- )
23
+ AugmentedTarget(target_id=target_index, target_augmentation_id=target_augmentation_index)
27
24
  )
28
25
 
29
26
  return augmented_targets
@@ -13,6 +13,7 @@ def _core(
13
13
  parameters: int,
14
14
  mapped: bool,
15
15
  snr: bool,
16
+ use_cache: bool = True,
16
17
  ) -> Truth:
17
18
  from os.path import join
18
19
 
@@ -50,8 +51,8 @@ def _core(
50
51
  tmp = np.nan_to_num(tmp, nan=-np.inf, posinf=np.inf, neginf=-np.inf)
51
52
 
52
53
  if mapped:
53
- snr_db_mean = load_object(join(mixdb.location, config["snr_db_mean"]))
54
- snr_db_std = load_object(join(mixdb.location, config["snr_db_std"]))
54
+ snr_db_mean = load_object(join(mixdb.location, config["snr_db_mean"]), use_cache)
55
+ snr_db_std = load_object(join(mixdb.location, config["snr_db_std"]), use_cache)
55
56
  tmp = _calculate_mapped_snr_f(tmp, snr_db_mean, snr_db_std)
56
57
 
57
58
  truth[frame] = tmp
@@ -85,7 +86,7 @@ def energy_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
85
86
  return ForwardTransform(**feature_forward_transform_config(feature)).bins
86
87
 
87
88
 
88
- def energy_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
89
+ def energy_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict, use_cache: bool = True) -> Truth:
89
90
  """Frequency domain energy truth generation function
90
91
 
91
92
  Calculates the true energy per bin:
@@ -104,6 +105,7 @@ def energy_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict)
104
105
  parameters=energy_f_parameters(mixdb.feature, mixdb.num_classes, config),
105
106
  mapped=False,
106
107
  snr=False,
108
+ use_cache=use_cache,
107
109
  )
108
110
 
109
111
 
@@ -118,7 +120,7 @@ def snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
118
120
  return ForwardTransform(**feature_forward_transform_config(feature)).bins
119
121
 
120
122
 
121
- def snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
123
+ def snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict, use_cache: bool = True) -> Truth:
122
124
  """Frequency domain SNR truth function documentation
123
125
 
124
126
  Calculates the true SNR per bin:
@@ -137,6 +139,7 @@ def snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) ->
137
139
  parameters=snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
138
140
  mapped=False,
139
141
  snr=True,
142
+ use_cache=use_cache,
140
143
  )
141
144
 
142
145
 
@@ -156,7 +159,7 @@ def mapped_snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> i
156
159
  return ForwardTransform(**feature_forward_transform_config(feature)).bins
157
160
 
158
161
 
159
- def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
162
+ def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict, use_cache: bool = True) -> Truth:
160
163
  """Frequency domain mapped SNR truth function documentation
161
164
 
162
165
  Output shape: [:, bins]
@@ -169,6 +172,7 @@ def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: d
169
172
  parameters=mapped_snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
170
173
  mapped=True,
171
174
  snr=True,
175
+ use_cache=use_cache,
172
176
  )
173
177
 
174
178
 
@@ -25,4 +25,4 @@ def metrics(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict)
25
25
  m = [config["metric"]]
26
26
  else:
27
27
  m = config["metric"]
28
- return mixdb.mixture_metrics(m_id, m)[0][target_index]
28
+ return mixdb.mixture_metrics(m_id, m)[m[0]][target_index]
sonusai/mkwav.py CHANGED
@@ -63,7 +63,7 @@ def _process_mixture(m_id: int, location: str, write_target: bool, write_targets
63
63
  if write_noise:
64
64
  write_audio(name=join(location, "noise.wav"), audio=float_to_int16(mixdb.mixture_noise(m_id)))
65
65
 
66
- write_mixture_metadata(mixdb, m_id)
66
+ write_mixture_metadata(mixdb, m_id=m_id)
67
67
 
68
68
 
69
69
  def main() -> None:
sonusai/onnx_predict.py CHANGED
@@ -193,7 +193,7 @@ def main() -> None:
193
193
  # run inference, ort session wants i.e. batch x timesteps x feat_params, outputs numpy BxTxFP or BxFP
194
194
  predict = session.run(out_names, {in0name: feature})[0]
195
195
  # predict, _ = reshape_outputs(predict=predict[0], timesteps=frames) # frames x feat_params
196
- output_fname = join(output_dir, mixdb.mixtures[mixid].name)
196
+ output_fname = join(output_dir, mixdb.mixture(mixid).name)
197
197
  with h5py.File(output_fname, "a") as f:
198
198
  if "predict" in f:
199
199
  del f["predict"]
@@ -178,7 +178,7 @@ def get_mixids_from_snr(
178
178
  result: dict[float, list[int]] = {}
179
179
  for snr in snrs:
180
180
  # Get a list of mixids for each SNR
181
- result[snr] = sorted([i for i, mixture in enumerate(mixdb.mixtures) if mixture.snr == snr and i in mixid_out])
181
+ result[snr] = sorted([i for i, mixture in enumerate(mixdb.mixtures()) if mixture.snr == snr and i in mixid_out])
182
182
 
183
183
  return result
184
184
 
sonusai/utils/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  # SonusAI general utilities
2
2
  # ruff: noqa: F401
3
+
3
4
  from .asl_p56 import asl_p56
4
5
  from .asr import ASRResult
5
6
  from .asr import calc_asr
@@ -53,5 +54,6 @@ from .stacked_complex import stacked_complex_imag
53
54
  from .stacked_complex import stacked_complex_real
54
55
  from .stacked_complex import unstack_complex
55
56
  from .stratified_shuffle_split import stratified_shuffle_split_mixid
57
+ from .temp_seed import temp_seed
56
58
  from .write_audio import write_audio
57
59
  from .yes_or_no import yes_or_no
sonusai/utils/asr.py CHANGED
@@ -65,7 +65,7 @@ def calc_asr(audio: AudioT | str, engine: str, **config) -> ASRResult:
65
65
  from sonusai.mixture import read_audio
66
66
 
67
67
  if not isinstance(audio, np.ndarray):
68
- audio = copy(read_audio(audio))
68
+ audio = copy(read_audio(audio, config.get("use_cache", True)))
69
69
 
70
70
  return _asr_fn(engine)(audio, **config)
71
71
 
@@ -2,9 +2,15 @@ from functools import lru_cache
2
2
  from typing import Any
3
3
 
4
4
 
5
+ def load_object(name: str, use_cache: bool = True) -> Any:
6
+ """Load an object from a pickle file"""
7
+ if use_cache:
8
+ return _load_object(name)
9
+ return _load_object.__wrapped__(name)
10
+
11
+
5
12
  @lru_cache
6
- def load_object(name: str) -> Any:
7
- """Load an object from a pickle file (with LRU caching)"""
13
+ def _load_object(name: str) -> Any:
8
14
  import pickle
9
15
  from os.path import exists
10
16
 
@@ -42,7 +42,7 @@ def stratified_shuffle_split_mixid(
42
42
  raise ValueError("vsplit must be between 0 and 1")
43
43
 
44
44
  a_class_mixid: dict[int, list[int]] = {i + 1: [] for i in range(mixdb.num_classes)}
45
- for mixid, mixture in enumerate(mixdb.mixtures):
45
+ for mixid, mixture in enumerate(mixdb.mixtures()):
46
46
  class_count = get_class_count_from_mixids(mixdb, mixid)
47
47
  if any(class_count):
48
48
  for class_index in mixdb.target_files[mixture.targets[0].file_id].class_indices:
@@ -0,0 +1,13 @@
1
+ import contextlib
2
+
3
+ import numpy as np
4
+
5
+
6
+ @contextlib.contextmanager
7
+ def temp_seed(seed):
8
+ state = np.random.get_state()
9
+ np.random.seed(seed)
10
+ try:
11
+ yield
12
+ finally:
13
+ np.random.set_state(state)
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.3
2
2
  Name: sonusai
3
- Version: 0.19.9
3
+ Version: 0.20.2
4
4
  Summary: Framework for building deep neural network models for sound, speech, and voice AI
5
5
  Home-page: https://aaware.com
6
6
  License: GPL-3.0-only