sonusai 0.19.6__py3-none-any.whl → 0.19.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +1 -1
- sonusai/aawscd_probwrite.py +1 -1
- sonusai/calc_metric_spenh.py +1 -1
- sonusai/genft.py +29 -14
- sonusai/genmetrics.py +60 -42
- sonusai/genmix.py +41 -29
- sonusai/genmixdb.py +56 -64
- sonusai/metrics/calc_class_weights.py +1 -3
- sonusai/metrics/calc_optimal_thresholds.py +2 -2
- sonusai/metrics/calc_phase_distance.py +1 -1
- sonusai/metrics/calc_speech.py +6 -6
- sonusai/metrics/class_summary.py +6 -15
- sonusai/metrics/confusion_matrix_summary.py +11 -27
- sonusai/metrics/one_hot.py +3 -3
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/mixture/__init__.py +2 -17
- sonusai/mixture/augmentation.py +5 -6
- sonusai/mixture/class_count.py +1 -1
- sonusai/mixture/config.py +36 -46
- sonusai/mixture/data_io.py +30 -1
- sonusai/mixture/datatypes.py +29 -40
- sonusai/mixture/db_datatypes.py +1 -1
- sonusai/mixture/feature.py +3 -23
- sonusai/mixture/generation.py +161 -204
- sonusai/mixture/helpers.py +29 -187
- sonusai/mixture/mixdb.py +386 -159
- sonusai/mixture/soundfile_audio.py +1 -1
- sonusai/mixture/sox_audio.py +4 -4
- sonusai/mixture/sox_augmentation.py +1 -1
- sonusai/mixture/target_class_balancing.py +9 -11
- sonusai/mixture/targets.py +23 -20
- sonusai/mixture/torchaudio_audio.py +18 -7
- sonusai/mixture/torchaudio_augmentation.py +3 -4
- sonusai/mixture/truth.py +21 -34
- sonusai/mixture/truth_functions/__init__.py +6 -0
- sonusai/mixture/truth_functions/crm.py +51 -37
- sonusai/mixture/truth_functions/energy.py +95 -50
- sonusai/mixture/truth_functions/file.py +12 -8
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +4 -5
- sonusai/mixture/truth_functions/sed.py +32 -23
- sonusai/mixture/truth_functions/target.py +62 -29
- sonusai/mkwav.py +20 -19
- sonusai/queries/queries.py +9 -15
- sonusai/speech/l2arctic.py +6 -2
- sonusai/summarize_metric_spenh.py +1 -1
- sonusai/utils/__init__.py +1 -0
- sonusai/utils/asr_functions/aaware_whisper.py +1 -1
- sonusai/utils/audio_devices.py +27 -18
- sonusai/utils/docstring.py +6 -3
- sonusai/utils/energy_f.py +5 -3
- sonusai/utils/human_readable_size.py +6 -6
- sonusai/utils/load_object.py +15 -0
- sonusai/utils/onnx_utils.py +2 -2
- sonusai/utils/print_mixture_details.py +3 -3
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/METADATA +2 -2
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/RECORD +60 -58
- sonusai/mixture/truth_functions/datatypes.py +0 -37
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/WHEEL +0 -0
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py
CHANGED
@@ -8,27 +8,21 @@ from typing import Any
|
|
8
8
|
|
9
9
|
from .datatypes import ASRConfigs
|
10
10
|
from .datatypes import AudioF
|
11
|
-
from .datatypes import AudiosF
|
12
|
-
from .datatypes import AudiosT
|
13
11
|
from .datatypes import AudioT
|
14
12
|
from .datatypes import ClassCount
|
15
13
|
from .datatypes import Feature
|
16
14
|
from .datatypes import FeatureGeneratorConfig
|
17
15
|
from .datatypes import FeatureGeneratorInfo
|
18
16
|
from .datatypes import GeneralizedIDs
|
19
|
-
from .datatypes import
|
17
|
+
from .datatypes import ImpulseResponseFile
|
20
18
|
from .datatypes import MetricDoc
|
21
19
|
from .datatypes import MetricDocs
|
22
20
|
from .datatypes import Mixture
|
23
|
-
from .datatypes import Mixtures
|
24
21
|
from .datatypes import NoiseFile
|
25
|
-
from .datatypes import NoiseFiles
|
26
22
|
from .datatypes import Segsnr
|
27
23
|
from .datatypes import SpectralMask
|
28
|
-
from .datatypes import SpectralMasks
|
29
24
|
from .datatypes import SpeechMetadata
|
30
25
|
from .datatypes import TargetFile
|
31
|
-
from .datatypes import TargetFiles
|
32
26
|
from .datatypes import TransformConfig
|
33
27
|
from .datatypes import TruthConfigs
|
34
28
|
from .datatypes import TruthDict
|
@@ -46,7 +40,13 @@ def db_file(location: str, test: bool = False) -> str:
|
|
46
40
|
return join(location, name)
|
47
41
|
|
48
42
|
|
49
|
-
def db_connection(
|
43
|
+
def db_connection(
|
44
|
+
location: str,
|
45
|
+
create: bool = False,
|
46
|
+
readonly: bool = True,
|
47
|
+
test: bool = False,
|
48
|
+
verbose: bool = False,
|
49
|
+
) -> Connection:
|
50
50
|
import sqlite3
|
51
51
|
from os import remove
|
52
52
|
from os.path import exists
|
@@ -62,7 +62,10 @@ def db_connection(location: str, create: bool = False, readonly: bool = True, te
|
|
62
62
|
name += "?mode=ro"
|
63
63
|
|
64
64
|
connection = sqlite3.connect("file:" + name, uri=True)
|
65
|
-
|
65
|
+
|
66
|
+
if verbose:
|
67
|
+
connection.set_trace_callback(print)
|
68
|
+
|
66
69
|
return connection
|
67
70
|
|
68
71
|
|
@@ -82,8 +85,30 @@ class SQLiteContextManager:
|
|
82
85
|
|
83
86
|
class MixtureDatabase:
|
84
87
|
def __init__(self, location: str, test: bool = False) -> None:
|
88
|
+
import json
|
89
|
+
from os.path import exists
|
90
|
+
|
91
|
+
from .config import load_config
|
92
|
+
|
85
93
|
self.location = location
|
86
|
-
self.
|
94
|
+
self.test = test
|
95
|
+
|
96
|
+
if not exists(db_file(self.location, self.test)):
|
97
|
+
raise OSError(f"Could not find mixture database in {self.location}")
|
98
|
+
|
99
|
+
self.db = partial(SQLiteContextManager, self.location, self.test)
|
100
|
+
|
101
|
+
# Check config.yml to see if asr_configs has changed and update database if needed
|
102
|
+
config = load_config(self.location)
|
103
|
+
new_asr_configs = json.dumps(config["asr_configs"])
|
104
|
+
with self.db() as c:
|
105
|
+
old_asr_configs = c.execute("SELECT top.asr_configs FROM top").fetchone()
|
106
|
+
|
107
|
+
if old_asr_configs is not None and new_asr_configs != old_asr_configs[0]:
|
108
|
+
con = db_connection(location=self.location, readonly=False, test=self.test)
|
109
|
+
con.execute("UPDATE top SET asr_configs = ? WHERE ? = id", (new_asr_configs,))
|
110
|
+
con.commit()
|
111
|
+
con.close()
|
87
112
|
|
88
113
|
@cached_property
|
89
114
|
def json(self) -> str:
|
@@ -127,10 +152,10 @@ class MixtureDatabase:
|
|
127
152
|
return get_feature_generator_info(self.fg_config)
|
128
153
|
|
129
154
|
@cached_property
|
130
|
-
def truth_parameters(self) -> dict[str, int]:
|
155
|
+
def truth_parameters(self) -> dict[str, int | None]:
|
131
156
|
with self.db() as c:
|
132
157
|
rows = c.execute("SELECT * FROM truth_parameters").fetchall()
|
133
|
-
truth_parameters: dict[str, int] = {}
|
158
|
+
truth_parameters: dict[str, int | None] = {}
|
134
159
|
for row in rows:
|
135
160
|
truth_parameters[row[1]] = row[2]
|
136
161
|
return truth_parameters
|
@@ -197,48 +222,58 @@ class MixtureDatabase:
|
|
197
222
|
"mxssnrdbf_std",
|
198
223
|
"Per-bin segmental standard deviation of the dB frame values over all frames (using feature transform)",
|
199
224
|
),
|
200
|
-
MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true
|
225
|
+
MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true targets"),
|
201
226
|
MetricDoc(
|
202
227
|
"Mixture Metrics",
|
203
228
|
"mxwsdr",
|
204
|
-
"Weighted signal
|
229
|
+
"Weighted signal distortion ratio of mixture versus true targets",
|
205
230
|
),
|
206
231
|
MetricDoc(
|
207
232
|
"Mixture Metrics",
|
208
233
|
"mxpd",
|
209
|
-
"Phase distance between mixture and true
|
234
|
+
"Phase distance between mixture and true targets",
|
210
235
|
),
|
211
236
|
MetricDoc(
|
212
237
|
"Mixture Metrics",
|
213
238
|
"mxstoi",
|
214
|
-
"Short term objective intelligibility of mixture versus true
|
239
|
+
"Short term objective intelligibility of mixture versus true targets",
|
215
240
|
),
|
216
241
|
MetricDoc(
|
217
242
|
"Mixture Metrics",
|
218
243
|
"mxcsig",
|
219
|
-
"Predicted rating of speech distortion of mixture versus true
|
244
|
+
"Predicted rating of speech distortion of mixture versus true targets",
|
220
245
|
),
|
221
246
|
MetricDoc(
|
222
247
|
"Mixture Metrics",
|
223
248
|
"mxcbak",
|
224
|
-
"Predicted rating of background distortion of mixture versus true
|
249
|
+
"Predicted rating of background distortion of mixture versus true targets",
|
225
250
|
),
|
226
251
|
MetricDoc(
|
227
252
|
"Mixture Metrics",
|
228
253
|
"mxcovl",
|
229
|
-
"Predicted rating of overall quality of mixture versus true
|
254
|
+
"Predicted rating of overall quality of mixture versus true targets",
|
230
255
|
),
|
231
256
|
MetricDoc("Mixture Metrics", "ssnr", "Segmental SNR"),
|
232
|
-
MetricDoc("
|
233
|
-
MetricDoc("
|
234
|
-
MetricDoc("
|
235
|
-
MetricDoc("
|
236
|
-
MetricDoc("
|
237
|
-
MetricDoc("
|
238
|
-
MetricDoc("
|
239
|
-
MetricDoc("
|
240
|
-
MetricDoc("
|
241
|
-
MetricDoc("
|
257
|
+
MetricDoc("Mixture Metrics", "mxtdco", "Mixture target DC offset"),
|
258
|
+
MetricDoc("Mixture Metrics", "mxtmin", "Mixture target min level"),
|
259
|
+
MetricDoc("Mixture Metrics", "mxtmax", "Mixture target max levl"),
|
260
|
+
MetricDoc("Mixture Metrics", "mxtpkdb", "Mixture target Pk lev dB"),
|
261
|
+
MetricDoc("Mixture Metrics", "mxtlrms", "Mixture target RMS lev dB"),
|
262
|
+
MetricDoc("Mixture Metrics", "mxtpkr", "Mixture target RMS Pk dB"),
|
263
|
+
MetricDoc("Mixture Metrics", "mxttr", "Mixture target RMS Tr dB"),
|
264
|
+
MetricDoc("Mixture Metrics", "mxtcr", "Mixture target Crest factor"),
|
265
|
+
MetricDoc("Mixture Metrics", "mxtfl", "Mixture target Flat factor"),
|
266
|
+
MetricDoc("Mixture Metrics", "mxtpkc", "Mixture target Pk count"),
|
267
|
+
MetricDoc("Targets Metrics", "tdco", "Targets DC offset"),
|
268
|
+
MetricDoc("Targets Metrics", "tmin", "Targets min level"),
|
269
|
+
MetricDoc("Targets Metrics", "tmax", "Targets max levl"),
|
270
|
+
MetricDoc("Targets Metrics", "tpkdb", "Targets Pk lev dB"),
|
271
|
+
MetricDoc("Targets Metrics", "tlrms", "Targets RMS lev dB"),
|
272
|
+
MetricDoc("Targets Metrics", "tpkr", "Targets RMS Pk dB"),
|
273
|
+
MetricDoc("Targets Metrics", "ttr", "Targets RMS Tr dB"),
|
274
|
+
MetricDoc("Targets Metrics", "tcr", "Targets Crest factor"),
|
275
|
+
MetricDoc("Targets Metrics", "tfl", "Targets Flat factor"),
|
276
|
+
MetricDoc("Targets Metrics", "tpkc", "Targets Pk count"),
|
242
277
|
MetricDoc("Noise Metrics", "ndco", "Noise DC offset"),
|
243
278
|
MetricDoc("Noise Metrics", "nmin", "Noise min level"),
|
244
279
|
MetricDoc("Noise Metrics", "nmax", "Noise max levl"),
|
@@ -272,11 +307,18 @@ class MixtureDatabase:
|
|
272
307
|
]
|
273
308
|
)
|
274
309
|
for name in self.asr_configs:
|
310
|
+
metrics.append(
|
311
|
+
MetricDoc(
|
312
|
+
"Target Metrics",
|
313
|
+
f"mxtasr.{name}",
|
314
|
+
f"Mixture Target ASR text using {name} ASR as defined in mixdb asr_configs parameter",
|
315
|
+
)
|
316
|
+
)
|
275
317
|
metrics.append(
|
276
318
|
MetricDoc(
|
277
319
|
"Target Metrics",
|
278
320
|
f"tasr.{name}",
|
279
|
-
f"
|
321
|
+
f"Targets ASR text using {name} ASR as defined in mixdb asr_configs parameter",
|
280
322
|
)
|
281
323
|
)
|
282
324
|
metrics.append(
|
@@ -486,7 +528,7 @@ class MixtureDatabase:
|
|
486
528
|
)
|
487
529
|
|
488
530
|
@cached_property
|
489
|
-
def spectral_masks(self) ->
|
531
|
+
def spectral_masks(self) -> list[SpectralMask]:
|
490
532
|
"""Get spectral masks from db
|
491
533
|
|
492
534
|
:return: Spectral masks
|
@@ -517,7 +559,7 @@ class MixtureDatabase:
|
|
517
559
|
return _spectral_mask(self.db, sm_id)
|
518
560
|
|
519
561
|
@cached_property
|
520
|
-
def target_files(self) ->
|
562
|
+
def target_files(self) -> list[TargetFile]:
|
521
563
|
"""Get target files from db
|
522
564
|
|
523
565
|
:return: Target files
|
@@ -529,17 +571,19 @@ class MixtureDatabase:
|
|
529
571
|
from .db_datatypes import TargetFileRecord
|
530
572
|
|
531
573
|
with self.db() as c:
|
532
|
-
target_files:
|
574
|
+
target_files: list[TargetFile] = []
|
533
575
|
target_file_records = [
|
534
576
|
TargetFileRecord(*result) for result in c.execute("SELECT * FROM target_file").fetchall()
|
535
577
|
]
|
536
578
|
for target_file_record in target_file_records:
|
537
579
|
truth_configs: TruthConfigs = {}
|
538
580
|
for truth_config_records in c.execute(
|
539
|
-
"
|
540
|
-
|
541
|
-
|
542
|
-
|
581
|
+
"""
|
582
|
+
SELECT truth_config.config
|
583
|
+
FROM truth_config, target_file_truth_config
|
584
|
+
WHERE ? = target_file_truth_config.target_file_id
|
585
|
+
AND truth_config.id = target_file_truth_config.truth_config_id
|
586
|
+
""",
|
543
587
|
(target_file_record.id,),
|
544
588
|
).fetchall():
|
545
589
|
truth_config = json.loads(truth_config_records[0])
|
@@ -587,7 +631,7 @@ class MixtureDatabase:
|
|
587
631
|
return int(c.execute("SELECT count(target_file.id) FROM target_file").fetchone()[0])
|
588
632
|
|
589
633
|
@cached_property
|
590
|
-
def noise_files(self) ->
|
634
|
+
def noise_files(self) -> list[NoiseFile]:
|
591
635
|
"""Get noise files from db
|
592
636
|
|
593
637
|
:return: Noise files
|
@@ -625,7 +669,7 @@ class MixtureDatabase:
|
|
625
669
|
return int(c.execute("SELECT count(noise_file.id) FROM noise_file").fetchone()[0])
|
626
670
|
|
627
671
|
@cached_property
|
628
|
-
def impulse_response_files(self) ->
|
672
|
+
def impulse_response_files(self) -> list[ImpulseResponseFile]:
|
629
673
|
"""Get impulse response files from db
|
630
674
|
|
631
675
|
:return: Impulse response files
|
@@ -635,10 +679,6 @@ class MixtureDatabase:
|
|
635
679
|
from .datatypes import ImpulseResponseFile
|
636
680
|
|
637
681
|
with self.db() as c:
|
638
|
-
# for impulse_response in c.execute(
|
639
|
-
# "SELECT impulse_response_file.* FROM impulse_response_file"
|
640
|
-
# ).fetchall():
|
641
|
-
# print(impulse_response)
|
642
682
|
return [
|
643
683
|
ImpulseResponseFile(impulse_response[1], json.loads(impulse_response[2]))
|
644
684
|
for impulse_response in c.execute(
|
@@ -678,7 +718,7 @@ class MixtureDatabase:
|
|
678
718
|
return int(c.execute("SELECT count(impulse_response_file.id) FROM impulse_response_file").fetchone()[0])
|
679
719
|
|
680
720
|
@cached_property
|
681
|
-
def mixtures(self) ->
|
721
|
+
def mixtures(self) -> list[Mixture]:
|
682
722
|
"""Get mixtures from db
|
683
723
|
|
684
724
|
:return: Mixtures
|
@@ -689,13 +729,16 @@ class MixtureDatabase:
|
|
689
729
|
from .helpers import to_target
|
690
730
|
|
691
731
|
with self.db() as c:
|
692
|
-
mixtures:
|
732
|
+
mixtures: list[Mixture] = []
|
693
733
|
for mixture in [MixtureRecord(*record) for record in c.execute("SELECT * FROM mixture").fetchall()]:
|
694
734
|
targets = [
|
695
735
|
to_target(TargetRecord(*target))
|
696
736
|
for target in c.execute(
|
697
|
-
"
|
698
|
-
|
737
|
+
"""
|
738
|
+
SELECT target.*
|
739
|
+
FROM target, mixture_target
|
740
|
+
WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id
|
741
|
+
""",
|
699
742
|
(mixture.id,),
|
700
743
|
).fetchall()
|
701
744
|
]
|
@@ -744,7 +787,7 @@ class MixtureDatabase:
|
|
744
787
|
return int(c.execute("SELECT count(mixture.id) FROM mixture").fetchone()[0])
|
745
788
|
|
746
789
|
def read_mixture_data(self, m_id: int, items: list[str] | str) -> Any:
|
747
|
-
"""Read mixture data
|
790
|
+
"""Read mixture data
|
748
791
|
|
749
792
|
:param m_id: Zero-based mixture ID
|
750
793
|
:param items: String(s) of dataset(s) to retrieve
|
@@ -792,7 +835,7 @@ class MixtureDatabase:
|
|
792
835
|
class_indices.extend(self.target_file(t_id).class_indices)
|
793
836
|
return sorted(set(class_indices))
|
794
837
|
|
795
|
-
def mixture_targets(self, m_id: int, force: bool = False) ->
|
838
|
+
def mixture_targets(self, m_id: int, force: bool = False) -> list[AudioT]:
|
796
839
|
"""Get the list of augmented target audio data (one per target in the mixup) for the given mixture ID
|
797
840
|
|
798
841
|
:param m_id: Zero-based mixture ID
|
@@ -826,7 +869,7 @@ class MixtureDatabase:
|
|
826
869
|
|
827
870
|
return targets_audio
|
828
871
|
|
829
|
-
def mixture_targets_f(self, m_id: int, targets:
|
872
|
+
def mixture_targets_f(self, m_id: int, targets: list[AudioT] | None = None, force: bool = False) -> list[AudioF]:
|
830
873
|
"""Get the list of augmented target transform data (one per target in the mixup) for the given mixture ID
|
831
874
|
|
832
875
|
:param m_id: Zero-based mixture ID
|
@@ -841,7 +884,7 @@ class MixtureDatabase:
|
|
841
884
|
|
842
885
|
return [forward_transform(target, self.ft_config) for target in targets]
|
843
886
|
|
844
|
-
def mixture_target(self, m_id: int, targets:
|
887
|
+
def mixture_target(self, m_id: int, targets: list[AudioT] | None = None, force: bool = False) -> AudioT:
|
845
888
|
"""Get the augmented target audio data for the given mixture ID
|
846
889
|
|
847
890
|
:param m_id: Zero-based mixture ID
|
@@ -864,7 +907,7 @@ class MixtureDatabase:
|
|
864
907
|
def mixture_target_f(
|
865
908
|
self,
|
866
909
|
m_id: int,
|
867
|
-
targets:
|
910
|
+
targets: list[AudioT] | None = None,
|
868
911
|
target: AudioT | None = None,
|
869
912
|
force: bool = False,
|
870
913
|
) -> AudioF:
|
@@ -900,7 +943,7 @@ class MixtureDatabase:
|
|
900
943
|
|
901
944
|
mixture = self.mixture(m_id)
|
902
945
|
noise = self.augmented_noise_audio(mixture)
|
903
|
-
noise = get_next_noise(audio=noise, offset=mixture.
|
946
|
+
noise = get_next_noise(audio=noise, offset=mixture.noise_offset, length=mixture.samples)
|
904
947
|
return apply_gain(audio=noise, gain=mixture.noise_snr_gain)
|
905
948
|
|
906
949
|
def mixture_noise_f(self, m_id: int, noise: AudioT | None = None, force: bool = False) -> AudioF:
|
@@ -921,7 +964,7 @@ class MixtureDatabase:
|
|
921
964
|
def mixture_mixture(
|
922
965
|
self,
|
923
966
|
m_id: int,
|
924
|
-
targets:
|
967
|
+
targets: list[AudioT] | None = None,
|
925
968
|
target: AudioT | None = None,
|
926
969
|
noise: AudioT | None = None,
|
927
970
|
force: bool = False,
|
@@ -951,7 +994,7 @@ class MixtureDatabase:
|
|
951
994
|
def mixture_mixture_f(
|
952
995
|
self,
|
953
996
|
m_id: int,
|
954
|
-
targets:
|
997
|
+
targets: list[AudioT] | None = None,
|
955
998
|
target: AudioT | None = None,
|
956
999
|
noise: AudioT | None = None,
|
957
1000
|
mixture: AudioT | None = None,
|
@@ -988,11 +1031,11 @@ class MixtureDatabase:
|
|
988
1031
|
def mixture_truth_t(
|
989
1032
|
self,
|
990
1033
|
m_id: int,
|
991
|
-
targets:
|
1034
|
+
targets: list[AudioT] | None = None,
|
992
1035
|
noise: AudioT | None = None,
|
993
1036
|
mixture: AudioT | None = None,
|
994
1037
|
force: bool = False,
|
995
|
-
) -> TruthDict:
|
1038
|
+
) -> list[TruthDict]:
|
996
1039
|
"""Get the truth_t data for the given mixture ID
|
997
1040
|
|
998
1041
|
:param m_id: Zero-based mixture ID
|
@@ -1000,9 +1043,9 @@ class MixtureDatabase:
|
|
1000
1043
|
:param noise: Augmented noise audio data for the given mixture ID
|
1001
1044
|
:param mixture: Mixture audio data for the given mixture ID
|
1002
1045
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1003
|
-
:return: truth_t data
|
1046
|
+
:return: list of truth_t data
|
1004
1047
|
"""
|
1005
|
-
from .
|
1048
|
+
from .truth import truth_function
|
1006
1049
|
|
1007
1050
|
if not force:
|
1008
1051
|
truth_t = self.read_mixture_data(m_id, "truth_t")
|
@@ -1018,12 +1061,18 @@ class MixtureDatabase:
|
|
1018
1061
|
if force or mixture is None:
|
1019
1062
|
mixture = self.mixture_mixture(m_id, targets=targets, noise=noise, force=force)
|
1020
1063
|
|
1021
|
-
|
1064
|
+
if not all(len(target) == self.mixture(m_id).samples for target in targets):
|
1065
|
+
raise ValueError("Lengths of targets do not match length of mixture")
|
1066
|
+
|
1067
|
+
if len(noise) != self.mixture(m_id).samples:
|
1068
|
+
raise ValueError("Length of noise does not match length of mixture")
|
1069
|
+
|
1070
|
+
return truth_function(self, m_id)
|
1022
1071
|
|
1023
1072
|
def mixture_segsnr_t(
|
1024
1073
|
self,
|
1025
1074
|
m_id: int,
|
1026
|
-
targets:
|
1075
|
+
targets: list[AudioT] | None = None,
|
1027
1076
|
target: AudioT | None = None,
|
1028
1077
|
noise: AudioT | None = None,
|
1029
1078
|
force: bool = False,
|
@@ -1037,7 +1086,9 @@ class MixtureDatabase:
|
|
1037
1086
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1038
1087
|
:return: segsnr_t data
|
1039
1088
|
"""
|
1040
|
-
|
1089
|
+
import numpy as np
|
1090
|
+
import torch
|
1091
|
+
from pyaaware import ForwardTransform
|
1041
1092
|
|
1042
1093
|
if not force:
|
1043
1094
|
segsnr_t = self.read_mixture_data(m_id, "segsnr_t")
|
@@ -1050,13 +1101,45 @@ class MixtureDatabase:
|
|
1050
1101
|
if force or noise is None:
|
1051
1102
|
noise = self.mixture_noise(m_id, force)
|
1052
1103
|
|
1053
|
-
|
1104
|
+
ft = ForwardTransform(
|
1105
|
+
length=self.ft_config.length,
|
1106
|
+
overlap=self.ft_config.overlap,
|
1107
|
+
bin_start=self.ft_config.bin_start,
|
1108
|
+
bin_end=self.ft_config.bin_end,
|
1109
|
+
ttype=self.ft_config.ttype,
|
1110
|
+
)
|
1111
|
+
|
1112
|
+
mixture = self.mixture(m_id)
|
1113
|
+
|
1114
|
+
segsnr_t = np.empty(mixture.samples, dtype=np.float32)
|
1115
|
+
|
1116
|
+
target_energy = ft.execute_all(torch.from_numpy(target))[1].numpy()
|
1117
|
+
noise_energy = ft.execute_all(torch.from_numpy(noise))[1].numpy()
|
1118
|
+
|
1119
|
+
offsets = range(0, mixture.samples, self.ft_config.overlap)
|
1120
|
+
if len(target_energy) != len(offsets):
|
1121
|
+
raise ValueError(
|
1122
|
+
f"Number of frames in energy, {len(target_energy)},"
|
1123
|
+
f" is not number of frames in mixture, {len(offsets)}"
|
1124
|
+
)
|
1125
|
+
|
1126
|
+
for idx, offset in enumerate(offsets):
|
1127
|
+
indices = slice(offset, offset + self.ft_config.overlap)
|
1128
|
+
|
1129
|
+
if noise_energy[idx] == 0:
|
1130
|
+
snr = np.float32(np.inf)
|
1131
|
+
else:
|
1132
|
+
snr = np.float32(target_energy[idx] / noise_energy[idx])
|
1133
|
+
|
1134
|
+
segsnr_t[indices] = snr
|
1135
|
+
|
1136
|
+
return segsnr_t
|
1054
1137
|
|
1055
1138
|
def mixture_segsnr(
|
1056
1139
|
self,
|
1057
1140
|
m_id: int,
|
1058
1141
|
segsnr_t: Segsnr | None = None,
|
1059
|
-
targets:
|
1142
|
+
targets: list[AudioT] | None = None,
|
1060
1143
|
target: AudioT | None = None,
|
1061
1144
|
noise: AudioT | None = None,
|
1062
1145
|
force: bool = False,
|
@@ -1088,12 +1171,12 @@ class MixtureDatabase:
|
|
1088
1171
|
def mixture_ft(
|
1089
1172
|
self,
|
1090
1173
|
m_id: int,
|
1091
|
-
targets:
|
1174
|
+
targets: list[AudioT] | None = None,
|
1092
1175
|
target: AudioT | None = None,
|
1093
1176
|
noise: AudioT | None = None,
|
1094
1177
|
mixture_f: AudioF | None = None,
|
1095
1178
|
mixture: AudioT | None = None,
|
1096
|
-
truth_t: TruthDict | None = None,
|
1179
|
+
truth_t: list[TruthDict] | None = None,
|
1097
1180
|
force: bool = False,
|
1098
1181
|
) -> tuple[Feature, TruthDict]:
|
1099
1182
|
"""Get the feature and truth_f data for the given mixture ID
|
@@ -1132,19 +1215,24 @@ class MixtureDatabase:
|
|
1132
1215
|
|
1133
1216
|
fg = FeatureGenerator(self.fg_config.feature_mode, self.fg_config.truth_parameters)
|
1134
1217
|
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1218
|
+
# TODO: handle mixup in truth_t
|
1219
|
+
feature, truth_f = fg.execute_all(mixture_f, truth_t[0])
|
1220
|
+
if truth_f is not None:
|
1221
|
+
for key in self.truth_configs:
|
1222
|
+
if self.truth_parameters[key] is not None:
|
1223
|
+
truth_f[key] = truth_stride_reduction(truth_f[key], self.truth_configs[key].stride_reduction)
|
1224
|
+
else:
|
1225
|
+
raise TypeError("Unexpected truth of None from feature generator")
|
1138
1226
|
|
1139
1227
|
return feature, truth_f
|
1140
1228
|
|
1141
1229
|
def mixture_feature(
|
1142
1230
|
self,
|
1143
1231
|
m_id: int,
|
1144
|
-
targets:
|
1232
|
+
targets: list[AudioT] | None = None,
|
1145
1233
|
noise: AudioT | None = None,
|
1146
1234
|
mixture: AudioT | None = None,
|
1147
|
-
truth_t: TruthDict | None = None,
|
1235
|
+
truth_t: list[TruthDict] | None = None,
|
1148
1236
|
force: bool = False,
|
1149
1237
|
) -> Feature:
|
1150
1238
|
"""Get the feature data for the given mixture ID
|
@@ -1170,10 +1258,10 @@ class MixtureDatabase:
|
|
1170
1258
|
def mixture_truth_f(
|
1171
1259
|
self,
|
1172
1260
|
m_id: int,
|
1173
|
-
targets:
|
1261
|
+
targets: list[AudioT] | None = None,
|
1174
1262
|
noise: AudioT | None = None,
|
1175
1263
|
mixture: AudioT | None = None,
|
1176
|
-
truth_t: TruthDict | None = None,
|
1264
|
+
truth_t: list[TruthDict] | None = None,
|
1177
1265
|
force: bool = False,
|
1178
1266
|
) -> TruthDict:
|
1179
1267
|
"""Get the truth_f data for the given mixture ID
|
@@ -1199,9 +1287,9 @@ class MixtureDatabase:
|
|
1199
1287
|
def mixture_class_count(
|
1200
1288
|
self,
|
1201
1289
|
m_id: int,
|
1202
|
-
targets:
|
1290
|
+
targets: list[AudioT] | None = None,
|
1203
1291
|
noise: AudioT | None = None,
|
1204
|
-
truth_t: TruthDict | None = None,
|
1292
|
+
truth_t: list[TruthDict] | None = None,
|
1205
1293
|
) -> ClassCount:
|
1206
1294
|
"""Compute the number of frames for which each class index is active for the given mixture ID
|
1207
1295
|
|
@@ -1220,7 +1308,8 @@ class MixtureDatabase:
|
|
1220
1308
|
num_classes = self.num_classes
|
1221
1309
|
if "sed" in self.truth_configs:
|
1222
1310
|
for cl in range(num_classes):
|
1223
|
-
|
1311
|
+
# TODO: handle mixup in truth_t
|
1312
|
+
class_count[cl] = int(np.sum(truth_t[0]["sed"][:, cl] >= self.class_weights_thresholds[cl]))
|
1224
1313
|
|
1225
1314
|
return class_count
|
1226
1315
|
|
@@ -1300,7 +1389,12 @@ class MixtureDatabase:
|
|
1300
1389
|
|
1301
1390
|
return results
|
1302
1391
|
|
1303
|
-
def mixids_for_speech_metadata(
|
1392
|
+
def mixids_for_speech_metadata(
|
1393
|
+
self,
|
1394
|
+
tier: str | None = None,
|
1395
|
+
value: str | None = None,
|
1396
|
+
where: str | None = None,
|
1397
|
+
) -> list[int]:
|
1304
1398
|
"""Get a list of mixture IDs for the given speech metadata tier.
|
1305
1399
|
|
1306
1400
|
If 'where' is None, then include mixture IDs whose tier values are equal to the given 'value'.
|
@@ -1310,35 +1404,35 @@ class MixtureDatabase:
|
|
1310
1404
|
Examples:
|
1311
1405
|
>>> mixdb = MixtureDatabase('/mixdb_location')
|
1312
1406
|
|
1313
|
-
>>> mixids = mixdb.mixids_for_speech_metadata('speaker_id', '
|
1314
|
-
Get
|
1407
|
+
>>> mixids = mixdb.mixids_for_speech_metadata('speaker_id', 'TIMIT_ABW0')
|
1408
|
+
Get mixture IDs for mixtures with speakers whose speaker_ids are 'TIMIT_ABW0'.
|
1315
1409
|
|
1316
|
-
>>> mixids = mixdb.mixids_for_speech_metadata(
|
1317
|
-
Get mixture IDs for mixtures with speakers whose ages are
|
1410
|
+
>>> mixids = mixdb.mixids_for_speech_metadata(where='age >= 27')
|
1411
|
+
Get mixture IDs for mixtures with speakers whose ages are greater than or equal to 27.
|
1318
1412
|
|
1319
|
-
>>> mixids = mixdb.mixids_for_speech_metadata(
|
1413
|
+
>>> mixids = mixdb.mixids_for_speech_metadata(where="dialect in ('New York City', 'Northern')")
|
1320
1414
|
Get mixture IDs for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
|
1321
1415
|
"""
|
1322
1416
|
if value is None and where is None:
|
1323
1417
|
raise ValueError("Must provide either value or where")
|
1324
1418
|
|
1325
1419
|
if where is None:
|
1420
|
+
if tier is None:
|
1421
|
+
raise ValueError("Must provide tier")
|
1326
1422
|
where = f"{tier} = '{value}'"
|
1327
1423
|
|
1328
|
-
if tier in self.textgrid_metadata_tiers:
|
1424
|
+
if tier is not None and tier in self.textgrid_metadata_tiers:
|
1329
1425
|
raise ValueError(f"TextGrid tier data, '{tier}', is not supported in mixids_for_speech_metadata().")
|
1330
1426
|
|
1331
1427
|
with self.db() as c:
|
1332
|
-
|
1333
|
-
|
1334
|
-
|
1335
|
-
results = c.execute(
|
1336
|
-
|
1337
|
-
|
1338
|
-
target_file_ids = [target_file_id[0] for target_file_id in results]
|
1428
|
+
results = c.execute(f"SELECT id FROM speaker WHERE {where}").fetchall()
|
1429
|
+
speaker_ids = ",".join(map(str, [i[0] for i in results]))
|
1430
|
+
|
1431
|
+
results = c.execute(f"SELECT id FROM target_file WHERE speaker_id IN ({speaker_ids})").fetchall()
|
1432
|
+
target_file_ids = ",".join(map(str, [i[0] for i in results]))
|
1433
|
+
|
1339
1434
|
results = c.execute(
|
1340
|
-
"SELECT mixture_id FROM mixture_target "
|
1341
|
-
+ f"WHERE mixture_target.target_id IN ({','.join(map(str, target_file_ids))})"
|
1435
|
+
f"SELECT mixture_id FROM mixture_target WHERE mixture_target.target_id IN ({target_file_ids})"
|
1342
1436
|
).fetchall()
|
1343
1437
|
|
1344
1438
|
return [mixture_id[0] - 1 for mixture_id in results]
|
@@ -1348,9 +1442,29 @@ class MixtureDatabase:
|
|
1348
1442
|
|
1349
1443
|
return mixture_all_speech_metadata(self, self.mixture(m_id))
|
1350
1444
|
|
1351
|
-
def
|
1352
|
-
|
1353
|
-
|
1445
|
+
def cached_metrics(self, m_ids: GeneralizedIDs = "*") -> list[str]:
|
1446
|
+
"""Get list of cached metrics for all mixtures."""
|
1447
|
+
from glob import glob
|
1448
|
+
from os.path import join
|
1449
|
+
from pathlib import Path
|
1450
|
+
|
1451
|
+
supported_metrics = self.supported_metrics.names
|
1452
|
+
first = True
|
1453
|
+
result: set[str] = set()
|
1454
|
+
for m_id in self.mixids_to_list(m_ids):
|
1455
|
+
mixture_dir = join(self.location, "mixture", self.mixture(m_id).name)
|
1456
|
+
found = {Path(f).stem for f in glob(join(mixture_dir, "*.pkl"))}
|
1457
|
+
if first:
|
1458
|
+
first = False
|
1459
|
+
for f in found:
|
1460
|
+
if f in supported_metrics:
|
1461
|
+
result.add(f)
|
1462
|
+
else:
|
1463
|
+
result = result & found
|
1464
|
+
|
1465
|
+
return sorted(result)
|
1466
|
+
|
1467
|
+
def mixture_metrics(self, m_id: int, metrics: list[str], force: bool = False) -> list[Any]:
|
1354
1468
|
"""Get metrics data for the given mixture ID
|
1355
1469
|
|
1356
1470
|
:param m_id: Zero-based mixture ID
|
@@ -1375,10 +1489,23 @@ class MixtureDatabase:
|
|
1375
1489
|
from sonusai.mixture import SpeechMetrics
|
1376
1490
|
from sonusai.utils import calc_asr
|
1377
1491
|
|
1378
|
-
def
|
1379
|
-
state = None
|
1492
|
+
def create_targets_audio() -> Callable[[], list[AudioT]]:
|
1493
|
+
state: list[AudioT] | None = None
|
1494
|
+
|
1495
|
+
def get() -> list[AudioT]:
|
1496
|
+
nonlocal state
|
1497
|
+
if state is None:
|
1498
|
+
state = self.mixture_targets(m_id)
|
1499
|
+
return state
|
1500
|
+
|
1501
|
+
return get
|
1502
|
+
|
1503
|
+
targets_audio = create_targets_audio()
|
1504
|
+
|
1505
|
+
def create_target_audio() -> Callable[[], AudioT]:
|
1506
|
+
state: AudioT | None = None
|
1380
1507
|
|
1381
|
-
def get() ->
|
1508
|
+
def get() -> AudioT:
|
1382
1509
|
nonlocal state
|
1383
1510
|
if state is None:
|
1384
1511
|
state = self.mixture_target(m_id)
|
@@ -1388,10 +1515,10 @@ class MixtureDatabase:
|
|
1388
1515
|
|
1389
1516
|
target_audio = create_target_audio()
|
1390
1517
|
|
1391
|
-
def create_target_f() -> Callable[[],
|
1392
|
-
state = None
|
1518
|
+
def create_target_f() -> Callable[[], AudioF]:
|
1519
|
+
state: AudioF | None = None
|
1393
1520
|
|
1394
|
-
def get() ->
|
1521
|
+
def get() -> AudioF:
|
1395
1522
|
nonlocal state
|
1396
1523
|
if state is None:
|
1397
1524
|
state = self.mixture_targets_f(m_id)[0]
|
@@ -1401,10 +1528,10 @@ class MixtureDatabase:
|
|
1401
1528
|
|
1402
1529
|
target_f = create_target_f()
|
1403
1530
|
|
1404
|
-
def create_noise_audio() -> Callable[[],
|
1405
|
-
state = None
|
1531
|
+
def create_noise_audio() -> Callable[[], AudioT]:
|
1532
|
+
state: AudioT | None = None
|
1406
1533
|
|
1407
|
-
def get() ->
|
1534
|
+
def get() -> AudioT:
|
1408
1535
|
nonlocal state
|
1409
1536
|
if state is None:
|
1410
1537
|
state = self.mixture_noise(m_id)
|
@@ -1414,10 +1541,10 @@ class MixtureDatabase:
|
|
1414
1541
|
|
1415
1542
|
noise_audio = create_noise_audio()
|
1416
1543
|
|
1417
|
-
def create_noise_f() -> Callable[[],
|
1418
|
-
state = None
|
1544
|
+
def create_noise_f() -> Callable[[], AudioF]:
|
1545
|
+
state: AudioF | None = None
|
1419
1546
|
|
1420
|
-
def get() ->
|
1547
|
+
def get() -> AudioF:
|
1421
1548
|
nonlocal state
|
1422
1549
|
if state is None:
|
1423
1550
|
state = self.mixture_noise_f(m_id)
|
@@ -1427,10 +1554,10 @@ class MixtureDatabase:
|
|
1427
1554
|
|
1428
1555
|
noise_f = create_noise_f()
|
1429
1556
|
|
1430
|
-
def create_mixture_audio() -> Callable[[],
|
1431
|
-
state = None
|
1557
|
+
def create_mixture_audio() -> Callable[[], AudioT]:
|
1558
|
+
state: AudioT | None = None
|
1432
1559
|
|
1433
|
-
def get() ->
|
1560
|
+
def get() -> AudioT:
|
1434
1561
|
nonlocal state
|
1435
1562
|
if state is None:
|
1436
1563
|
state = self.mixture_mixture(m_id)
|
@@ -1440,10 +1567,10 @@ class MixtureDatabase:
|
|
1440
1567
|
|
1441
1568
|
mixture_audio = create_mixture_audio()
|
1442
1569
|
|
1443
|
-
def create_segsnr_f() -> Callable[[],
|
1444
|
-
state = None
|
1570
|
+
def create_segsnr_f() -> Callable[[], Segsnr]:
|
1571
|
+
state: Segsnr | None = None
|
1445
1572
|
|
1446
|
-
def get() ->
|
1573
|
+
def get() -> Segsnr:
|
1447
1574
|
nonlocal state
|
1448
1575
|
if state is None:
|
1449
1576
|
state = self.mixture_segsnr(m_id)
|
@@ -1453,21 +1580,38 @@ class MixtureDatabase:
|
|
1453
1580
|
|
1454
1581
|
segsnr_f = create_segsnr_f()
|
1455
1582
|
|
1456
|
-
def create_speech() -> Callable[[], SpeechMetrics]:
|
1457
|
-
state = None
|
1583
|
+
def create_speech() -> Callable[[], list[SpeechMetrics]]:
|
1584
|
+
state: list[SpeechMetrics] | None = None
|
1458
1585
|
|
1459
|
-
def get() -> SpeechMetrics:
|
1586
|
+
def get() -> list[SpeechMetrics]:
|
1460
1587
|
nonlocal state
|
1461
1588
|
if state is None:
|
1462
|
-
state =
|
1589
|
+
state = []
|
1590
|
+
for audio in targets_audio():
|
1591
|
+
state.append(calc_speech(hypothesis=mixture_audio(), reference=audio))
|
1463
1592
|
return state
|
1464
1593
|
|
1465
1594
|
return get
|
1466
1595
|
|
1467
1596
|
speech = create_speech()
|
1468
1597
|
|
1598
|
+
def create_targets_stats() -> Callable[[], list[AudioStatsMetrics]]:
|
1599
|
+
state: list[AudioStatsMetrics] | None = None
|
1600
|
+
|
1601
|
+
def get() -> list[AudioStatsMetrics]:
|
1602
|
+
nonlocal state
|
1603
|
+
if state is None:
|
1604
|
+
state = []
|
1605
|
+
for audio in targets_audio():
|
1606
|
+
state.append(calc_audio_stats(audio, self.fg_info.ft_config.length / SAMPLE_RATE))
|
1607
|
+
return state
|
1608
|
+
|
1609
|
+
return get
|
1610
|
+
|
1611
|
+
targets_stats = create_targets_stats()
|
1612
|
+
|
1469
1613
|
def create_target_stats() -> Callable[[], AudioStatsMetrics]:
|
1470
|
-
state = None
|
1614
|
+
state: AudioStatsMetrics | None = None
|
1471
1615
|
|
1472
1616
|
def get() -> AudioStatsMetrics:
|
1473
1617
|
nonlocal state
|
@@ -1480,7 +1624,7 @@ class MixtureDatabase:
|
|
1480
1624
|
target_stats = create_target_stats()
|
1481
1625
|
|
1482
1626
|
def create_noise_stats() -> Callable[[], AudioStatsMetrics]:
|
1483
|
-
state = None
|
1627
|
+
state: AudioStatsMetrics | None = None
|
1484
1628
|
|
1485
1629
|
def get() -> AudioStatsMetrics:
|
1486
1630
|
nonlocal state
|
@@ -1508,6 +1652,21 @@ class MixtureDatabase:
|
|
1508
1652
|
|
1509
1653
|
asr_config = create_asr_config()
|
1510
1654
|
|
1655
|
+
def create_targets_asr() -> Callable[[str], list[str]]:
|
1656
|
+
state: dict[str, list[str]] = {}
|
1657
|
+
|
1658
|
+
def get(asr_name) -> list[str]:
|
1659
|
+
nonlocal state
|
1660
|
+
if asr_name not in state:
|
1661
|
+
state[asr_name] = []
|
1662
|
+
for audio in targets_audio():
|
1663
|
+
state[asr_name].append(calc_asr(audio, **asr_config(asr_name)).text)
|
1664
|
+
return state[asr_name]
|
1665
|
+
|
1666
|
+
return get
|
1667
|
+
|
1668
|
+
targets_asr = create_targets_asr()
|
1669
|
+
|
1511
1670
|
def create_target_asr() -> Callable[[str], str]:
|
1512
1671
|
state: dict[str, str] = {}
|
1513
1672
|
|
@@ -1541,7 +1700,7 @@ class MixtureDatabase:
|
|
1541
1700
|
asr_name = parts[1]
|
1542
1701
|
return asr_name
|
1543
1702
|
|
1544
|
-
def calc(m: str) ->
|
1703
|
+
def calc(m: str) -> Any:
|
1545
1704
|
if m == "mxsnr":
|
1546
1705
|
return self.mixture(m_id).snr
|
1547
1706
|
|
@@ -1555,7 +1714,7 @@ class MixtureDatabase:
|
|
1555
1714
|
if m.startswith("mxwer"):
|
1556
1715
|
asr_name = get_asr_name(m)
|
1557
1716
|
|
1558
|
-
if self.mixture(m_id).
|
1717
|
+
if self.mixture(m_id).is_noise_only:
|
1559
1718
|
# noise only, ignore/reset target asr
|
1560
1719
|
return float("nan")
|
1561
1720
|
|
@@ -1569,11 +1728,11 @@ class MixtureDatabase:
|
|
1569
1728
|
asr_name = get_asr_name(m)
|
1570
1729
|
|
1571
1730
|
text = self.mixture_speech_metadata(m_id, "text")[0]
|
1572
|
-
if text
|
1573
|
-
|
1731
|
+
if not isinstance(text, str):
|
1732
|
+
# TODO: should this be NaN like above?
|
1733
|
+
return [float(0)] * len(targets_audio())
|
1574
1734
|
|
1575
|
-
|
1576
|
-
return float(0)
|
1735
|
+
return [calc_wer(t, text).wer * 100 for t in targets_asr(asr_name)]
|
1577
1736
|
|
1578
1737
|
if m.startswith("mxasr"):
|
1579
1738
|
return mixture_asr(get_asr_name(m))
|
@@ -1603,24 +1762,24 @@ class MixtureDatabase:
|
|
1603
1762
|
return calc_segsnr_f_bin(target_f(), noise_f()).db_std
|
1604
1763
|
|
1605
1764
|
if m == "mxpesq":
|
1606
|
-
if self.mixture(m_id).
|
1607
|
-
return 0
|
1608
|
-
return speech()
|
1765
|
+
if self.mixture(m_id).is_noise_only:
|
1766
|
+
return [0] * len(speech())
|
1767
|
+
return [s.pesq for s in speech()]
|
1609
1768
|
|
1610
1769
|
if m == "mxcsig":
|
1611
|
-
if self.mixture(m_id).
|
1612
|
-
return 0
|
1613
|
-
return speech()
|
1770
|
+
if self.mixture(m_id).is_noise_only:
|
1771
|
+
return [0] * len(speech())
|
1772
|
+
return [s.csig for s in speech()]
|
1614
1773
|
|
1615
1774
|
if m == "mxcbak":
|
1616
|
-
if self.mixture(m_id).
|
1617
|
-
return 0
|
1618
|
-
return speech()
|
1775
|
+
if self.mixture(m_id).is_noise_only:
|
1776
|
+
return [0] * len(speech())
|
1777
|
+
return [s.cbak for s in speech()]
|
1619
1778
|
|
1620
1779
|
if m == "mxcovl":
|
1621
|
-
if self.mixture(m_id).
|
1622
|
-
return 0
|
1623
|
-
return speech()
|
1780
|
+
if self.mixture(m_id).is_noise_only:
|
1781
|
+
return [0] * len(speech())
|
1782
|
+
return [s.covl for s in speech()]
|
1624
1783
|
|
1625
1784
|
if m == "mxwsdr":
|
1626
1785
|
mixture = mixture_audio()[:, np.newaxis]
|
@@ -1644,37 +1803,70 @@ class MixtureDatabase:
|
|
1644
1803
|
extended=False,
|
1645
1804
|
)
|
1646
1805
|
|
1647
|
-
if m == "
|
1806
|
+
if m == "mxtdco":
|
1648
1807
|
return target_stats().dco
|
1649
1808
|
|
1650
|
-
if m == "
|
1809
|
+
if m == "mxtmin":
|
1651
1810
|
return target_stats().min
|
1652
1811
|
|
1653
|
-
if m == "
|
1812
|
+
if m == "mxtmax":
|
1654
1813
|
return target_stats().max
|
1655
1814
|
|
1656
|
-
if m == "
|
1815
|
+
if m == "mxtpkdb":
|
1657
1816
|
return target_stats().pkdb
|
1658
1817
|
|
1659
|
-
if m == "
|
1818
|
+
if m == "mxtlrms":
|
1660
1819
|
return target_stats().lrms
|
1661
1820
|
|
1662
|
-
if m == "
|
1821
|
+
if m == "mxtpkr":
|
1663
1822
|
return target_stats().pkr
|
1664
1823
|
|
1665
|
-
if m == "
|
1824
|
+
if m == "mxttr":
|
1666
1825
|
return target_stats().tr
|
1667
1826
|
|
1668
|
-
if m == "
|
1827
|
+
if m == "mxtcr":
|
1669
1828
|
return target_stats().cr
|
1670
1829
|
|
1671
|
-
if m == "
|
1830
|
+
if m == "mxtfl":
|
1672
1831
|
return target_stats().fl
|
1673
1832
|
|
1674
|
-
if m == "
|
1833
|
+
if m == "mxtpkc":
|
1675
1834
|
return target_stats().pkc
|
1676
1835
|
|
1836
|
+
if m == "tdco":
|
1837
|
+
return [t.dco for t in targets_stats()]
|
1838
|
+
|
1839
|
+
if m == "tmin":
|
1840
|
+
return [t.min for t in targets_stats()]
|
1841
|
+
|
1842
|
+
if m == "tmax":
|
1843
|
+
return [t.max for t in targets_stats()]
|
1844
|
+
|
1845
|
+
if m == "tpkdb":
|
1846
|
+
return [t.pkdb for t in targets_stats()]
|
1847
|
+
|
1848
|
+
if m == "tlrms":
|
1849
|
+
return [t.lrms for t in targets_stats()]
|
1850
|
+
|
1851
|
+
if m == "tpkr":
|
1852
|
+
return [t.pkr for t in targets_stats()]
|
1853
|
+
|
1854
|
+
if m == "ttr":
|
1855
|
+
return [t.tr for t in targets_stats()]
|
1856
|
+
|
1857
|
+
if m == "tcr":
|
1858
|
+
return [t.cr for t in targets_stats()]
|
1859
|
+
|
1860
|
+
if m == "tfl":
|
1861
|
+
return [t.fl for t in targets_stats()]
|
1862
|
+
|
1863
|
+
if m == "tpkc":
|
1864
|
+
return [t.pkc for t in targets_stats()]
|
1865
|
+
|
1677
1866
|
if m.startswith("tasr"):
|
1867
|
+
return targets_asr(get_asr_name(m))
|
1868
|
+
|
1869
|
+
if m.startswith("mxtasr"):
|
1678
1870
|
return target_asr(get_asr_name(m))
|
1679
1871
|
|
1680
1872
|
if m == "ndco":
|
@@ -1743,7 +1935,14 @@ def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
|
|
1743
1935
|
|
1744
1936
|
with db() as c:
|
1745
1937
|
spectral_mask = SpectralMaskRecord(
|
1746
|
-
*c.execute(
|
1938
|
+
*c.execute(
|
1939
|
+
"""
|
1940
|
+
SELECT *
|
1941
|
+
FROM spectral_mask
|
1942
|
+
WHERE ? = spectral_mask.id
|
1943
|
+
""",
|
1944
|
+
(sm_id,),
|
1945
|
+
).fetchone()
|
1747
1946
|
)
|
1748
1947
|
return SpectralMask(
|
1749
1948
|
f_max_width=spectral_mask.f_max_width,
|
@@ -1768,7 +1967,14 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
|
|
1768
1967
|
|
1769
1968
|
with db() as c:
|
1770
1969
|
target_file = TargetFileRecord(
|
1771
|
-
*c.execute(
|
1970
|
+
*c.execute(
|
1971
|
+
"""
|
1972
|
+
SELECT *
|
1973
|
+
FROM target_file
|
1974
|
+
WHERE ? = target_file.id
|
1975
|
+
""",
|
1976
|
+
(t_id,),
|
1977
|
+
).fetchone()
|
1772
1978
|
)
|
1773
1979
|
|
1774
1980
|
return TargetFile(
|
@@ -1791,7 +1997,11 @@ def _noise_file(db: partial, n_id: int) -> NoiseFile:
|
|
1791
1997
|
"""
|
1792
1998
|
with db() as c:
|
1793
1999
|
noise = c.execute(
|
1794
|
-
"
|
2000
|
+
"""
|
2001
|
+
SELECT noise_file.name, samples
|
2002
|
+
FROM noise_file
|
2003
|
+
WHERE ? = noise_file.id
|
2004
|
+
""",
|
1795
2005
|
(n_id,),
|
1796
2006
|
).fetchone()
|
1797
2007
|
return NoiseFile(name=noise[0], samples=noise[1])
|
@@ -1808,7 +2018,11 @@ def _impulse_response_file(db: partial, ir_id: int) -> str:
|
|
1808
2018
|
with db() as c:
|
1809
2019
|
return str(
|
1810
2020
|
c.execute(
|
1811
|
-
"
|
2021
|
+
"""
|
2022
|
+
SELECT impulse_response_file.file
|
2023
|
+
FROM impulse_response_file
|
2024
|
+
WHERE ? = impulse_response_file.id
|
2025
|
+
""",
|
1812
2026
|
(ir_id + 1,),
|
1813
2027
|
).fetchone()[0]
|
1814
2028
|
)
|
@@ -1828,13 +2042,25 @@ def _mixture(db: partial, m_id: int) -> Mixture:
|
|
1828
2042
|
from .helpers import to_target
|
1829
2043
|
|
1830
2044
|
with db() as c:
|
1831
|
-
mixture = MixtureRecord(
|
2045
|
+
mixture = MixtureRecord(
|
2046
|
+
*c.execute(
|
2047
|
+
"""
|
2048
|
+
SELECT *
|
2049
|
+
FROM mixture
|
2050
|
+
WHERE ? = mixture.id
|
2051
|
+
""",
|
2052
|
+
(m_id + 1,),
|
2053
|
+
).fetchone()
|
2054
|
+
)
|
2055
|
+
|
1832
2056
|
targets = [
|
1833
2057
|
to_target(TargetRecord(*target))
|
1834
2058
|
for target in c.execute(
|
1835
|
-
"
|
1836
|
-
|
1837
|
-
|
2059
|
+
"""
|
2060
|
+
SELECT target.*
|
2061
|
+
FROM target, mixture_target
|
2062
|
+
WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id
|
2063
|
+
""",
|
1838
2064
|
(mixture.id,),
|
1839
2065
|
).fetchall()
|
1840
2066
|
]
|
@@ -1865,10 +2091,11 @@ def _target_truth_configs(db: partial, t_id: int) -> TruthConfigs:
|
|
1865
2091
|
truth_configs: TruthConfigs = {}
|
1866
2092
|
with db() as c:
|
1867
2093
|
for truth_config_record in c.execute(
|
1868
|
-
"
|
1869
|
-
|
1870
|
-
|
1871
|
-
|
2094
|
+
"""
|
2095
|
+
SELECT truth_config.config
|
2096
|
+
FROM truth_config, target_file_truth_config
|
2097
|
+
WHERE ? = target_file_truth_config.target_file_id AND truth_config.id = target_file_truth_config.truth_config_id
|
2098
|
+
""",
|
1872
2099
|
(t_id,),
|
1873
2100
|
).fetchall():
|
1874
2101
|
truth_config = json.loads(truth_config_record[0])
|