sonusai 0.19.9__py3-none-any.whl → 0.19.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/calc_metric_spenh.py +265 -233
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/genft.py +1 -1
- sonusai/genmetrics.py +15 -18
- sonusai/genmix.py +1 -1
- sonusai/genmixdb.py +30 -52
- sonusai/metrics_summary.py +320 -0
- sonusai/mixture/__init__.py +2 -1
- sonusai/mixture/audio.py +40 -7
- sonusai/mixture/generation.py +42 -53
- sonusai/mixture/helpers.py +22 -7
- sonusai/mixture/mixdb.py +90 -30
- sonusai/mixture/truth_functions/energy.py +9 -5
- sonusai/mixture/truth_functions/metrics.py +1 -1
- sonusai/mkwav.py +1 -1
- sonusai/onnx_predict.py +1 -1
- sonusai/queries/queries.py +1 -1
- sonusai/utils/asr.py +1 -1
- sonusai/utils/load_object.py +8 -2
- sonusai/utils/stratified_shuffle_split.py +1 -1
- {sonusai-0.19.9.dist-info → sonusai-0.19.10.dist-info}/METADATA +1 -1
- {sonusai-0.19.9.dist-info → sonusai-0.19.10.dist-info}/RECORD +25 -22
- {sonusai-0.19.9.dist-info → sonusai-0.19.10.dist-info}/WHEEL +0 -0
- {sonusai-0.19.9.dist-info → sonusai-0.19.10.dist-info}/entry_points.txt +0 -0
sonusai/mixture/audio.py
CHANGED
@@ -44,44 +44,77 @@ def validate_input_file(input_filepath: str | Path) -> None:
|
|
44
44
|
raise OSError(f"This installation cannot process .{ext} files")
|
45
45
|
|
46
46
|
|
47
|
-
|
48
|
-
def get_sample_rate(name: str | Path) -> int:
|
47
|
+
def get_sample_rate(name: str | Path, use_cache: bool = True) -> int:
|
49
48
|
"""Get sample rate from audio file
|
50
49
|
|
51
50
|
:param name: File name
|
51
|
+
:param use_cache: If true, use LRU caching
|
52
52
|
:return: Sample rate
|
53
53
|
"""
|
54
|
+
if use_cache:
|
55
|
+
return _get_sample_rate(name)
|
56
|
+
return _get_sample_rate.__wrapped__(name)
|
57
|
+
|
58
|
+
|
59
|
+
@lru_cache
|
60
|
+
def _get_sample_rate(name: str | Path) -> int:
|
54
61
|
from .soundfile_audio import get_sample_rate
|
55
62
|
|
56
63
|
return get_sample_rate(name)
|
57
64
|
|
58
65
|
|
59
|
-
|
60
|
-
def read_audio(name: str | Path) -> AudioT:
|
66
|
+
def read_audio(name: str | Path, use_cache: bool = True) -> AudioT:
|
61
67
|
"""Read audio data from a file
|
62
68
|
|
63
69
|
:param name: File name
|
70
|
+
:param use_cache: If true, use LRU caching
|
64
71
|
:return: Array of time domain audio data
|
65
72
|
"""
|
73
|
+
if use_cache:
|
74
|
+
return _read_audio(name)
|
75
|
+
return _read_audio.__wrapped__(name)
|
76
|
+
|
77
|
+
|
78
|
+
@lru_cache
|
79
|
+
def _read_audio(name: str | Path) -> AudioT:
|
66
80
|
from .soundfile_audio import read_audio
|
67
81
|
|
68
82
|
return read_audio(name)
|
69
83
|
|
70
84
|
|
71
|
-
|
72
|
-
def read_ir(name: str | Path) -> ImpulseResponseData:
|
85
|
+
def read_ir(name: str | Path, use_cache: bool = True) -> ImpulseResponseData:
|
73
86
|
"""Read impulse response data
|
74
87
|
|
75
88
|
:param name: File name
|
89
|
+
:param use_cache: If true, use LRU caching
|
76
90
|
:return: ImpulseResponseData object
|
77
91
|
"""
|
92
|
+
if use_cache:
|
93
|
+
return _read_ir(name)
|
94
|
+
return _read_ir.__wrapped__(name)
|
95
|
+
|
96
|
+
|
97
|
+
@lru_cache
|
98
|
+
def _read_ir(name: str | Path) -> ImpulseResponseData:
|
78
99
|
from .soundfile_audio import read_ir
|
79
100
|
|
80
101
|
return read_ir(name)
|
81
102
|
|
82
103
|
|
104
|
+
def get_num_samples(name: str | Path, use_cache: bool = True) -> int:
|
105
|
+
"""Get the number of samples resampled to the SonusAI sample rate in the given file
|
106
|
+
|
107
|
+
:param name: File name
|
108
|
+
:param use_cache: If true, use LRU caching
|
109
|
+
:return: number of samples in resampled audio
|
110
|
+
"""
|
111
|
+
if use_cache:
|
112
|
+
return _get_num_samples(name)
|
113
|
+
return _get_num_samples.__wrapped__(name)
|
114
|
+
|
115
|
+
|
83
116
|
@lru_cache
|
84
|
-
def
|
117
|
+
def _get_num_samples(name: str | Path) -> int:
|
85
118
|
"""Get the number of samples resampled to the SonusAI sample rate in the given file
|
86
119
|
|
87
120
|
:param name: File name
|
sonusai/mixture/generation.py
CHANGED
@@ -119,8 +119,7 @@ def initialize_db(location: str, test: bool = False) -> None:
|
|
119
119
|
id INTEGER PRIMARY KEY NOT NULL,
|
120
120
|
file_id INTEGER NOT NULL,
|
121
121
|
augmentation TEXT NOT NULL,
|
122
|
-
FOREIGN KEY(file_id) REFERENCES target_file (id)
|
123
|
-
UNIQUE(file_id, augmentation))
|
122
|
+
FOREIGN KEY(file_id) REFERENCES target_file (id))
|
124
123
|
""")
|
125
124
|
|
126
125
|
con.execute("""
|
@@ -389,8 +388,7 @@ def update_mixid_width(location: str, num_mixtures: int, test: bool = False) ->
|
|
389
388
|
con.close()
|
390
389
|
|
391
390
|
|
392
|
-
def
|
393
|
-
location: str,
|
391
|
+
def generate_mixtures(
|
394
392
|
noise_mix_mode: str,
|
395
393
|
augmented_targets: list[AugmentedTarget],
|
396
394
|
target_files: list[TargetFile],
|
@@ -403,13 +401,8 @@ def populate_mixture_table(
|
|
403
401
|
num_classes: int,
|
404
402
|
feature_step_samples: int,
|
405
403
|
num_ir: int,
|
406
|
-
|
407
|
-
|
408
|
-
"""Generate mixtures and populate mixture table"""
|
409
|
-
from .helpers import from_mixture
|
410
|
-
from .helpers import from_target
|
411
|
-
from .mixdb import db_connection
|
412
|
-
|
404
|
+
) -> tuple[int, int, list[Mixture]]:
|
405
|
+
"""Generate mixtures"""
|
413
406
|
if noise_mix_mode == "exhaustive":
|
414
407
|
func = _exhaustive_noise_mix
|
415
408
|
elif noise_mix_mode == "non-exhaustive":
|
@@ -419,7 +412,7 @@ def populate_mixture_table(
|
|
419
412
|
else:
|
420
413
|
raise ValueError(f"invalid noise_mix_mode: {noise_mix_mode}")
|
421
414
|
|
422
|
-
|
415
|
+
return func(
|
423
416
|
augmented_targets=augmented_targets,
|
424
417
|
target_files=target_files,
|
425
418
|
target_augmentations=target_augmentations,
|
@@ -433,20 +426,41 @@ def populate_mixture_table(
|
|
433
426
|
num_ir=num_ir,
|
434
427
|
)
|
435
428
|
|
429
|
+
|
430
|
+
def populate_mixture_table(
|
431
|
+
location: str,
|
432
|
+
mixtures: list[Mixture],
|
433
|
+
test: bool = False,
|
434
|
+
logging: bool = False,
|
435
|
+
show_progress: bool = False,
|
436
|
+
) -> None:
|
437
|
+
"""Populate mixture table"""
|
438
|
+
from sonusai import logger
|
439
|
+
from sonusai.utils import track
|
440
|
+
|
441
|
+
from .helpers import from_mixture
|
442
|
+
from .helpers import from_target
|
443
|
+
from .mixdb import db_connection
|
444
|
+
|
436
445
|
con = db_connection(location=location, readonly=False, test=test)
|
446
|
+
|
437
447
|
# Populate target table
|
448
|
+
if logging:
|
449
|
+
logger.info("Populating target table")
|
450
|
+
targets: list[tuple[int, str]] = []
|
438
451
|
for mixture in mixtures:
|
439
452
|
for target in mixture.targets:
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
from_target(target),
|
446
|
-
)
|
453
|
+
entry = from_target(target)
|
454
|
+
if entry not in targets:
|
455
|
+
targets.append(entry)
|
456
|
+
for target in track(targets, disable=not show_progress):
|
457
|
+
con.execute("INSERT INTO target (file_id, augmentation) VALUES (?, ?)", target)
|
447
458
|
|
448
459
|
# Populate mixture table
|
449
|
-
|
460
|
+
if logging:
|
461
|
+
logger.info("Populating mixture table")
|
462
|
+
for mixture in track(mixtures, disable=not show_progress):
|
463
|
+
m_id = int(mixture.name)
|
450
464
|
con.execute(
|
451
465
|
"""
|
452
466
|
INSERT INTO mixture (id, name, noise_file_id, noise_augmentation, noise_offset, noise_snr_gain, random_snr,
|
@@ -473,20 +487,13 @@ def populate_mixture_table(
|
|
473
487
|
con.commit()
|
474
488
|
con.close()
|
475
489
|
|
476
|
-
return used_noise_files, used_noise_samples
|
477
490
|
|
478
|
-
|
479
|
-
def update_mixture_table(location: str, m_id: int, with_data: bool = False, test: bool = False) -> GenMixData:
|
491
|
+
def update_mixture(mixdb: MixtureDatabase, mixture: Mixture, with_data: bool = False) -> tuple[Mixture, GenMixData]:
|
480
492
|
"""Update mixture record with name and gains"""
|
481
493
|
from .audio import get_next_noise
|
482
494
|
from .augmentation import apply_gain
|
483
|
-
from .datatypes import GenMixData
|
484
|
-
from .helpers import from_mixture
|
485
495
|
from .helpers import get_target
|
486
|
-
from .mixdb import db_connection
|
487
496
|
|
488
|
-
mixdb = MixtureDatabase(location, test)
|
489
|
-
mixture = mixdb.mixture(m_id)
|
490
497
|
mixture, targets_audio = _initialize_targets_audio(mixdb, mixture)
|
491
498
|
|
492
499
|
noise_audio = _augmented_noise_audio(mixdb, mixture)
|
@@ -501,29 +508,8 @@ def update_mixture_table(location: str, m_id: int, with_data: bool = False, test
|
|
501
508
|
|
502
509
|
mixture.name = f"{int(mixture.name):0{mixdb.mixid_width}}"
|
503
510
|
|
504
|
-
con = db_connection(location=location, readonly=False, test=test)
|
505
|
-
con.execute(
|
506
|
-
"""
|
507
|
-
UPDATE mixture SET name=?,
|
508
|
-
noise_file_id=?,
|
509
|
-
noise_augmentation=?,
|
510
|
-
noise_offset=?,
|
511
|
-
noise_snr_gain=?,
|
512
|
-
random_snr=?,
|
513
|
-
snr=?,
|
514
|
-
samples=?,
|
515
|
-
spectral_mask_id=?,
|
516
|
-
spectral_mask_seed=?,
|
517
|
-
target_snr_gain=?
|
518
|
-
WHERE ? = mixture.id
|
519
|
-
""",
|
520
|
-
(*from_mixture(mixture), m_id + 1),
|
521
|
-
)
|
522
|
-
con.commit()
|
523
|
-
con.close()
|
524
|
-
|
525
511
|
if not with_data:
|
526
|
-
return GenMixData()
|
512
|
+
return mixture, GenMixData()
|
527
513
|
|
528
514
|
# Apply SNR gains
|
529
515
|
targets_audio = [apply_gain(audio=target_audio, gain=mixture.target_snr_gain) for target_audio in targets_audio]
|
@@ -533,7 +519,7 @@ def update_mixture_table(location: str, m_id: int, with_data: bool = False, test
|
|
533
519
|
target_audio = get_target(mixdb, mixture, targets_audio)
|
534
520
|
mixture_audio = target_audio + noise_audio
|
535
521
|
|
536
|
-
return GenMixData(
|
522
|
+
return mixture, GenMixData(
|
537
523
|
mixture=mixture_audio,
|
538
524
|
targets=targets_audio,
|
539
525
|
target=target_audio,
|
@@ -553,7 +539,7 @@ def _augmented_noise_audio(mixdb: MixtureDatabase, mixture: Mixture) -> AudioT:
|
|
553
539
|
audio = read_audio(noise.name)
|
554
540
|
audio = apply_augmentation(audio, noise_augmentation)
|
555
541
|
if noise_augmentation.ir is not None:
|
556
|
-
audio = apply_impulse_response(audio, read_ir(mixdb.impulse_response_file(noise_augmentation.ir)))
|
542
|
+
audio = apply_impulse_response(audio, read_ir(mixdb.impulse_response_file(noise_augmentation.ir))) # pyright: ignore [reportArgumentType]
|
557
543
|
|
558
544
|
return audio
|
559
545
|
|
@@ -582,7 +568,10 @@ def _initialize_targets_audio(mixdb: MixtureDatabase, mixture: Mixture) -> tuple
|
|
582
568
|
|
583
569
|
|
584
570
|
def _initialize_mixture_gains(
|
585
|
-
mixdb: MixtureDatabase,
|
571
|
+
mixdb: MixtureDatabase,
|
572
|
+
mixture: Mixture,
|
573
|
+
target_audio: AudioT,
|
574
|
+
noise_audio: AudioT,
|
586
575
|
) -> Mixture:
|
587
576
|
import numpy as np
|
588
577
|
|
sonusai/mixture/helpers.py
CHANGED
@@ -135,14 +135,20 @@ def mixture_all_speech_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> lis
|
|
135
135
|
return results
|
136
136
|
|
137
137
|
|
138
|
-
def mixture_metadata(mixdb: MixtureDatabase, m_id: int) -> str:
|
138
|
+
def mixture_metadata(mixdb: MixtureDatabase, m_id: int | None = None, mixture: Mixture | None = None) -> str:
|
139
139
|
"""Create a string of metadata for a Mixture
|
140
140
|
|
141
141
|
:param mixdb: Mixture database
|
142
142
|
:param m_id: Mixture ID
|
143
|
+
:param mixture: Mixture record
|
143
144
|
:return: String of metadata
|
144
145
|
"""
|
145
|
-
|
146
|
+
if m_id is not None:
|
147
|
+
mixture = mixdb.mixture(m_id)
|
148
|
+
|
149
|
+
if mixture is None:
|
150
|
+
raise ValueError("No mixture specified.")
|
151
|
+
|
146
152
|
metadata = ""
|
147
153
|
speech_metadata = mixture_all_speech_metadata(mixdb, mixture)
|
148
154
|
for mi, target in enumerate(mixture.targets):
|
@@ -173,17 +179,25 @@ def mixture_metadata(mixdb: MixtureDatabase, m_id: int) -> str:
|
|
173
179
|
return metadata
|
174
180
|
|
175
181
|
|
176
|
-
def write_mixture_metadata(mixdb: MixtureDatabase, m_id: int) -> None:
|
182
|
+
def write_mixture_metadata(mixdb: MixtureDatabase, m_id: int | None = None, mixture: Mixture | None = None) -> None:
|
177
183
|
"""Write mixture metadata to a text file
|
178
184
|
|
179
185
|
:param mixdb: Mixture database
|
180
186
|
:param m_id: Mixture ID
|
187
|
+
:param mixture: Mixture record
|
181
188
|
"""
|
182
189
|
from os.path import join
|
183
190
|
|
184
|
-
|
191
|
+
if m_id is not None:
|
192
|
+
name = mixdb.mixture(m_id).name
|
193
|
+
elif mixture is not None:
|
194
|
+
name = mixture.name
|
195
|
+
else:
|
196
|
+
raise ValueError("No mixture specified.")
|
197
|
+
|
198
|
+
name = join(mixdb.location, "mixture", name, "metadata.txt")
|
185
199
|
with open(file=name, mode="w") as f:
|
186
|
-
f.write(mixture_metadata(mixdb, m_id))
|
200
|
+
f.write(mixture_metadata(mixdb, m_id, mixture))
|
187
201
|
|
188
202
|
|
189
203
|
def from_mixture(
|
@@ -246,12 +260,13 @@ def to_target(entry: TargetRecord) -> Target:
|
|
246
260
|
)
|
247
261
|
|
248
262
|
|
249
|
-
def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: list[AudioT]) -> AudioT:
|
263
|
+
def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: list[AudioT], use_cache: bool = True) -> AudioT:
|
250
264
|
"""Get the augmented target audio data for the given mixture record
|
251
265
|
|
252
266
|
:param mixdb: Mixture database
|
253
267
|
:param mixture: Mixture record
|
254
268
|
:param targets_audio: List of augmented target audio data (one per target in the mixup)
|
269
|
+
:param use_cache: If true, use LRU caching
|
255
270
|
:return: Sum of augmented target audio data
|
256
271
|
"""
|
257
272
|
# Apply impulse responses to targets
|
@@ -265,7 +280,7 @@ def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: list[Aud
|
|
265
280
|
ir_idx = mixture.targets[idx].augmentation.ir
|
266
281
|
if ir_idx is not None:
|
267
282
|
targets_ir.append(
|
268
|
-
apply_impulse_response(audio=target, ir=read_ir(mixdb.impulse_response_file(int(ir_idx))))
|
283
|
+
apply_impulse_response(audio=target, ir=read_ir(mixdb.impulse_response_file(int(ir_idx)), use_cache)) # pyright: ignore [reportArgumentType]
|
269
284
|
)
|
270
285
|
else:
|
271
286
|
targets_ir.append(target)
|
sonusai/mixture/mixdb.py
CHANGED
@@ -61,7 +61,7 @@ def db_connection(
|
|
61
61
|
if not create and readonly:
|
62
62
|
name += "?mode=ro"
|
63
63
|
|
64
|
-
connection = sqlite3.connect("file:" + name, uri=True)
|
64
|
+
connection = sqlite3.connect("file:" + name, uri=True, timeout=20)
|
65
65
|
|
66
66
|
if verbose:
|
67
67
|
connection.set_trace_callback(print)
|
@@ -84,7 +84,7 @@ class SQLiteContextManager:
|
|
84
84
|
|
85
85
|
|
86
86
|
class MixtureDatabase:
|
87
|
-
def __init__(self, location: str, test: bool = False) -> None:
|
87
|
+
def __init__(self, location: str, test: bool = False, use_cache: bool = True) -> None:
|
88
88
|
import json
|
89
89
|
from os.path import exists
|
90
90
|
|
@@ -92,6 +92,7 @@ class MixtureDatabase:
|
|
92
92
|
|
93
93
|
self.location = location
|
94
94
|
self.test = test
|
95
|
+
self.use_cache = use_cache
|
95
96
|
|
96
97
|
if not exists(db_file(self.location, self.test)):
|
97
98
|
raise OSError(f"Could not find mixture database in {self.location}")
|
@@ -121,7 +122,7 @@ class MixtureDatabase:
|
|
121
122
|
class_weights_threshold=self.class_weights_thresholds,
|
122
123
|
feature=self.feature,
|
123
124
|
impulse_response_files=self.impulse_response_files,
|
124
|
-
mixtures=self.mixtures,
|
125
|
+
mixtures=self.mixtures(),
|
125
126
|
noise_mix_mode=self.noise_mix_mode,
|
126
127
|
noise_files=self.noise_files,
|
127
128
|
num_classes=self.num_classes,
|
@@ -488,7 +489,7 @@ class MixtureDatabase:
|
|
488
489
|
return truth_configs
|
489
490
|
|
490
491
|
def target_truth_configs(self, t_id: int) -> TruthConfigs:
|
491
|
-
return _target_truth_configs(self.db, t_id)
|
492
|
+
return _target_truth_configs(self.db, t_id, self.use_cache)
|
492
493
|
|
493
494
|
@cached_property
|
494
495
|
def random_snrs(self) -> list[float]:
|
@@ -556,7 +557,7 @@ class MixtureDatabase:
|
|
556
557
|
:param sm_id: Spectral mask ID
|
557
558
|
:return: Spectral mask
|
558
559
|
"""
|
559
|
-
return _spectral_mask(self.db, sm_id)
|
560
|
+
return _spectral_mask(self.db, sm_id, self.use_cache)
|
560
561
|
|
561
562
|
@cached_property
|
562
563
|
def target_files(self) -> list[TargetFile]:
|
@@ -619,7 +620,7 @@ class MixtureDatabase:
|
|
619
620
|
:param t_id: Target file ID
|
620
621
|
:return: Target file
|
621
622
|
"""
|
622
|
-
return _target_file(self.db, t_id)
|
623
|
+
return _target_file(self.db, t_id, self.use_cache)
|
623
624
|
|
624
625
|
@cached_property
|
625
626
|
def num_target_files(self) -> int:
|
@@ -657,7 +658,7 @@ class MixtureDatabase:
|
|
657
658
|
:param n_id: Noise file ID
|
658
659
|
:return: Noise file
|
659
660
|
"""
|
660
|
-
return _noise_file(self.db, n_id)
|
661
|
+
return _noise_file(self.db, n_id, self.use_cache)
|
661
662
|
|
662
663
|
@cached_property
|
663
664
|
def num_noise_files(self) -> int:
|
@@ -706,7 +707,7 @@ class MixtureDatabase:
|
|
706
707
|
"""
|
707
708
|
if ir_id is None:
|
708
709
|
return None
|
709
|
-
return _impulse_response_file(self.db, ir_id)
|
710
|
+
return _impulse_response_file(self.db, ir_id, self.use_cache)
|
710
711
|
|
711
712
|
@cached_property
|
712
713
|
def num_impulse_response_files(self) -> int:
|
@@ -717,7 +718,6 @@ class MixtureDatabase:
|
|
717
718
|
with self.db() as c:
|
718
719
|
return int(c.execute("SELECT count(impulse_response_file.id) FROM impulse_response_file").fetchone()[0])
|
719
720
|
|
720
|
-
@cached_property
|
721
721
|
def mixtures(self) -> list[Mixture]:
|
722
722
|
"""Get mixtures from db
|
723
723
|
|
@@ -760,7 +760,7 @@ class MixtureDatabase:
|
|
760
760
|
:param m_id: Zero-based mixture ID
|
761
761
|
:return: Mixture record
|
762
762
|
"""
|
763
|
-
return _mixture(self.db, m_id)
|
763
|
+
return _mixture(self.db, m_id, self.use_cache)
|
764
764
|
|
765
765
|
@cached_property
|
766
766
|
def mixid_width(self) -> int:
|
@@ -805,7 +805,7 @@ class MixtureDatabase:
|
|
805
805
|
"""
|
806
806
|
from .audio import read_audio
|
807
807
|
|
808
|
-
return read_audio(self.target_file(t_id).name)
|
808
|
+
return read_audio(self.target_file(t_id).name, self.use_cache)
|
809
809
|
|
810
810
|
def augmented_noise_audio(self, mixture: Mixture) -> AudioT:
|
811
811
|
"""Get augmented noise audio
|
@@ -819,12 +819,12 @@ class MixtureDatabase:
|
|
819
819
|
from .augmentation import apply_impulse_response
|
820
820
|
|
821
821
|
noise = self.noise_file(mixture.noise.file_id)
|
822
|
-
audio = read_audio(noise.name)
|
822
|
+
audio = read_audio(noise.name, self.use_cache)
|
823
823
|
audio = apply_augmentation(audio, mixture.noise.augmentation)
|
824
824
|
if mixture.noise.augmentation.ir is not None:
|
825
825
|
audio = apply_impulse_response(
|
826
826
|
audio,
|
827
|
-
read_ir(self.impulse_response_file(mixture.noise.augmentation.ir)),
|
827
|
+
read_ir(self.impulse_response_file(mixture.noise.augmentation.ir), self.use_cache), # pyright: ignore [reportArgumentType]
|
828
828
|
)
|
829
829
|
|
830
830
|
return audio
|
@@ -1332,7 +1332,7 @@ class MixtureDatabase:
|
|
1332
1332
|
return sorted(set(self.speaker_metadata_tiers + self.textgrid_metadata_tiers))
|
1333
1333
|
|
1334
1334
|
def speaker(self, s_id: int | None, tier: str) -> str | None:
|
1335
|
-
return _speaker(self.db, s_id, tier)
|
1335
|
+
return _speaker(self.db, s_id, tier, self.use_cache)
|
1336
1336
|
|
1337
1337
|
def speech_metadata(self, tier: str) -> list[str]:
|
1338
1338
|
from .helpers import get_textgrid_tier_from_target_file
|
@@ -1464,7 +1464,7 @@ class MixtureDatabase:
|
|
1464
1464
|
|
1465
1465
|
return sorted(result)
|
1466
1466
|
|
1467
|
-
def mixture_metrics(self, m_id: int, metrics: list[str], force: bool = False) ->
|
1467
|
+
def mixture_metrics(self, m_id: int, metrics: list[str], force: bool = False) -> dict[str, Any]:
|
1468
1468
|
"""Get metrics data for the given mixture ID
|
1469
1469
|
|
1470
1470
|
:param m_id: Zero-based mixture ID
|
@@ -1916,21 +1916,34 @@ class MixtureDatabase:
|
|
1916
1916
|
|
1917
1917
|
raise AttributeError(f"Unrecognized metric: '{m}'")
|
1918
1918
|
|
1919
|
-
result:
|
1919
|
+
result: dict[str, Any] = {}
|
1920
1920
|
for metric in metrics:
|
1921
|
-
result
|
1921
|
+
result[metric] = calc(metric)
|
1922
|
+
|
1923
|
+
# Check for metrics dependencies and add them even if not explicitly requested.
|
1924
|
+
if metric.startswith("mxwer"):
|
1925
|
+
dependencies = ("mxasr." + metric[6:], "tasr." + metric[6:])
|
1926
|
+
for dependency in dependencies:
|
1927
|
+
result[dependency] = calc(dependency)
|
1922
1928
|
|
1923
1929
|
return result
|
1924
1930
|
|
1925
1931
|
|
1926
|
-
|
1927
|
-
def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
|
1932
|
+
def _spectral_mask(db: partial, sm_id: int, use_cache: bool = True) -> SpectralMask:
|
1928
1933
|
"""Get spectral mask with ID from db
|
1929
1934
|
|
1930
1935
|
:param db: Database context
|
1931
1936
|
:param sm_id: Spectral mask ID
|
1937
|
+
:param use_cache: If true, use LRU caching
|
1932
1938
|
:return: Spectral mask
|
1933
1939
|
"""
|
1940
|
+
if use_cache:
|
1941
|
+
return __spectral_mask(db, sm_id)
|
1942
|
+
return __spectral_mask.__wrapped__(db, sm_id)
|
1943
|
+
|
1944
|
+
|
1945
|
+
@lru_cache
|
1946
|
+
def __spectral_mask(db: partial, sm_id: int) -> SpectralMask:
|
1934
1947
|
from .db_datatypes import SpectralMaskRecord
|
1935
1948
|
|
1936
1949
|
with db() as c:
|
@@ -1953,12 +1966,26 @@ def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
|
|
1953
1966
|
)
|
1954
1967
|
|
1955
1968
|
|
1969
|
+
def _target_file(db: partial, t_id: int, use_cache: bool = True) -> TargetFile:
|
1970
|
+
"""Get target file with ID from db
|
1971
|
+
|
1972
|
+
:param db: Database context
|
1973
|
+
:param t_id: Target file ID
|
1974
|
+
:param use_cache: If true, use LRU caching
|
1975
|
+
:return: Target file
|
1976
|
+
"""
|
1977
|
+
if use_cache:
|
1978
|
+
return __target_file(db, t_id, use_cache)
|
1979
|
+
return __target_file.__wrapped__(db, t_id, use_cache)
|
1980
|
+
|
1981
|
+
|
1956
1982
|
@lru_cache
|
1957
|
-
def
|
1983
|
+
def __target_file(db: partial, t_id: int, use_cache: bool = True) -> TargetFile:
|
1958
1984
|
"""Get target file with ID from db
|
1959
1985
|
|
1960
1986
|
:param db: Database context
|
1961
1987
|
:param t_id: Target file ID
|
1988
|
+
:param use_cache: If true, use LRU caching
|
1962
1989
|
:return: Target file
|
1963
1990
|
"""
|
1964
1991
|
import json
|
@@ -1982,19 +2009,26 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
|
|
1982
2009
|
samples=target_file.samples,
|
1983
2010
|
class_indices=json.loads(target_file.class_indices),
|
1984
2011
|
level_type=target_file.level_type,
|
1985
|
-
truth_configs=_target_truth_configs(db, t_id),
|
2012
|
+
truth_configs=_target_truth_configs(db, t_id, use_cache),
|
1986
2013
|
speaker_id=target_file.speaker_id,
|
1987
2014
|
)
|
1988
2015
|
|
1989
2016
|
|
1990
|
-
|
1991
|
-
def _noise_file(db: partial, n_id: int) -> NoiseFile:
|
2017
|
+
def _noise_file(db: partial, n_id: int, use_cache: bool = True) -> NoiseFile:
|
1992
2018
|
"""Get noise file with ID from db
|
1993
2019
|
|
1994
2020
|
:param db: Database context
|
1995
2021
|
:param n_id: Noise file ID
|
2022
|
+
:param use_cache: If true, use LRU caching
|
1996
2023
|
:return: Noise file
|
1997
2024
|
"""
|
2025
|
+
if use_cache:
|
2026
|
+
return __noise_file(db, n_id)
|
2027
|
+
return __noise_file.__wrapped__(db, n_id)
|
2028
|
+
|
2029
|
+
|
2030
|
+
@lru_cache
|
2031
|
+
def __noise_file(db: partial, n_id: int) -> NoiseFile:
|
1998
2032
|
with db() as c:
|
1999
2033
|
noise = c.execute(
|
2000
2034
|
"""
|
@@ -2007,14 +2041,21 @@ def _noise_file(db: partial, n_id: int) -> NoiseFile:
|
|
2007
2041
|
return NoiseFile(name=noise[0], samples=noise[1])
|
2008
2042
|
|
2009
2043
|
|
2010
|
-
|
2011
|
-
def _impulse_response_file(db: partial, ir_id: int) -> str:
|
2044
|
+
def _impulse_response_file(db: partial, ir_id: int, use_cache: bool = True) -> str:
|
2012
2045
|
"""Get impulse response file with ID from db
|
2013
2046
|
|
2014
2047
|
:param db: Database context
|
2015
2048
|
:param ir_id: Impulse response file ID
|
2016
|
-
:
|
2049
|
+
:param use_cache: If true, use LRU caching
|
2050
|
+
:return: Impulse response
|
2017
2051
|
"""
|
2052
|
+
if use_cache:
|
2053
|
+
return __impulse_response_file(db, ir_id)
|
2054
|
+
return __impulse_response_file.__wrapped__(db, ir_id)
|
2055
|
+
|
2056
|
+
|
2057
|
+
@lru_cache
|
2058
|
+
def __impulse_response_file(db: partial, ir_id: int) -> str:
|
2018
2059
|
with db() as c:
|
2019
2060
|
return str(
|
2020
2061
|
c.execute(
|
@@ -2028,14 +2069,21 @@ def _impulse_response_file(db: partial, ir_id: int) -> str:
|
|
2028
2069
|
)
|
2029
2070
|
|
2030
2071
|
|
2031
|
-
|
2032
|
-
def _mixture(db: partial, m_id: int) -> Mixture:
|
2072
|
+
def _mixture(db: partial, m_id: int, use_cache: bool = True) -> Mixture:
|
2033
2073
|
"""Get mixture record with ID from db
|
2034
2074
|
|
2035
2075
|
:param db: Database context
|
2036
2076
|
:param m_id: Zero-based mixture ID
|
2077
|
+
:param use_cache: If true, use LRU caching
|
2037
2078
|
:return: Mixture record
|
2038
2079
|
"""
|
2080
|
+
if use_cache:
|
2081
|
+
return __mixture(db, m_id)
|
2082
|
+
return __mixture.__wrapped__(db, m_id)
|
2083
|
+
|
2084
|
+
|
2085
|
+
@lru_cache
|
2086
|
+
def __mixture(db: partial, m_id: int) -> Mixture:
|
2039
2087
|
from .db_datatypes import MixtureRecord
|
2040
2088
|
from .db_datatypes import TargetRecord
|
2041
2089
|
from .helpers import to_mixture
|
@@ -2068,8 +2116,14 @@ def _mixture(db: partial, m_id: int) -> Mixture:
|
|
2068
2116
|
return to_mixture(mixture, targets)
|
2069
2117
|
|
2070
2118
|
|
2119
|
+
def _speaker(db: partial, s_id: int | None, tier: str, use_cache: bool = True) -> str | None:
|
2120
|
+
if use_cache:
|
2121
|
+
return __speaker(db, s_id, tier)
|
2122
|
+
return __speaker.__wrapped__(db, s_id, tier)
|
2123
|
+
|
2124
|
+
|
2071
2125
|
@lru_cache
|
2072
|
-
def
|
2126
|
+
def __speaker(db: partial, s_id: int | None, tier: str) -> str | None:
|
2073
2127
|
if s_id is None:
|
2074
2128
|
return None
|
2075
2129
|
|
@@ -2082,8 +2136,14 @@ def _speaker(db: partial, s_id: int | None, tier: str) -> str | None:
|
|
2082
2136
|
return data[0]
|
2083
2137
|
|
2084
2138
|
|
2139
|
+
def _target_truth_configs(db: partial, t_id: int, use_cache: bool = True) -> TruthConfigs:
|
2140
|
+
if use_cache:
|
2141
|
+
return __target_truth_configs(db, t_id)
|
2142
|
+
return __target_truth_configs.__wrapped__(db, t_id)
|
2143
|
+
|
2144
|
+
|
2085
2145
|
@lru_cache
|
2086
|
-
def
|
2146
|
+
def __target_truth_configs(db: partial, t_id: int) -> TruthConfigs:
|
2087
2147
|
import json
|
2088
2148
|
|
2089
2149
|
from .datatypes import TruthConfig
|