sonusai 0.19.9__py3-none-any.whl → 0.19.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/mixture/audio.py CHANGED
@@ -44,44 +44,77 @@ def validate_input_file(input_filepath: str | Path) -> None:
44
44
  raise OSError(f"This installation cannot process .{ext} files")
45
45
 
46
46
 
47
- @lru_cache
48
- def get_sample_rate(name: str | Path) -> int:
47
+ def get_sample_rate(name: str | Path, use_cache: bool = True) -> int:
49
48
  """Get sample rate from audio file
50
49
 
51
50
  :param name: File name
51
+ :param use_cache: If true, use LRU caching
52
52
  :return: Sample rate
53
53
  """
54
+ if use_cache:
55
+ return _get_sample_rate(name)
56
+ return _get_sample_rate.__wrapped__(name)
57
+
58
+
59
+ @lru_cache
60
+ def _get_sample_rate(name: str | Path) -> int:
54
61
  from .soundfile_audio import get_sample_rate
55
62
 
56
63
  return get_sample_rate(name)
57
64
 
58
65
 
59
- @lru_cache
60
- def read_audio(name: str | Path) -> AudioT:
66
+ def read_audio(name: str | Path, use_cache: bool = True) -> AudioT:
61
67
  """Read audio data from a file
62
68
 
63
69
  :param name: File name
70
+ :param use_cache: If true, use LRU caching
64
71
  :return: Array of time domain audio data
65
72
  """
73
+ if use_cache:
74
+ return _read_audio(name)
75
+ return _read_audio.__wrapped__(name)
76
+
77
+
78
+ @lru_cache
79
+ def _read_audio(name: str | Path) -> AudioT:
66
80
  from .soundfile_audio import read_audio
67
81
 
68
82
  return read_audio(name)
69
83
 
70
84
 
71
- @lru_cache
72
- def read_ir(name: str | Path) -> ImpulseResponseData:
85
+ def read_ir(name: str | Path, use_cache: bool = True) -> ImpulseResponseData:
73
86
  """Read impulse response data
74
87
 
75
88
  :param name: File name
89
+ :param use_cache: If true, use LRU caching
76
90
  :return: ImpulseResponseData object
77
91
  """
92
+ if use_cache:
93
+ return _read_ir(name)
94
+ return _read_ir.__wrapped__(name)
95
+
96
+
97
+ @lru_cache
98
+ def _read_ir(name: str | Path) -> ImpulseResponseData:
78
99
  from .soundfile_audio import read_ir
79
100
 
80
101
  return read_ir(name)
81
102
 
82
103
 
104
+ def get_num_samples(name: str | Path, use_cache: bool = True) -> int:
105
+ """Get the number of samples resampled to the SonusAI sample rate in the given file
106
+
107
+ :param name: File name
108
+ :param use_cache: If true, use LRU caching
109
+ :return: number of samples in resampled audio
110
+ """
111
+ if use_cache:
112
+ return _get_num_samples(name)
113
+ return _get_num_samples.__wrapped__(name)
114
+
115
+
83
116
  @lru_cache
84
- def get_num_samples(name: str | Path) -> int:
117
+ def _get_num_samples(name: str | Path) -> int:
85
118
  """Get the number of samples resampled to the SonusAI sample rate in the given file
86
119
 
87
120
  :param name: File name
@@ -119,8 +119,7 @@ def initialize_db(location: str, test: bool = False) -> None:
119
119
  id INTEGER PRIMARY KEY NOT NULL,
120
120
  file_id INTEGER NOT NULL,
121
121
  augmentation TEXT NOT NULL,
122
- FOREIGN KEY(file_id) REFERENCES target_file (id),
123
- UNIQUE(file_id, augmentation))
122
+ FOREIGN KEY(file_id) REFERENCES target_file (id))
124
123
  """)
125
124
 
126
125
  con.execute("""
@@ -389,8 +388,7 @@ def update_mixid_width(location: str, num_mixtures: int, test: bool = False) ->
389
388
  con.close()
390
389
 
391
390
 
392
- def populate_mixture_table(
393
- location: str,
391
+ def generate_mixtures(
394
392
  noise_mix_mode: str,
395
393
  augmented_targets: list[AugmentedTarget],
396
394
  target_files: list[TargetFile],
@@ -403,13 +401,8 @@ def populate_mixture_table(
403
401
  num_classes: int,
404
402
  feature_step_samples: int,
405
403
  num_ir: int,
406
- test: bool = False,
407
- ) -> tuple[int, int]:
408
- """Generate mixtures and populate mixture table"""
409
- from .helpers import from_mixture
410
- from .helpers import from_target
411
- from .mixdb import db_connection
412
-
404
+ ) -> tuple[int, int, list[Mixture]]:
405
+ """Generate mixtures"""
413
406
  if noise_mix_mode == "exhaustive":
414
407
  func = _exhaustive_noise_mix
415
408
  elif noise_mix_mode == "non-exhaustive":
@@ -419,7 +412,7 @@ def populate_mixture_table(
419
412
  else:
420
413
  raise ValueError(f"invalid noise_mix_mode: {noise_mix_mode}")
421
414
 
422
- used_noise_files, used_noise_samples, mixtures = func(
415
+ return func(
423
416
  augmented_targets=augmented_targets,
424
417
  target_files=target_files,
425
418
  target_augmentations=target_augmentations,
@@ -433,20 +426,41 @@ def populate_mixture_table(
433
426
  num_ir=num_ir,
434
427
  )
435
428
 
429
+
430
+ def populate_mixture_table(
431
+ location: str,
432
+ mixtures: list[Mixture],
433
+ test: bool = False,
434
+ logging: bool = False,
435
+ show_progress: bool = False,
436
+ ) -> None:
437
+ """Populate mixture table"""
438
+ from sonusai import logger
439
+ from sonusai.utils import track
440
+
441
+ from .helpers import from_mixture
442
+ from .helpers import from_target
443
+ from .mixdb import db_connection
444
+
436
445
  con = db_connection(location=location, readonly=False, test=test)
446
+
437
447
  # Populate target table
448
+ if logging:
449
+ logger.info("Populating target table")
450
+ targets: list[tuple[int, str]] = []
438
451
  for mixture in mixtures:
439
452
  for target in mixture.targets:
440
- con.execute(
441
- """
442
- INSERT OR IGNORE INTO target (file_id, augmentation)
443
- VALUES (?, ?)
444
- """,
445
- from_target(target),
446
- )
453
+ entry = from_target(target)
454
+ if entry not in targets:
455
+ targets.append(entry)
456
+ for target in track(targets, disable=not show_progress):
457
+ con.execute("INSERT INTO target (file_id, augmentation) VALUES (?, ?)", target)
447
458
 
448
459
  # Populate mixture table
449
- for m_id, mixture in enumerate(mixtures):
460
+ if logging:
461
+ logger.info("Populating mixture table")
462
+ for mixture in track(mixtures, disable=not show_progress):
463
+ m_id = int(mixture.name)
450
464
  con.execute(
451
465
  """
452
466
  INSERT INTO mixture (id, name, noise_file_id, noise_augmentation, noise_offset, noise_snr_gain, random_snr,
@@ -473,20 +487,13 @@ def populate_mixture_table(
473
487
  con.commit()
474
488
  con.close()
475
489
 
476
- return used_noise_files, used_noise_samples
477
490
 
478
-
479
- def update_mixture_table(location: str, m_id: int, with_data: bool = False, test: bool = False) -> GenMixData:
491
+ def update_mixture(mixdb: MixtureDatabase, mixture: Mixture, with_data: bool = False) -> tuple[Mixture, GenMixData]:
480
492
  """Update mixture record with name and gains"""
481
493
  from .audio import get_next_noise
482
494
  from .augmentation import apply_gain
483
- from .datatypes import GenMixData
484
- from .helpers import from_mixture
485
495
  from .helpers import get_target
486
- from .mixdb import db_connection
487
496
 
488
- mixdb = MixtureDatabase(location, test)
489
- mixture = mixdb.mixture(m_id)
490
497
  mixture, targets_audio = _initialize_targets_audio(mixdb, mixture)
491
498
 
492
499
  noise_audio = _augmented_noise_audio(mixdb, mixture)
@@ -501,29 +508,8 @@ def update_mixture_table(location: str, m_id: int, with_data: bool = False, test
501
508
 
502
509
  mixture.name = f"{int(mixture.name):0{mixdb.mixid_width}}"
503
510
 
504
- con = db_connection(location=location, readonly=False, test=test)
505
- con.execute(
506
- """
507
- UPDATE mixture SET name=?,
508
- noise_file_id=?,
509
- noise_augmentation=?,
510
- noise_offset=?,
511
- noise_snr_gain=?,
512
- random_snr=?,
513
- snr=?,
514
- samples=?,
515
- spectral_mask_id=?,
516
- spectral_mask_seed=?,
517
- target_snr_gain=?
518
- WHERE ? = mixture.id
519
- """,
520
- (*from_mixture(mixture), m_id + 1),
521
- )
522
- con.commit()
523
- con.close()
524
-
525
511
  if not with_data:
526
- return GenMixData()
512
+ return mixture, GenMixData()
527
513
 
528
514
  # Apply SNR gains
529
515
  targets_audio = [apply_gain(audio=target_audio, gain=mixture.target_snr_gain) for target_audio in targets_audio]
@@ -533,7 +519,7 @@ def update_mixture_table(location: str, m_id: int, with_data: bool = False, test
533
519
  target_audio = get_target(mixdb, mixture, targets_audio)
534
520
  mixture_audio = target_audio + noise_audio
535
521
 
536
- return GenMixData(
522
+ return mixture, GenMixData(
537
523
  mixture=mixture_audio,
538
524
  targets=targets_audio,
539
525
  target=target_audio,
@@ -553,7 +539,7 @@ def _augmented_noise_audio(mixdb: MixtureDatabase, mixture: Mixture) -> AudioT:
553
539
  audio = read_audio(noise.name)
554
540
  audio = apply_augmentation(audio, noise_augmentation)
555
541
  if noise_augmentation.ir is not None:
556
- audio = apply_impulse_response(audio, read_ir(mixdb.impulse_response_file(noise_augmentation.ir)))
542
+ audio = apply_impulse_response(audio, read_ir(mixdb.impulse_response_file(noise_augmentation.ir))) # pyright: ignore [reportArgumentType]
557
543
 
558
544
  return audio
559
545
 
@@ -582,7 +568,10 @@ def _initialize_targets_audio(mixdb: MixtureDatabase, mixture: Mixture) -> tuple
582
568
 
583
569
 
584
570
  def _initialize_mixture_gains(
585
- mixdb: MixtureDatabase, mixture: Mixture, target_audio: AudioT, noise_audio: AudioT
571
+ mixdb: MixtureDatabase,
572
+ mixture: Mixture,
573
+ target_audio: AudioT,
574
+ noise_audio: AudioT,
586
575
  ) -> Mixture:
587
576
  import numpy as np
588
577
 
@@ -135,14 +135,20 @@ def mixture_all_speech_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> lis
135
135
  return results
136
136
 
137
137
 
138
- def mixture_metadata(mixdb: MixtureDatabase, m_id: int) -> str:
138
+ def mixture_metadata(mixdb: MixtureDatabase, m_id: int | None = None, mixture: Mixture | None = None) -> str:
139
139
  """Create a string of metadata for a Mixture
140
140
 
141
141
  :param mixdb: Mixture database
142
142
  :param m_id: Mixture ID
143
+ :param mixture: Mixture record
143
144
  :return: String of metadata
144
145
  """
145
- mixture = mixdb.mixture(m_id)
146
+ if m_id is not None:
147
+ mixture = mixdb.mixture(m_id)
148
+
149
+ if mixture is None:
150
+ raise ValueError("No mixture specified.")
151
+
146
152
  metadata = ""
147
153
  speech_metadata = mixture_all_speech_metadata(mixdb, mixture)
148
154
  for mi, target in enumerate(mixture.targets):
@@ -173,17 +179,25 @@ def mixture_metadata(mixdb: MixtureDatabase, m_id: int) -> str:
173
179
  return metadata
174
180
 
175
181
 
176
- def write_mixture_metadata(mixdb: MixtureDatabase, m_id: int) -> None:
182
+ def write_mixture_metadata(mixdb: MixtureDatabase, m_id: int | None = None, mixture: Mixture | None = None) -> None:
177
183
  """Write mixture metadata to a text file
178
184
 
179
185
  :param mixdb: Mixture database
180
186
  :param m_id: Mixture ID
187
+ :param mixture: Mixture record
181
188
  """
182
189
  from os.path import join
183
190
 
184
- name = join(mixdb.location, "mixture", mixdb.mixture(m_id).name, "metadata.txt")
191
+ if m_id is not None:
192
+ name = mixdb.mixture(m_id).name
193
+ elif mixture is not None:
194
+ name = mixture.name
195
+ else:
196
+ raise ValueError("No mixture specified.")
197
+
198
+ name = join(mixdb.location, "mixture", name, "metadata.txt")
185
199
  with open(file=name, mode="w") as f:
186
- f.write(mixture_metadata(mixdb, m_id))
200
+ f.write(mixture_metadata(mixdb, m_id, mixture))
187
201
 
188
202
 
189
203
  def from_mixture(
@@ -246,12 +260,13 @@ def to_target(entry: TargetRecord) -> Target:
246
260
  )
247
261
 
248
262
 
249
- def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: list[AudioT]) -> AudioT:
263
+ def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: list[AudioT], use_cache: bool = True) -> AudioT:
250
264
  """Get the augmented target audio data for the given mixture record
251
265
 
252
266
  :param mixdb: Mixture database
253
267
  :param mixture: Mixture record
254
268
  :param targets_audio: List of augmented target audio data (one per target in the mixup)
269
+ :param use_cache: If true, use LRU caching
255
270
  :return: Sum of augmented target audio data
256
271
  """
257
272
  # Apply impulse responses to targets
@@ -265,7 +280,7 @@ def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: list[Aud
265
280
  ir_idx = mixture.targets[idx].augmentation.ir
266
281
  if ir_idx is not None:
267
282
  targets_ir.append(
268
- apply_impulse_response(audio=target, ir=read_ir(mixdb.impulse_response_file(int(ir_idx))))
283
+ apply_impulse_response(audio=target, ir=read_ir(mixdb.impulse_response_file(int(ir_idx)), use_cache)) # pyright: ignore [reportArgumentType]
269
284
  )
270
285
  else:
271
286
  targets_ir.append(target)
sonusai/mixture/mixdb.py CHANGED
@@ -61,7 +61,7 @@ def db_connection(
61
61
  if not create and readonly:
62
62
  name += "?mode=ro"
63
63
 
64
- connection = sqlite3.connect("file:" + name, uri=True)
64
+ connection = sqlite3.connect("file:" + name, uri=True, timeout=20)
65
65
 
66
66
  if verbose:
67
67
  connection.set_trace_callback(print)
@@ -84,7 +84,7 @@ class SQLiteContextManager:
84
84
 
85
85
 
86
86
  class MixtureDatabase:
87
- def __init__(self, location: str, test: bool = False) -> None:
87
+ def __init__(self, location: str, test: bool = False, use_cache: bool = True) -> None:
88
88
  import json
89
89
  from os.path import exists
90
90
 
@@ -92,6 +92,7 @@ class MixtureDatabase:
92
92
 
93
93
  self.location = location
94
94
  self.test = test
95
+ self.use_cache = use_cache
95
96
 
96
97
  if not exists(db_file(self.location, self.test)):
97
98
  raise OSError(f"Could not find mixture database in {self.location}")
@@ -121,7 +122,7 @@ class MixtureDatabase:
121
122
  class_weights_threshold=self.class_weights_thresholds,
122
123
  feature=self.feature,
123
124
  impulse_response_files=self.impulse_response_files,
124
- mixtures=self.mixtures,
125
+ mixtures=self.mixtures(),
125
126
  noise_mix_mode=self.noise_mix_mode,
126
127
  noise_files=self.noise_files,
127
128
  num_classes=self.num_classes,
@@ -488,7 +489,7 @@ class MixtureDatabase:
488
489
  return truth_configs
489
490
 
490
491
  def target_truth_configs(self, t_id: int) -> TruthConfigs:
491
- return _target_truth_configs(self.db, t_id)
492
+ return _target_truth_configs(self.db, t_id, self.use_cache)
492
493
 
493
494
  @cached_property
494
495
  def random_snrs(self) -> list[float]:
@@ -556,7 +557,7 @@ class MixtureDatabase:
556
557
  :param sm_id: Spectral mask ID
557
558
  :return: Spectral mask
558
559
  """
559
- return _spectral_mask(self.db, sm_id)
560
+ return _spectral_mask(self.db, sm_id, self.use_cache)
560
561
 
561
562
  @cached_property
562
563
  def target_files(self) -> list[TargetFile]:
@@ -619,7 +620,7 @@ class MixtureDatabase:
619
620
  :param t_id: Target file ID
620
621
  :return: Target file
621
622
  """
622
- return _target_file(self.db, t_id)
623
+ return _target_file(self.db, t_id, self.use_cache)
623
624
 
624
625
  @cached_property
625
626
  def num_target_files(self) -> int:
@@ -657,7 +658,7 @@ class MixtureDatabase:
657
658
  :param n_id: Noise file ID
658
659
  :return: Noise file
659
660
  """
660
- return _noise_file(self.db, n_id)
661
+ return _noise_file(self.db, n_id, self.use_cache)
661
662
 
662
663
  @cached_property
663
664
  def num_noise_files(self) -> int:
@@ -706,7 +707,7 @@ class MixtureDatabase:
706
707
  """
707
708
  if ir_id is None:
708
709
  return None
709
- return _impulse_response_file(self.db, ir_id)
710
+ return _impulse_response_file(self.db, ir_id, self.use_cache)
710
711
 
711
712
  @cached_property
712
713
  def num_impulse_response_files(self) -> int:
@@ -717,7 +718,6 @@ class MixtureDatabase:
717
718
  with self.db() as c:
718
719
  return int(c.execute("SELECT count(impulse_response_file.id) FROM impulse_response_file").fetchone()[0])
719
720
 
720
- @cached_property
721
721
  def mixtures(self) -> list[Mixture]:
722
722
  """Get mixtures from db
723
723
 
@@ -760,7 +760,7 @@ class MixtureDatabase:
760
760
  :param m_id: Zero-based mixture ID
761
761
  :return: Mixture record
762
762
  """
763
- return _mixture(self.db, m_id)
763
+ return _mixture(self.db, m_id, self.use_cache)
764
764
 
765
765
  @cached_property
766
766
  def mixid_width(self) -> int:
@@ -805,7 +805,7 @@ class MixtureDatabase:
805
805
  """
806
806
  from .audio import read_audio
807
807
 
808
- return read_audio(self.target_file(t_id).name)
808
+ return read_audio(self.target_file(t_id).name, self.use_cache)
809
809
 
810
810
  def augmented_noise_audio(self, mixture: Mixture) -> AudioT:
811
811
  """Get augmented noise audio
@@ -819,12 +819,12 @@ class MixtureDatabase:
819
819
  from .augmentation import apply_impulse_response
820
820
 
821
821
  noise = self.noise_file(mixture.noise.file_id)
822
- audio = read_audio(noise.name)
822
+ audio = read_audio(noise.name, self.use_cache)
823
823
  audio = apply_augmentation(audio, mixture.noise.augmentation)
824
824
  if mixture.noise.augmentation.ir is not None:
825
825
  audio = apply_impulse_response(
826
826
  audio,
827
- read_ir(self.impulse_response_file(mixture.noise.augmentation.ir)),
827
+ read_ir(self.impulse_response_file(mixture.noise.augmentation.ir), self.use_cache), # pyright: ignore [reportArgumentType]
828
828
  )
829
829
 
830
830
  return audio
@@ -1332,7 +1332,7 @@ class MixtureDatabase:
1332
1332
  return sorted(set(self.speaker_metadata_tiers + self.textgrid_metadata_tiers))
1333
1333
 
1334
1334
  def speaker(self, s_id: int | None, tier: str) -> str | None:
1335
- return _speaker(self.db, s_id, tier)
1335
+ return _speaker(self.db, s_id, tier, self.use_cache)
1336
1336
 
1337
1337
  def speech_metadata(self, tier: str) -> list[str]:
1338
1338
  from .helpers import get_textgrid_tier_from_target_file
@@ -1464,7 +1464,7 @@ class MixtureDatabase:
1464
1464
 
1465
1465
  return sorted(result)
1466
1466
 
1467
- def mixture_metrics(self, m_id: int, metrics: list[str], force: bool = False) -> list[Any]:
1467
+ def mixture_metrics(self, m_id: int, metrics: list[str], force: bool = False) -> dict[str, Any]:
1468
1468
  """Get metrics data for the given mixture ID
1469
1469
 
1470
1470
  :param m_id: Zero-based mixture ID
@@ -1916,21 +1916,34 @@ class MixtureDatabase:
1916
1916
 
1917
1917
  raise AttributeError(f"Unrecognized metric: '{m}'")
1918
1918
 
1919
- result: list[float | int | str | Segsnr | None] = []
1919
+ result: dict[str, Any] = {}
1920
1920
  for metric in metrics:
1921
- result.append(calc(metric))
1921
+ result[metric] = calc(metric)
1922
+
1923
+ # Check for metrics dependencies and add them even if not explicitly requested.
1924
+ if metric.startswith("mxwer"):
1925
+ dependencies = ("mxasr." + metric[6:], "tasr." + metric[6:])
1926
+ for dependency in dependencies:
1927
+ result[dependency] = calc(dependency)
1922
1928
 
1923
1929
  return result
1924
1930
 
1925
1931
 
1926
- @lru_cache
1927
- def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
1932
+ def _spectral_mask(db: partial, sm_id: int, use_cache: bool = True) -> SpectralMask:
1928
1933
  """Get spectral mask with ID from db
1929
1934
 
1930
1935
  :param db: Database context
1931
1936
  :param sm_id: Spectral mask ID
1937
+ :param use_cache: If true, use LRU caching
1932
1938
  :return: Spectral mask
1933
1939
  """
1940
+ if use_cache:
1941
+ return __spectral_mask(db, sm_id)
1942
+ return __spectral_mask.__wrapped__(db, sm_id)
1943
+
1944
+
1945
+ @lru_cache
1946
+ def __spectral_mask(db: partial, sm_id: int) -> SpectralMask:
1934
1947
  from .db_datatypes import SpectralMaskRecord
1935
1948
 
1936
1949
  with db() as c:
@@ -1953,12 +1966,26 @@ def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
1953
1966
  )
1954
1967
 
1955
1968
 
1969
+ def _target_file(db: partial, t_id: int, use_cache: bool = True) -> TargetFile:
1970
+ """Get target file with ID from db
1971
+
1972
+ :param db: Database context
1973
+ :param t_id: Target file ID
1974
+ :param use_cache: If true, use LRU caching
1975
+ :return: Target file
1976
+ """
1977
+ if use_cache:
1978
+ return __target_file(db, t_id, use_cache)
1979
+ return __target_file.__wrapped__(db, t_id, use_cache)
1980
+
1981
+
1956
1982
  @lru_cache
1957
- def _target_file(db: partial, t_id: int) -> TargetFile:
1983
+ def __target_file(db: partial, t_id: int, use_cache: bool = True) -> TargetFile:
1958
1984
  """Get target file with ID from db
1959
1985
 
1960
1986
  :param db: Database context
1961
1987
  :param t_id: Target file ID
1988
+ :param use_cache: If true, use LRU caching
1962
1989
  :return: Target file
1963
1990
  """
1964
1991
  import json
@@ -1982,19 +2009,26 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
1982
2009
  samples=target_file.samples,
1983
2010
  class_indices=json.loads(target_file.class_indices),
1984
2011
  level_type=target_file.level_type,
1985
- truth_configs=_target_truth_configs(db, t_id),
2012
+ truth_configs=_target_truth_configs(db, t_id, use_cache),
1986
2013
  speaker_id=target_file.speaker_id,
1987
2014
  )
1988
2015
 
1989
2016
 
1990
- @lru_cache
1991
- def _noise_file(db: partial, n_id: int) -> NoiseFile:
2017
+ def _noise_file(db: partial, n_id: int, use_cache: bool = True) -> NoiseFile:
1992
2018
  """Get noise file with ID from db
1993
2019
 
1994
2020
  :param db: Database context
1995
2021
  :param n_id: Noise file ID
2022
+ :param use_cache: If true, use LRU caching
1996
2023
  :return: Noise file
1997
2024
  """
2025
+ if use_cache:
2026
+ return __noise_file(db, n_id)
2027
+ return __noise_file.__wrapped__(db, n_id)
2028
+
2029
+
2030
+ @lru_cache
2031
+ def __noise_file(db: partial, n_id: int) -> NoiseFile:
1998
2032
  with db() as c:
1999
2033
  noise = c.execute(
2000
2034
  """
@@ -2007,14 +2041,21 @@ def _noise_file(db: partial, n_id: int) -> NoiseFile:
2007
2041
  return NoiseFile(name=noise[0], samples=noise[1])
2008
2042
 
2009
2043
 
2010
- @lru_cache
2011
- def _impulse_response_file(db: partial, ir_id: int) -> str:
2044
+ def _impulse_response_file(db: partial, ir_id: int, use_cache: bool = True) -> str:
2012
2045
  """Get impulse response file with ID from db
2013
2046
 
2014
2047
  :param db: Database context
2015
2048
  :param ir_id: Impulse response file ID
2016
- :return: Noise
2049
+ :param use_cache: If true, use LRU caching
2050
+ :return: Impulse response
2017
2051
  """
2052
+ if use_cache:
2053
+ return __impulse_response_file(db, ir_id)
2054
+ return __impulse_response_file.__wrapped__(db, ir_id)
2055
+
2056
+
2057
+ @lru_cache
2058
+ def __impulse_response_file(db: partial, ir_id: int) -> str:
2018
2059
  with db() as c:
2019
2060
  return str(
2020
2061
  c.execute(
@@ -2028,14 +2069,21 @@ def _impulse_response_file(db: partial, ir_id: int) -> str:
2028
2069
  )
2029
2070
 
2030
2071
 
2031
- @lru_cache
2032
- def _mixture(db: partial, m_id: int) -> Mixture:
2072
+ def _mixture(db: partial, m_id: int, use_cache: bool = True) -> Mixture:
2033
2073
  """Get mixture record with ID from db
2034
2074
 
2035
2075
  :param db: Database context
2036
2076
  :param m_id: Zero-based mixture ID
2077
+ :param use_cache: If true, use LRU caching
2037
2078
  :return: Mixture record
2038
2079
  """
2080
+ if use_cache:
2081
+ return __mixture(db, m_id)
2082
+ return __mixture.__wrapped__(db, m_id)
2083
+
2084
+
2085
+ @lru_cache
2086
+ def __mixture(db: partial, m_id: int) -> Mixture:
2039
2087
  from .db_datatypes import MixtureRecord
2040
2088
  from .db_datatypes import TargetRecord
2041
2089
  from .helpers import to_mixture
@@ -2068,8 +2116,14 @@ def _mixture(db: partial, m_id: int) -> Mixture:
2068
2116
  return to_mixture(mixture, targets)
2069
2117
 
2070
2118
 
2119
+ def _speaker(db: partial, s_id: int | None, tier: str, use_cache: bool = True) -> str | None:
2120
+ if use_cache:
2121
+ return __speaker(db, s_id, tier)
2122
+ return __speaker.__wrapped__(db, s_id, tier)
2123
+
2124
+
2071
2125
  @lru_cache
2072
- def _speaker(db: partial, s_id: int | None, tier: str) -> str | None:
2126
+ def __speaker(db: partial, s_id: int | None, tier: str) -> str | None:
2073
2127
  if s_id is None:
2074
2128
  return None
2075
2129
 
@@ -2082,8 +2136,14 @@ def _speaker(db: partial, s_id: int | None, tier: str) -> str | None:
2082
2136
  return data[0]
2083
2137
 
2084
2138
 
2139
+ def _target_truth_configs(db: partial, t_id: int, use_cache: bool = True) -> TruthConfigs:
2140
+ if use_cache:
2141
+ return __target_truth_configs(db, t_id)
2142
+ return __target_truth_configs.__wrapped__(db, t_id)
2143
+
2144
+
2085
2145
  @lru_cache
2086
- def _target_truth_configs(db: partial, t_id: int) -> TruthConfigs:
2146
+ def __target_truth_configs(db: partial, t_id: int) -> TruthConfigs:
2087
2147
  import json
2088
2148
 
2089
2149
  from .datatypes import TruthConfig