sonusai 0.20.2__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +16 -3
- sonusai/audiofe.py +240 -76
- sonusai/calc_metric_spenh.py +71 -73
- sonusai/config/__init__.py +3 -0
- sonusai/config/config.py +61 -0
- sonusai/config/config.yml +20 -0
- sonusai/config/constants.py +8 -0
- sonusai/constants.py +11 -0
- sonusai/data/genmixdb.yml +21 -36
- sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
- sonusai/deprecated/plot.py +4 -5
- sonusai/doc/doc.py +4 -4
- sonusai/doc.py +11 -4
- sonusai/genft.py +43 -45
- sonusai/genmetrics.py +23 -19
- sonusai/genmix.py +54 -82
- sonusai/genmixdb.py +88 -264
- sonusai/ir_metric.py +30 -34
- sonusai/lsdb.py +41 -48
- sonusai/main.py +15 -22
- sonusai/metrics/calc_audio_stats.py +4 -17
- sonusai/metrics/calc_class_weights.py +4 -4
- sonusai/metrics/calc_optimal_thresholds.py +8 -5
- sonusai/metrics/calc_pesq.py +2 -2
- sonusai/metrics/calc_segsnr_f.py +4 -4
- sonusai/metrics/calc_speech.py +25 -13
- sonusai/metrics/class_summary.py +7 -7
- sonusai/metrics/confusion_matrix_summary.py +5 -5
- sonusai/metrics/one_hot.py +4 -4
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/metrics_summary.py +38 -45
- sonusai/mixture/__init__.py +5 -104
- sonusai/mixture/audio.py +10 -39
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/config.py +251 -271
- sonusai/mixture/constants.py +35 -39
- sonusai/mixture/data_io.py +25 -36
- sonusai/mixture/db_datatypes.py +58 -22
- sonusai/mixture/effects.py +386 -0
- sonusai/mixture/feature.py +7 -11
- sonusai/mixture/generation.py +484 -611
- sonusai/mixture/helpers.py +82 -184
- sonusai/mixture/ir_delay.py +3 -4
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +6 -12
- sonusai/mixture/mixdb.py +931 -669
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +2 -2
- sonusai/mixture/truth.py +17 -15
- sonusai/mixture/truth_functions/crm.py +12 -12
- sonusai/mixture/truth_functions/energy.py +22 -22
- sonusai/mixture/truth_functions/file.py +5 -5
- sonusai/mixture/truth_functions/metadata.py +4 -4
- sonusai/mixture/truth_functions/metrics.py +4 -4
- sonusai/mixture/truth_functions/phoneme.py +3 -3
- sonusai/mixture/truth_functions/sed.py +11 -13
- sonusai/mixture/truth_functions/target.py +10 -10
- sonusai/mkwav.py +26 -29
- sonusai/onnx_predict.py +240 -88
- sonusai/queries/__init__.py +2 -2
- sonusai/queries/queries.py +38 -34
- sonusai/speech/librispeech.py +1 -1
- sonusai/speech/mcgill.py +1 -1
- sonusai/speech/timit.py +2 -2
- sonusai/summarize_metric_spenh.py +10 -17
- sonusai/utils/__init__.py +7 -1
- sonusai/utils/asl_p56.py +2 -2
- sonusai/utils/asr.py +2 -2
- sonusai/utils/asr_functions/aaware_whisper.py +4 -5
- sonusai/utils/choice.py +31 -0
- sonusai/utils/compress.py +1 -1
- sonusai/utils/dataclass_from_dict.py +19 -1
- sonusai/utils/energy_f.py +3 -3
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/onnx_utils.py +3 -17
- sonusai/utils/print_mixture_details.py +21 -19
- sonusai/utils/{temp_seed.py → rand.py} +3 -3
- sonusai/utils/read_predict_data.py +2 -2
- sonusai/utils/reshape.py +3 -3
- sonusai/utils/stratified_shuffle_split.py +3 -3
- sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
- sonusai/utils/write_audio.py +2 -2
- sonusai/vars.py +11 -4
- {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/METADATA +4 -2
- sonusai-1.0.1.dist-info/RECORD +138 -0
- sonusai/mixture/augmentation.py +0 -444
- sonusai/mixture/class_count.py +0 -15
- sonusai/mixture/eq_rule_is_valid.py +0 -45
- sonusai/mixture/target_class_balancing.py +0 -107
- sonusai/mixture/targets.py +0 -175
- sonusai-0.20.2.dist-info/RECORD +0 -128
- {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/WHEEL +0 -0
- {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py
CHANGED
@@ -6,36 +6,43 @@ from sqlite3 import Connection
|
|
6
6
|
from sqlite3 import Cursor
|
7
7
|
from typing import Any
|
8
8
|
|
9
|
-
from
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from
|
13
|
-
from
|
14
|
-
from
|
15
|
-
from
|
16
|
-
from
|
17
|
-
from
|
18
|
-
from
|
19
|
-
from
|
20
|
-
from
|
21
|
-
from
|
22
|
-
from
|
23
|
-
from
|
24
|
-
from
|
25
|
-
from
|
26
|
-
from
|
27
|
-
from
|
28
|
-
from
|
29
|
-
from
|
9
|
+
from ..datatypes import ASRConfigs
|
10
|
+
from ..datatypes import AudioF
|
11
|
+
from ..datatypes import AudioT
|
12
|
+
from ..datatypes import ClassCount
|
13
|
+
from ..datatypes import Feature
|
14
|
+
from ..datatypes import FeatureGeneratorConfig
|
15
|
+
from ..datatypes import FeatureGeneratorInfo
|
16
|
+
from ..datatypes import GeneralizedIDs
|
17
|
+
from ..datatypes import ImpulseResponseFile
|
18
|
+
from ..datatypes import MetricDoc
|
19
|
+
from ..datatypes import MetricDocs
|
20
|
+
from ..datatypes import Mixture
|
21
|
+
from ..datatypes import Segsnr
|
22
|
+
from ..datatypes import SourceFile
|
23
|
+
from ..datatypes import Sources
|
24
|
+
from ..datatypes import SourcesAudioF
|
25
|
+
from ..datatypes import SourcesAudioT
|
26
|
+
from ..datatypes import SpectralMask
|
27
|
+
from ..datatypes import SpeechMetadata
|
28
|
+
from ..datatypes import TransformConfig
|
29
|
+
from ..datatypes import TruthConfigs
|
30
|
+
from ..datatypes import TruthDict
|
31
|
+
from ..datatypes import TruthsConfigs
|
32
|
+
from ..datatypes import TruthsDict
|
33
|
+
from ..datatypes import UniversalSNR
|
30
34
|
|
31
35
|
|
32
36
|
def db_file(location: str, test: bool = False) -> str:
|
33
37
|
from os.path import join
|
34
38
|
|
39
|
+
from .constants import MIXDB_NAME
|
40
|
+
from .constants import TEST_MIXDB_NAME
|
41
|
+
|
35
42
|
if test:
|
36
|
-
name =
|
43
|
+
name = TEST_MIXDB_NAME
|
37
44
|
else:
|
38
|
-
name =
|
45
|
+
name = MIXDB_NAME
|
39
46
|
|
40
47
|
return join(location, name)
|
41
48
|
|
@@ -113,7 +120,7 @@ class MixtureDatabase:
|
|
113
120
|
|
114
121
|
@cached_property
|
115
122
|
def json(self) -> str:
|
116
|
-
from
|
123
|
+
from ..datatypes import MixtureDatabaseConfig
|
117
124
|
|
118
125
|
config = MixtureDatabaseConfig(
|
119
126
|
asr_configs=self.asr_configs,
|
@@ -121,13 +128,11 @@ class MixtureDatabase:
|
|
121
128
|
class_labels=self.class_labels,
|
122
129
|
class_weights_threshold=self.class_weights_thresholds,
|
123
130
|
feature=self.feature,
|
124
|
-
|
125
|
-
mixtures=self.mixtures
|
126
|
-
noise_mix_mode=self.noise_mix_mode,
|
127
|
-
noise_files=self.noise_files,
|
131
|
+
ir_files=self.ir_files,
|
132
|
+
mixtures=self.mixtures,
|
128
133
|
num_classes=self.num_classes,
|
129
134
|
spectral_masks=self.spectral_masks,
|
130
|
-
|
135
|
+
source_files=self.source_files,
|
131
136
|
)
|
132
137
|
return config.to_json(indent=2)
|
133
138
|
|
@@ -153,12 +158,15 @@ class MixtureDatabase:
|
|
153
158
|
return get_feature_generator_info(self.fg_config)
|
154
159
|
|
155
160
|
@cached_property
|
156
|
-
def truth_parameters(self) -> dict[str, int | None]:
|
161
|
+
def truth_parameters(self) -> dict[str, dict[str, int | None]]:
|
157
162
|
with self.db() as c:
|
158
|
-
rows = c.execute("SELECT
|
159
|
-
truth_parameters: dict[str, int | None] = {}
|
163
|
+
rows = c.execute("SELECT category, name, parameters FROM truth_parameters").fetchall()
|
164
|
+
truth_parameters: dict[str, dict[str, int | None]] = {}
|
160
165
|
for row in rows:
|
161
|
-
|
166
|
+
category, name, parameters = row
|
167
|
+
if category not in truth_parameters:
|
168
|
+
truth_parameters[category] = {}
|
169
|
+
truth_parameters[category][name] = parameters
|
162
170
|
return truth_parameters
|
163
171
|
|
164
172
|
@cached_property
|
@@ -166,11 +174,6 @@ class MixtureDatabase:
|
|
166
174
|
with self.db() as c:
|
167
175
|
return int(c.execute("SELECT top.num_classes FROM top").fetchone()[0])
|
168
176
|
|
169
|
-
@cached_property
|
170
|
-
def noise_mix_mode(self) -> str:
|
171
|
-
with self.db() as c:
|
172
|
-
return str(c.execute("SELECT top.noise_mix_mode FROM top").fetchone()[0])
|
173
|
-
|
174
177
|
@cached_property
|
175
178
|
def asr_configs(self) -> ASRConfigs:
|
176
179
|
import json
|
@@ -223,36 +226,36 @@ class MixtureDatabase:
|
|
223
226
|
"mxssnrdbf_std",
|
224
227
|
"Per-bin segmental standard deviation of the dB frame values over all frames (using feature transform)",
|
225
228
|
),
|
226
|
-
MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true
|
229
|
+
MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true sources"),
|
227
230
|
MetricDoc(
|
228
231
|
"Mixture Metrics",
|
229
232
|
"mxwsdr",
|
230
|
-
"Weighted signal distortion ratio of mixture versus true
|
233
|
+
"Weighted signal distortion ratio of mixture versus true sources",
|
231
234
|
),
|
232
235
|
MetricDoc(
|
233
236
|
"Mixture Metrics",
|
234
237
|
"mxpd",
|
235
|
-
"Phase distance between mixture and true
|
238
|
+
"Phase distance between mixture and true sources",
|
236
239
|
),
|
237
240
|
MetricDoc(
|
238
241
|
"Mixture Metrics",
|
239
242
|
"mxstoi",
|
240
|
-
"Short term objective intelligibility of mixture versus true
|
243
|
+
"Short term objective intelligibility of mixture versus true sources",
|
241
244
|
),
|
242
245
|
MetricDoc(
|
243
246
|
"Mixture Metrics",
|
244
247
|
"mxcsig",
|
245
|
-
"Predicted rating of speech distortion of mixture versus true
|
248
|
+
"Predicted rating of speech distortion of mixture versus true sources",
|
246
249
|
),
|
247
250
|
MetricDoc(
|
248
251
|
"Mixture Metrics",
|
249
252
|
"mxcbak",
|
250
|
-
"Predicted rating of background distortion of mixture versus true
|
253
|
+
"Predicted rating of background distortion of mixture versus true sources",
|
251
254
|
),
|
252
255
|
MetricDoc(
|
253
256
|
"Mixture Metrics",
|
254
257
|
"mxcovl",
|
255
|
-
"Predicted rating of overall quality of mixture versus true
|
258
|
+
"Predicted rating of overall quality of mixture versus true sources",
|
256
259
|
),
|
257
260
|
MetricDoc("Mixture Metrics", "ssnr", "Segmental SNR"),
|
258
261
|
MetricDoc("Mixture Metrics", "mxdco", "Mixture DC offset"),
|
@@ -265,26 +268,26 @@ class MixtureDatabase:
|
|
265
268
|
MetricDoc("Mixture Metrics", "mxcr", "Mixture Crest factor"),
|
266
269
|
MetricDoc("Mixture Metrics", "mxfl", "Mixture Flat factor"),
|
267
270
|
MetricDoc("Mixture Metrics", "mxpkc", "Mixture Pk count"),
|
268
|
-
MetricDoc("Mixture Metrics", "mxtdco", "Mixture
|
269
|
-
MetricDoc("Mixture Metrics", "mxtmin", "Mixture
|
270
|
-
MetricDoc("Mixture Metrics", "mxtmax", "Mixture
|
271
|
-
MetricDoc("Mixture Metrics", "mxtpkdb", "Mixture
|
272
|
-
MetricDoc("Mixture Metrics", "mxtlrms", "Mixture
|
273
|
-
MetricDoc("Mixture Metrics", "mxtpkr", "Mixture
|
274
|
-
MetricDoc("Mixture Metrics", "mxttr", "Mixture
|
275
|
-
MetricDoc("Mixture Metrics", "mxtcr", "Mixture
|
276
|
-
MetricDoc("Mixture Metrics", "mxtfl", "Mixture
|
277
|
-
MetricDoc("Mixture Metrics", "mxtpkc", "Mixture
|
278
|
-
MetricDoc("
|
279
|
-
MetricDoc("
|
280
|
-
MetricDoc("
|
281
|
-
MetricDoc("
|
282
|
-
MetricDoc("
|
283
|
-
MetricDoc("
|
284
|
-
MetricDoc("
|
285
|
-
MetricDoc("
|
286
|
-
MetricDoc("
|
287
|
-
MetricDoc("
|
271
|
+
MetricDoc("Mixture Metrics", "mxtdco", "Mixture source DC offset"),
|
272
|
+
MetricDoc("Mixture Metrics", "mxtmin", "Mixture source min level"),
|
273
|
+
MetricDoc("Mixture Metrics", "mxtmax", "Mixture source max levl"),
|
274
|
+
MetricDoc("Mixture Metrics", "mxtpkdb", "Mixture source Pk lev dB"),
|
275
|
+
MetricDoc("Mixture Metrics", "mxtlrms", "Mixture source RMS lev dB"),
|
276
|
+
MetricDoc("Mixture Metrics", "mxtpkr", "Mixture source RMS Pk dB"),
|
277
|
+
MetricDoc("Mixture Metrics", "mxttr", "Mixture source RMS Tr dB"),
|
278
|
+
MetricDoc("Mixture Metrics", "mxtcr", "Mixture source Crest factor"),
|
279
|
+
MetricDoc("Mixture Metrics", "mxtfl", "Mixture source Flat factor"),
|
280
|
+
MetricDoc("Mixture Metrics", "mxtpkc", "Mixture source Pk count"),
|
281
|
+
MetricDoc("Sources Metrics", "sdco", "Sources DC offset"),
|
282
|
+
MetricDoc("Sources Metrics", "smin", "Sources min level"),
|
283
|
+
MetricDoc("Sources Metrics", "smax", "Sources max levl"),
|
284
|
+
MetricDoc("Sources Metrics", "spkdb", "Sources Pk lev dB"),
|
285
|
+
MetricDoc("Sources Metrics", "slrms", "Sources RMS lev dB"),
|
286
|
+
MetricDoc("Sources Metrics", "spkr", "Sources RMS Pk dB"),
|
287
|
+
MetricDoc("Sources Metrics", "str", "Sources RMS Tr dB"),
|
288
|
+
MetricDoc("Sources Metrics", "scr", "Sources Crest factor"),
|
289
|
+
MetricDoc("Sources Metrics", "sfl", "Sources Flat factor"),
|
290
|
+
MetricDoc("Sources Metrics", "spkc", "Sources Pk count"),
|
288
291
|
MetricDoc("Noise Metrics", "ndco", "Noise DC offset"),
|
289
292
|
MetricDoc("Noise Metrics", "nmin", "Noise min level"),
|
290
293
|
MetricDoc("Noise Metrics", "nmax", "Noise max levl"),
|
@@ -320,16 +323,16 @@ class MixtureDatabase:
|
|
320
323
|
for name in self.asr_configs:
|
321
324
|
metrics.append(
|
322
325
|
MetricDoc(
|
323
|
-
"
|
324
|
-
f"
|
325
|
-
f"Mixture
|
326
|
+
"Source Metrics",
|
327
|
+
f"mxsasr.{name}",
|
328
|
+
f"Mixture Source ASR text using {name} ASR as defined in mixdb asr_configs parameter",
|
326
329
|
)
|
327
330
|
)
|
328
331
|
metrics.append(
|
329
332
|
MetricDoc(
|
330
|
-
"
|
331
|
-
f"
|
332
|
-
f"
|
333
|
+
"Source Metrics",
|
334
|
+
f"sasr.{name}",
|
335
|
+
f"Sources ASR text using {name} ASR as defined in mixdb asr_configs parameter",
|
333
336
|
)
|
334
337
|
)
|
335
338
|
metrics.append(
|
@@ -341,16 +344,16 @@ class MixtureDatabase:
|
|
341
344
|
)
|
342
345
|
metrics.append(
|
343
346
|
MetricDoc(
|
344
|
-
"
|
347
|
+
"Source Metrics",
|
345
348
|
f"basewer.{name}",
|
346
|
-
f"Word error rate of
|
349
|
+
f"Word error rate of sasr.{name} vs. speech text metadata for the source",
|
347
350
|
)
|
348
351
|
)
|
349
352
|
metrics.append(
|
350
353
|
MetricDoc(
|
351
354
|
"Mixture Metrics",
|
352
355
|
f"mxwer.{name}",
|
353
|
-
f"Word error rate of mxasr.{name} vs.
|
356
|
+
f"Word error rate of mxasr.{name} vs. sasr.{name}",
|
354
357
|
)
|
355
358
|
)
|
356
359
|
|
@@ -396,7 +399,7 @@ class MixtureDatabase:
|
|
396
399
|
|
397
400
|
@cached_property
|
398
401
|
def transform_frame_ms(self) -> float:
|
399
|
-
from
|
402
|
+
from ..constants import SAMPLE_RATE
|
400
403
|
|
401
404
|
return float(self.ft_config.overlap) / float(SAMPLE_RATE / 1000)
|
402
405
|
|
@@ -417,12 +420,7 @@ class MixtureDatabase:
|
|
417
420
|
return self.ft_config.overlap * self.fg_decimation * self.fg_step
|
418
421
|
|
419
422
|
def total_samples(self, m_ids: GeneralizedIDs = "*") -> int:
|
420
|
-
samples
|
421
|
-
for m_id in self.mixids_to_list(m_ids):
|
422
|
-
s = self.mixture(m_id).samples
|
423
|
-
if s is not None:
|
424
|
-
samples += s
|
425
|
-
return samples
|
423
|
+
return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(m_ids)])
|
426
424
|
|
427
425
|
def total_transform_frames(self, m_ids: GeneralizedIDs = "*") -> int:
|
428
426
|
return self.total_samples(m_ids) // self.ft_config.overlap
|
@@ -476,30 +474,18 @@ class MixtureDatabase:
|
|
476
474
|
).fetchall()
|
477
475
|
]
|
478
476
|
|
479
|
-
|
480
|
-
|
481
|
-
"""Get truth configs from db
|
482
|
-
|
483
|
-
:return: Truth configs
|
484
|
-
"""
|
485
|
-
import json
|
477
|
+
def category_truth_configs(self, category: str) -> dict[str, str]:
|
478
|
+
return _category_truth_configs(self.db, category, self.use_cache)
|
486
479
|
|
487
|
-
|
480
|
+
def source_truth_configs(self, s_id: int) -> TruthConfigs:
|
481
|
+
return _source_truth_configs(self.db, s_id, self.use_cache)
|
488
482
|
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
function=truth_config["function"],
|
496
|
-
stride_reduction=truth_config["stride_reduction"],
|
497
|
-
config=truth_config["config"],
|
498
|
-
)
|
499
|
-
return truth_configs
|
500
|
-
|
501
|
-
def target_truth_configs(self, t_id: int) -> TruthConfigs:
|
502
|
-
return _target_truth_configs(self.db, t_id, self.use_cache)
|
483
|
+
def mixture_truth_configs(self, m_id: int) -> TruthsConfigs:
|
484
|
+
mixture = self.mixture(m_id)
|
485
|
+
return {
|
486
|
+
category: self.source_truth_configs(mixture.all_sources[category].file_id)
|
487
|
+
for category in mixture.all_sources
|
488
|
+
}
|
503
489
|
|
504
490
|
@cached_property
|
505
491
|
def random_snrs(self) -> list[float]:
|
@@ -509,10 +495,7 @@ class MixtureDatabase:
|
|
509
495
|
"""
|
510
496
|
with self.db() as c:
|
511
497
|
return list(
|
512
|
-
{
|
513
|
-
float(item[0])
|
514
|
-
for item in c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 1").fetchall()
|
515
|
-
}
|
498
|
+
{float(item[0]) for item in c.execute("SELECT snr FROM source WHERE snr_random == 1").fetchall()}
|
516
499
|
)
|
517
500
|
|
518
501
|
@cached_property
|
@@ -523,10 +506,7 @@ class MixtureDatabase:
|
|
523
506
|
"""
|
524
507
|
with self.db() as c:
|
525
508
|
return list(
|
526
|
-
{
|
527
|
-
float(item[0])
|
528
|
-
for item in c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 0").fetchall()
|
529
|
-
}
|
509
|
+
{float(item[0]) for item in c.execute("SELECT snr FROM source WHERE snr_random == 0").fetchall()}
|
530
510
|
)
|
531
511
|
|
532
512
|
@cached_property
|
@@ -570,199 +550,246 @@ class MixtureDatabase:
|
|
570
550
|
return _spectral_mask(self.db, sm_id, self.use_cache)
|
571
551
|
|
572
552
|
@cached_property
|
573
|
-
def
|
574
|
-
"""Get
|
553
|
+
def source_files(self) -> dict[str, list[SourceFile]]:
|
554
|
+
"""Get source files from db
|
575
555
|
|
576
|
-
:return:
|
556
|
+
:return: Source files
|
577
557
|
"""
|
578
558
|
import json
|
579
559
|
|
580
|
-
from
|
581
|
-
from
|
582
|
-
from .db_datatypes import
|
560
|
+
from ..datatypes import TruthConfig
|
561
|
+
from ..datatypes import TruthConfigs
|
562
|
+
from .db_datatypes import SourceFileRecord
|
583
563
|
|
584
564
|
with self.db() as c:
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
(
|
599
|
-
|
600
|
-
|
601
|
-
truth_configs
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
truth_configs=
|
613
|
-
|
565
|
+
source_files: dict[str, list[SourceFile]] = {}
|
566
|
+
categories = c.execute("SELECT DISTINCT category FROM source_file").fetchall()
|
567
|
+
for category in categories:
|
568
|
+
source_files[category[0]] = []
|
569
|
+
source_file_records = [
|
570
|
+
SourceFileRecord(*result)
|
571
|
+
for result in c.execute(
|
572
|
+
"""
|
573
|
+
SELECT *
|
574
|
+
FROM source_file
|
575
|
+
WHERE ? = source_file.category
|
576
|
+
""",
|
577
|
+
(category[0],),
|
578
|
+
).fetchall()
|
579
|
+
]
|
580
|
+
for source_file_record in source_file_records:
|
581
|
+
truth_configs: TruthConfigs = {}
|
582
|
+
for truth_config_records in c.execute(
|
583
|
+
"""
|
584
|
+
SELECT truth_config.config
|
585
|
+
FROM truth_config, source_file_truth_config
|
586
|
+
WHERE ? = source_file_truth_config.source_file_id
|
587
|
+
AND truth_config.id = source_file_truth_config.truth_config_id
|
588
|
+
""",
|
589
|
+
(source_file_record.id,),
|
590
|
+
).fetchall():
|
591
|
+
truth_config = json.loads(truth_config_records[0])
|
592
|
+
truth_configs[truth_config["name"]] = TruthConfig(
|
593
|
+
function=truth_config["function"],
|
594
|
+
stride_reduction=truth_config["stride_reduction"],
|
595
|
+
config=truth_config["config"],
|
596
|
+
)
|
597
|
+
source_files[source_file_record.category].append(
|
598
|
+
SourceFile(
|
599
|
+
id=source_file_record.id,
|
600
|
+
category=source_file_record.category,
|
601
|
+
name=source_file_record.name,
|
602
|
+
samples=source_file_record.samples,
|
603
|
+
class_indices=json.loads(source_file_record.class_indices),
|
604
|
+
level_type=source_file_record.level_type,
|
605
|
+
truth_configs=truth_configs,
|
606
|
+
speaker_id=source_file_record.speaker_id,
|
607
|
+
)
|
614
608
|
)
|
615
|
-
|
616
|
-
return target_files
|
609
|
+
return source_files
|
617
610
|
|
618
611
|
@cached_property
|
619
|
-
def
|
620
|
-
"""Get
|
612
|
+
def source_file_ids(self) -> dict[str, list[int]]:
|
613
|
+
"""Get source file IDs from db
|
621
614
|
|
622
|
-
:return:
|
615
|
+
:return: Dictionary of list of source file IDs
|
623
616
|
"""
|
624
617
|
with self.db() as c:
|
625
|
-
|
618
|
+
source_file_ids: dict[str, list[int]] = {}
|
619
|
+
categories = c.execute("SELECT DISTINCT category FROM source_file").fetchall()
|
620
|
+
for category in categories:
|
621
|
+
# items = c.execute(
|
622
|
+
# """
|
623
|
+
# SELECT source_file.id
|
624
|
+
# FROM source_file
|
625
|
+
# WHERE ? = source_file.category
|
626
|
+
# """,
|
627
|
+
# (category[0],),
|
628
|
+
# ).fetchall()
|
629
|
+
# source_file_ids[category[0]] = [int(item[0]) for item in items]
|
630
|
+
source_file_ids[category[0]] = [
|
631
|
+
int(item[0])
|
632
|
+
for item in c.execute(
|
633
|
+
"""
|
634
|
+
SELECT source_file.id
|
635
|
+
FROM source_file
|
636
|
+
WHERE ? = source_file.category
|
637
|
+
""",
|
638
|
+
(category[0],),
|
639
|
+
).fetchall()
|
640
|
+
]
|
641
|
+
return source_file_ids
|
626
642
|
|
627
|
-
def
|
628
|
-
"""Get
|
643
|
+
def source_file(self, s_id: int) -> SourceFile:
|
644
|
+
"""Get source file with ID from db
|
629
645
|
|
630
|
-
:param
|
631
|
-
:return:
|
646
|
+
:param s_id: Source file ID
|
647
|
+
:return: Source file
|
632
648
|
"""
|
633
|
-
return
|
649
|
+
return _source_file(self.db, s_id, self.use_cache)
|
634
650
|
|
635
|
-
|
636
|
-
|
637
|
-
"""Get number of target files from db
|
651
|
+
def num_source_files(self, category: str) -> int:
|
652
|
+
"""Get number of source files from category from db
|
638
653
|
|
639
|
-
:
|
654
|
+
:param category: Source category
|
655
|
+
:return: Number of source files
|
640
656
|
"""
|
641
|
-
|
642
|
-
return int(c.execute("SELECT count(target_file.id) FROM target_file").fetchone()[0])
|
657
|
+
return _num_source_files(self.db, category, self.use_cache)
|
643
658
|
|
644
659
|
@cached_property
|
645
|
-
def
|
646
|
-
"""Get
|
660
|
+
def ir_files(self) -> list[ImpulseResponseFile]:
|
661
|
+
"""Get impulse response files from db
|
647
662
|
|
648
|
-
:return:
|
663
|
+
:return: Impulse response files
|
649
664
|
"""
|
650
|
-
|
651
|
-
return [
|
652
|
-
NoiseFile(name=noise[0], samples=noise[1])
|
653
|
-
for noise in c.execute("SELECT noise_file.name, samples FROM noise_file").fetchall()
|
654
|
-
]
|
655
|
-
|
656
|
-
@cached_property
|
657
|
-
def noise_file_ids(self) -> list[int]:
|
658
|
-
"""Get noise file IDs from db
|
665
|
+
from .db_datatypes import ImpulseResponseFileRecord
|
659
666
|
|
660
|
-
:return: List of noise file IDs
|
661
|
-
"""
|
662
667
|
with self.db() as c:
|
663
|
-
|
668
|
+
files: list[ImpulseResponseFile] = []
|
669
|
+
entries = c.execute("SELECT * FROM ir_file").fetchall()
|
670
|
+
for entry in entries:
|
671
|
+
file = ImpulseResponseFileRecord(*entry)
|
672
|
+
|
673
|
+
tags = [
|
674
|
+
tag[0]
|
675
|
+
for tag in c.execute(
|
676
|
+
"""
|
677
|
+
SELECT ir_tag.tag
|
678
|
+
FROM ir_tag, ir_file_ir_tag
|
679
|
+
WHERE ? = ir_file_ir_tag.file_id
|
680
|
+
AND ir_tag.id = ir_file_ir_tag.tag_id
|
681
|
+
""",
|
682
|
+
(file.id,),
|
683
|
+
).fetchall()
|
684
|
+
]
|
664
685
|
|
665
|
-
|
666
|
-
|
686
|
+
files.append(
|
687
|
+
ImpulseResponseFile(
|
688
|
+
delay=file.delay,
|
689
|
+
name=file.name,
|
690
|
+
tags=tags,
|
691
|
+
)
|
692
|
+
)
|
667
693
|
|
668
|
-
|
669
|
-
:return: Noise file
|
670
|
-
"""
|
671
|
-
return _noise_file(self.db, n_id, self.use_cache)
|
694
|
+
return files
|
672
695
|
|
673
696
|
@cached_property
|
674
|
-
def
|
675
|
-
"""Get
|
697
|
+
def ir_file_ids(self) -> list[int]:
|
698
|
+
"""Get impulse response file IDs from db
|
676
699
|
|
677
|
-
:return:
|
700
|
+
:return: List of impulse response file IDs
|
678
701
|
"""
|
679
702
|
with self.db() as c:
|
680
|
-
return int(c.execute("SELECT
|
703
|
+
return [int(item[0]) for item in c.execute("SELECT ir_file.id FROM ir_file").fetchall()]
|
681
704
|
|
682
|
-
|
683
|
-
|
684
|
-
"""Get impulse response files from db
|
705
|
+
def ir_file_ids_for_tag(self, tag: str) -> list[int]:
|
706
|
+
"""Get impulse response file IDs for given tag from db
|
685
707
|
|
686
|
-
:return:
|
708
|
+
:return: List of impulse response file IDs for given tag
|
687
709
|
"""
|
688
|
-
import json
|
689
|
-
|
690
|
-
from .datatypes import ImpulseResponseFile
|
691
|
-
|
692
710
|
with self.db() as c:
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
"SELECT impulse_response_file.* FROM impulse_response_file"
|
697
|
-
).fetchall()
|
698
|
-
]
|
699
|
-
|
700
|
-
@cached_property
|
701
|
-
def impulse_response_file_ids(self) -> list[int]:
|
702
|
-
"""Get impulse response file IDs from db
|
711
|
+
tag_id = c.execute("SELECT ir_tag.id FROM ir_tag WHERE ? = ir_tag.tag", (tag,)).fetchone()
|
712
|
+
if not tag_id:
|
713
|
+
return []
|
703
714
|
|
704
|
-
:return: List of impulse response file IDs
|
705
|
-
"""
|
706
|
-
with self.db() as c:
|
707
715
|
return [
|
708
|
-
int(item[0])
|
709
|
-
for item in c.execute(
|
716
|
+
int(item[0] - 1)
|
717
|
+
for item in c.execute(
|
718
|
+
"""
|
719
|
+
SELECT ir_file_ir_tag.file_id
|
720
|
+
FROM ir_file_ir_tag
|
721
|
+
WHERE ? = ir_file_ir_tag.tag_id
|
722
|
+
""",
|
723
|
+
(tag_id[0],),
|
724
|
+
).fetchall()
|
710
725
|
]
|
711
726
|
|
712
|
-
def
|
727
|
+
def ir_file(self, ir_id: int) -> str:
|
713
728
|
"""Get impulse response file name with ID from db
|
714
729
|
|
715
730
|
:param ir_id: Impulse response file ID
|
716
731
|
:return: Impulse response file name
|
717
732
|
"""
|
718
|
-
|
719
|
-
return None
|
720
|
-
return _impulse_response_file(self.db, ir_id, self.use_cache)
|
733
|
+
return _ir_file(self.db, ir_id, self.use_cache)
|
721
734
|
|
722
|
-
def
|
735
|
+
def ir_delay(self, ir_id: int) -> int:
|
723
736
|
"""Get impulse response delay with ID from db
|
724
737
|
|
725
738
|
:param ir_id: Impulse response file ID
|
726
739
|
:return: Impulse response delay
|
727
740
|
"""
|
728
|
-
|
729
|
-
return None
|
730
|
-
return _impulse_response_delay(self.db, ir_id, self.use_cache)
|
741
|
+
return _ir_delay(self.db, ir_id, self.use_cache)
|
731
742
|
|
732
743
|
@cached_property
|
733
|
-
def
|
744
|
+
def num_ir_files(self) -> int:
|
734
745
|
"""Get number of impulse response files from db
|
735
746
|
|
736
747
|
:return: Number of impulse response files
|
737
748
|
"""
|
738
749
|
with self.db() as c:
|
739
|
-
return int(c.execute("SELECT count(
|
750
|
+
return int(c.execute("SELECT count(ir_file.id) FROM ir_file").fetchone()[0])
|
751
|
+
|
752
|
+
@cached_property
|
753
|
+
def ir_tags(self) -> list[str]:
|
754
|
+
"""Get tags of impulse response files from db
|
755
|
+
|
756
|
+
:return: Tags of impulse response files
|
757
|
+
"""
|
758
|
+
with self.db() as c:
|
759
|
+
return [tag[0] for tag in c.execute("SELECT ir_tag.tag FROM ir_tag").fetchall()]
|
740
760
|
|
761
|
+
@property
|
741
762
|
def mixtures(self) -> list[Mixture]:
|
742
763
|
"""Get mixtures from db
|
743
764
|
|
744
765
|
:return: Mixtures
|
745
766
|
"""
|
746
767
|
from .db_datatypes import MixtureRecord
|
747
|
-
from .db_datatypes import
|
768
|
+
from .db_datatypes import SourceRecord
|
748
769
|
from .helpers import to_mixture
|
749
|
-
from .helpers import
|
770
|
+
from .helpers import to_source
|
750
771
|
|
751
772
|
with self.db() as c:
|
752
773
|
mixtures: list[Mixture] = []
|
753
774
|
for mixture in [MixtureRecord(*record) for record in c.execute("SELECT * FROM mixture").fetchall()]:
|
754
|
-
|
755
|
-
|
756
|
-
for
|
775
|
+
sources_list = [
|
776
|
+
to_source(SourceRecord(*source))
|
777
|
+
for source in c.execute(
|
757
778
|
"""
|
758
|
-
SELECT
|
759
|
-
FROM
|
760
|
-
WHERE ? =
|
779
|
+
SELECT source.*
|
780
|
+
FROM source, mixture_source
|
781
|
+
WHERE ? = mixture_source.mixture_id AND source.id = mixture_source.source_id
|
761
782
|
""",
|
762
783
|
(mixture.id,),
|
763
784
|
).fetchall()
|
764
785
|
]
|
765
|
-
|
786
|
+
|
787
|
+
sources: Sources = {}
|
788
|
+
for source in sources_list:
|
789
|
+
sources[self.source_file(source.file_id).category] = source
|
790
|
+
|
791
|
+
mixtures.append(to_mixture(mixture, sources))
|
792
|
+
|
766
793
|
return mixtures
|
767
794
|
|
768
795
|
@cached_property
|
@@ -806,229 +833,340 @@ class MixtureDatabase:
|
|
806
833
|
with self.db() as c:
|
807
834
|
return int(c.execute("SELECT count(mixture.id) FROM mixture").fetchone()[0])
|
808
835
|
|
809
|
-
def read_mixture_data(self, m_id: int, items: list[str] | str) -> Any:
|
836
|
+
def read_mixture_data(self, m_id: int, items: list[str] | str) -> dict[str, Any]:
|
810
837
|
"""Read mixture data
|
811
838
|
|
812
839
|
:param m_id: Zero-based mixture ID
|
813
840
|
:param items: String(s) of dataset(s) to retrieve
|
814
|
-
:return:
|
841
|
+
:return: Dictionary of name: data
|
815
842
|
"""
|
816
|
-
from
|
843
|
+
from .data_io import read_cached_data
|
817
844
|
|
818
845
|
return read_cached_data(self.location, "mixture", self.mixture(m_id).name, items)
|
819
846
|
|
820
|
-
def
|
821
|
-
"""Read
|
847
|
+
def read_source_audio(self, s_id: int) -> AudioT:
|
848
|
+
"""Read source audio
|
822
849
|
|
823
|
-
:param
|
824
|
-
:return:
|
850
|
+
:param s_id: Source ID
|
851
|
+
:return: Source audio
|
825
852
|
"""
|
826
853
|
from .audio import read_audio
|
827
854
|
|
828
|
-
return read_audio(self.
|
829
|
-
|
830
|
-
def augmented_noise_audio(self, mixture: Mixture) -> AudioT:
|
831
|
-
"""Get augmented noise audio
|
832
|
-
|
833
|
-
:param mixture: Mixture
|
834
|
-
:return: Augmented noise audio
|
835
|
-
"""
|
836
|
-
from .audio import read_audio
|
837
|
-
from .augmentation import apply_augmentation
|
838
|
-
|
839
|
-
noise = self.noise_file(mixture.noise.file_id)
|
840
|
-
audio = read_audio(noise.name, self.use_cache)
|
841
|
-
audio = apply_augmentation(self, audio, mixture.noise.augmentation.pre)
|
842
|
-
|
843
|
-
return audio
|
855
|
+
return read_audio(self.source_file(s_id).name, self.use_cache)
|
844
856
|
|
845
857
|
def mixture_class_indices(self, m_id: int) -> list[int]:
|
846
858
|
class_indices: list[int] = []
|
847
|
-
for
|
848
|
-
class_indices.extend(self.
|
859
|
+
for s_id in self.mixture(m_id).source_ids.values():
|
860
|
+
class_indices.extend(self.source_file(s_id).class_indices)
|
849
861
|
return sorted(set(class_indices))
|
850
862
|
|
851
|
-
def
|
852
|
-
"""Get the
|
863
|
+
def mixture_sources(self, m_id: int, force: bool = False, cache: bool = False) -> SourcesAudioT:
|
864
|
+
"""Get the pre-truth source audio data (one per source in the mixture) for the given mixture ID
|
853
865
|
|
854
866
|
:param m_id: Zero-based mixture ID
|
855
867
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
856
|
-
:
|
868
|
+
:param cache: Cache result
|
869
|
+
:return: Dictionary of pre-truth source audio data (one per source in the mixture)
|
857
870
|
"""
|
858
|
-
from .
|
859
|
-
from .
|
860
|
-
from .
|
871
|
+
from .data_io import write_cached_data
|
872
|
+
from .effects import apply_effects
|
873
|
+
from .effects import conform_audio_to_length
|
861
874
|
|
862
875
|
if not force:
|
863
|
-
|
864
|
-
if
|
865
|
-
return
|
876
|
+
sources = self.read_mixture_data(m_id, "sources")["sources"]
|
877
|
+
if sources is not None:
|
878
|
+
return sources
|
866
879
|
|
867
880
|
mixture = self.mixture(m_id)
|
868
881
|
if mixture is None:
|
869
882
|
raise ValueError(f"Could not find mixture for m_id: {m_id}")
|
870
883
|
|
871
|
-
|
872
|
-
for
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
884
|
+
sources = {}
|
885
|
+
for category, source in mixture.all_sources.items():
|
886
|
+
source = mixture.all_sources[category]
|
887
|
+
audio = self.read_source_audio(source.file_id)
|
888
|
+
audio = apply_effects(self, audio, source.effects, pre=True, post=False)
|
889
|
+
audio = conform_audio_to_length(audio, mixture.samples, source.repeat, source.start)
|
890
|
+
sources[category] = audio
|
891
|
+
|
892
|
+
if cache:
|
893
|
+
write_cached_data(
|
894
|
+
location=self.location,
|
895
|
+
name="mixture",
|
896
|
+
index=mixture.name,
|
897
|
+
items={"sources": sources},
|
879
898
|
)
|
880
|
-
target_audio = apply_gain(audio=target_audio, gain=mixture.target_snr_gain)
|
881
|
-
target_audio = pad_audio_to_length(audio=target_audio, length=mixture.samples)
|
882
|
-
targets_audio.append(target_audio)
|
883
899
|
|
884
|
-
return
|
900
|
+
return sources
|
885
901
|
|
886
|
-
def
|
887
|
-
|
902
|
+
def mixture_sources_f(
|
903
|
+
self,
|
904
|
+
m_id: int,
|
905
|
+
sources: SourcesAudioT | None = None,
|
906
|
+
force: bool = False,
|
907
|
+
cache: bool = False,
|
908
|
+
) -> SourcesAudioF:
|
909
|
+
"""Get the pre-truth source transform data (one per source in the mixture) for the given mixture ID
|
888
910
|
|
889
911
|
:param m_id: Zero-based mixture ID
|
890
|
-
:param
|
912
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
891
913
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
892
|
-
:
|
914
|
+
:param cache: Cache result
|
915
|
+
:return: Dictionary of pre-truth source transform data (one per source in the mixture)
|
893
916
|
"""
|
917
|
+
from .data_io import write_cached_data
|
894
918
|
from .helpers import forward_transform
|
895
919
|
|
896
|
-
if
|
897
|
-
|
920
|
+
if sources is None:
|
921
|
+
sources = self.mixture_sources(m_id, force)
|
898
922
|
|
899
|
-
|
923
|
+
sources_f = {category: forward_transform(sources[category], self.ft_config) for category in sources}
|
924
|
+
|
925
|
+
if cache:
|
926
|
+
write_cached_data(
|
927
|
+
location=self.location,
|
928
|
+
name="mixture",
|
929
|
+
index=self.mixture(m_id).name,
|
930
|
+
items={"sources_f": sources_f},
|
931
|
+
)
|
900
932
|
|
901
|
-
|
902
|
-
|
933
|
+
return sources_f
|
934
|
+
|
935
|
+
def mixture_source(
|
936
|
+
self,
|
937
|
+
m_id: int,
|
938
|
+
sources: SourcesAudioT | None = None,
|
939
|
+
force: bool = False,
|
940
|
+
cache: bool = False,
|
941
|
+
) -> AudioT:
|
942
|
+
"""Get the post-truth, summed, and gained source audio data for the given mixture ID
|
903
943
|
|
904
944
|
:param m_id: Zero-based mixture ID
|
905
|
-
:param
|
945
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
906
946
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
907
|
-
:
|
947
|
+
:param cache: Cache result
|
948
|
+
:return: Post-truth, gained, and summed source audio data
|
908
949
|
"""
|
909
|
-
|
950
|
+
import numpy as np
|
951
|
+
|
952
|
+
from .data_io import write_cached_data
|
953
|
+
from .effects import apply_effects
|
910
954
|
|
911
955
|
if not force:
|
912
|
-
|
913
|
-
if
|
914
|
-
return
|
956
|
+
source = self.read_mixture_data(m_id, "source")["source"]
|
957
|
+
if source is not None:
|
958
|
+
return source
|
959
|
+
|
960
|
+
if sources is None:
|
961
|
+
sources = self.mixture_sources(m_id, force)
|
962
|
+
|
963
|
+
mixture = self.mixture(m_id)
|
915
964
|
|
916
|
-
|
917
|
-
|
965
|
+
source = np.sum(
|
966
|
+
[
|
967
|
+
apply_effects(
|
968
|
+
self,
|
969
|
+
audio=sources[category],
|
970
|
+
effects=mixture.all_sources[category].effects,
|
971
|
+
pre=False,
|
972
|
+
post=True,
|
973
|
+
)
|
974
|
+
* mixture.all_sources[category].snr_gain
|
975
|
+
for category in sources
|
976
|
+
if category != "noise"
|
977
|
+
],
|
978
|
+
axis=0,
|
979
|
+
)
|
918
980
|
|
919
|
-
|
981
|
+
if cache:
|
982
|
+
write_cached_data(
|
983
|
+
location=self.location,
|
984
|
+
name="mixture",
|
985
|
+
index=mixture.name,
|
986
|
+
items={"source": source},
|
987
|
+
)
|
920
988
|
|
921
|
-
|
989
|
+
return source
|
990
|
+
|
991
|
+
def mixture_source_f(
|
922
992
|
self,
|
923
993
|
m_id: int,
|
924
|
-
|
925
|
-
|
994
|
+
sources: SourcesAudioT | None = None,
|
995
|
+
source: AudioT | None = None,
|
926
996
|
force: bool = False,
|
997
|
+
cache: bool = False,
|
927
998
|
) -> AudioF:
|
928
|
-
"""Get the
|
999
|
+
"""Get the post-truth, summed, and gained source transform data for the given mixture ID
|
929
1000
|
|
930
1001
|
:param m_id: Zero-based mixture ID
|
931
|
-
:param
|
932
|
-
:param
|
1002
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1003
|
+
:param source: Post-truth, gained, and summed source audio for the given m_id
|
933
1004
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
934
|
-
:
|
1005
|
+
:param cache: Cache result
|
1006
|
+
:return: Post-truth, gained, and summed source transform data
|
935
1007
|
"""
|
1008
|
+
from .data_io import write_cached_data
|
936
1009
|
from .helpers import forward_transform
|
937
1010
|
|
938
|
-
if
|
939
|
-
|
1011
|
+
if source is None:
|
1012
|
+
source = self.mixture_source(m_id, sources, force)
|
1013
|
+
|
1014
|
+
source_f = forward_transform(source, self.ft_config)
|
1015
|
+
|
1016
|
+
if cache:
|
1017
|
+
write_cached_data(
|
1018
|
+
location=self.location,
|
1019
|
+
name="mixture",
|
1020
|
+
index=self.mixture(m_id).name,
|
1021
|
+
items={"source_f": source_f},
|
1022
|
+
)
|
940
1023
|
|
941
|
-
return
|
1024
|
+
return source_f
|
942
1025
|
|
943
|
-
def mixture_noise(
|
944
|
-
|
1026
|
+
def mixture_noise(
|
1027
|
+
self,
|
1028
|
+
m_id: int,
|
1029
|
+
sources: SourcesAudioT | None = None,
|
1030
|
+
force: bool = False,
|
1031
|
+
cache: bool = False,
|
1032
|
+
) -> AudioT:
|
1033
|
+
"""Get the post-truth and gained noise audio data for the given mixture ID
|
945
1034
|
|
946
1035
|
:param m_id: Zero-based mixture ID
|
1036
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
947
1037
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
948
|
-
:
|
1038
|
+
:param cache: Cache result
|
1039
|
+
:return: Post-truth and gained noise audio data
|
949
1040
|
"""
|
950
|
-
from .
|
951
|
-
from .
|
1041
|
+
from .data_io import write_cached_data
|
1042
|
+
from .effects import apply_effects
|
952
1043
|
|
953
1044
|
if not force:
|
954
|
-
noise = self.read_mixture_data(m_id, "noise")
|
1045
|
+
noise = self.read_mixture_data(m_id, "noise")["noise"]
|
955
1046
|
if noise is not None:
|
956
1047
|
return noise
|
957
1048
|
|
958
|
-
|
959
|
-
|
960
|
-
|
961
|
-
|
1049
|
+
if sources is None:
|
1050
|
+
sources = self.mixture_sources(m_id, force)
|
1051
|
+
|
1052
|
+
noise = self.mixture(m_id).noise
|
1053
|
+
noise = apply_effects(self, sources["noise"], noise.effects, pre=False, post=True) * noise.snr_gain
|
962
1054
|
|
963
|
-
|
964
|
-
|
1055
|
+
if cache:
|
1056
|
+
write_cached_data(
|
1057
|
+
location=self.location,
|
1058
|
+
name="mixture",
|
1059
|
+
index=self.mixture(m_id).name,
|
1060
|
+
items={"noise": noise},
|
1061
|
+
)
|
1062
|
+
|
1063
|
+
return noise
|
1064
|
+
|
1065
|
+
def mixture_noise_f(
|
1066
|
+
self,
|
1067
|
+
m_id: int,
|
1068
|
+
sources: SourcesAudioT | None = None,
|
1069
|
+
noise: AudioT | None = None,
|
1070
|
+
force: bool = False,
|
1071
|
+
cache: bool = False,
|
1072
|
+
) -> AudioF:
|
1073
|
+
"""Get the post-truth and gained noise transform for the given mixture ID
|
965
1074
|
|
966
1075
|
:param m_id: Zero-based mixture ID
|
967
|
-
:param
|
1076
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1077
|
+
:param noise: Post-truth and gained noise audio data
|
968
1078
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
969
|
-
:
|
1079
|
+
:param cache: Cache result
|
1080
|
+
:return: Post-truth and gained noise transform data
|
970
1081
|
"""
|
1082
|
+
from .data_io import write_cached_data
|
971
1083
|
from .helpers import forward_transform
|
972
1084
|
|
973
1085
|
if force or noise is None:
|
974
|
-
noise = self.mixture_noise(m_id, force)
|
1086
|
+
noise = self.mixture_noise(m_id, sources, force)
|
1087
|
+
|
1088
|
+
noise_f = forward_transform(noise, self.ft_config)
|
1089
|
+
if cache:
|
1090
|
+
write_cached_data(
|
1091
|
+
location=self.location,
|
1092
|
+
name="mixture",
|
1093
|
+
index=self.mixture(m_id).name,
|
1094
|
+
items={"noise_f": noise_f},
|
1095
|
+
)
|
975
1096
|
|
976
|
-
return
|
1097
|
+
return noise_f
|
977
1098
|
|
978
1099
|
def mixture_mixture(
|
979
1100
|
self,
|
980
1101
|
m_id: int,
|
981
|
-
|
982
|
-
|
1102
|
+
sources: SourcesAudioT | None = None,
|
1103
|
+
source: AudioT | None = None,
|
983
1104
|
noise: AudioT | None = None,
|
984
1105
|
force: bool = False,
|
1106
|
+
cache: bool = False,
|
985
1107
|
) -> AudioT:
|
986
1108
|
"""Get the mixture audio data for the given mixture ID
|
987
1109
|
|
988
1110
|
:param m_id: Zero-based mixture ID
|
989
|
-
:param
|
990
|
-
:param
|
991
|
-
:param noise:
|
1111
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1112
|
+
:param source: Post-truth, gained, and summed source audio data
|
1113
|
+
:param noise: Post-truth and gained noise audio data
|
992
1114
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1115
|
+
:param cache: Cache result
|
993
1116
|
:return: Mixture audio data
|
994
1117
|
"""
|
1118
|
+
from .data_io import write_cached_data
|
1119
|
+
|
995
1120
|
if not force:
|
996
|
-
mixture = self.read_mixture_data(m_id, "mixture")
|
1121
|
+
mixture = self.read_mixture_data(m_id, "mixture")["mixture"]
|
997
1122
|
if mixture is not None:
|
998
1123
|
return mixture
|
999
1124
|
|
1000
|
-
if
|
1001
|
-
|
1125
|
+
if source is None:
|
1126
|
+
source = self.mixture_source(m_id, sources, force)
|
1002
1127
|
|
1003
|
-
if
|
1004
|
-
noise = self.mixture_noise(m_id, force)
|
1128
|
+
if noise is None:
|
1129
|
+
noise = self.mixture_noise(m_id, sources, force)
|
1130
|
+
|
1131
|
+
mixture = source + noise
|
1132
|
+
|
1133
|
+
if cache:
|
1134
|
+
write_cached_data(
|
1135
|
+
location=self.location,
|
1136
|
+
name="mixture",
|
1137
|
+
index=self.mixture(m_id).name,
|
1138
|
+
items={"mixture": mixture},
|
1139
|
+
)
|
1005
1140
|
|
1006
|
-
return
|
1141
|
+
return mixture
|
1007
1142
|
|
1008
1143
|
def mixture_mixture_f(
|
1009
1144
|
self,
|
1010
1145
|
m_id: int,
|
1011
|
-
|
1012
|
-
|
1146
|
+
sources: SourcesAudioT | None = None,
|
1147
|
+
source: AudioT | None = None,
|
1013
1148
|
noise: AudioT | None = None,
|
1014
1149
|
mixture: AudioT | None = None,
|
1015
1150
|
force: bool = False,
|
1151
|
+
cache: bool = False,
|
1016
1152
|
) -> AudioF:
|
1017
1153
|
"""Get the mixture transform for the given mixture ID
|
1018
1154
|
|
1019
1155
|
:param m_id: Zero-based mixture ID
|
1020
|
-
:param
|
1021
|
-
:param
|
1022
|
-
:param noise:
|
1156
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1157
|
+
:param source: Post-truth, gained, and summed source audio data
|
1158
|
+
:param noise: Post-truth and gained noise audio data
|
1023
1159
|
:param mixture: Mixture audio data
|
1024
1160
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1161
|
+
:param cache: Cache result
|
1025
1162
|
:return: Mixture transform data
|
1026
1163
|
"""
|
1164
|
+
from .data_io import write_cached_data
|
1027
1165
|
from .helpers import forward_transform
|
1028
1166
|
from .spectral_mask import apply_spectral_mask
|
1029
1167
|
|
1030
|
-
if
|
1031
|
-
mixture = self.mixture_mixture(m_id,
|
1168
|
+
if mixture is None:
|
1169
|
+
mixture = self.mixture_mixture(m_id, sources, source, noise, force)
|
1032
1170
|
|
1033
1171
|
mixture_f = forward_transform(mixture, self.ft_config)
|
1034
1172
|
|
@@ -1040,80 +1178,79 @@ class MixtureDatabase:
|
|
1040
1178
|
seed=m.spectral_mask_seed,
|
1041
1179
|
)
|
1042
1180
|
|
1181
|
+
if cache:
|
1182
|
+
write_cached_data(
|
1183
|
+
location=self.location,
|
1184
|
+
name="mixture",
|
1185
|
+
index=self.mixture(m_id).name,
|
1186
|
+
items={"mixture_f": mixture_f},
|
1187
|
+
)
|
1188
|
+
|
1043
1189
|
return mixture_f
|
1044
1190
|
|
1045
|
-
def mixture_truth_t(
|
1046
|
-
self,
|
1047
|
-
m_id: int,
|
1048
|
-
targets: list[AudioT] | None = None,
|
1049
|
-
noise: AudioT | None = None,
|
1050
|
-
mixture: AudioT | None = None,
|
1051
|
-
force: bool = False,
|
1052
|
-
) -> list[TruthDict]:
|
1191
|
+
def mixture_truth_t(self, m_id: int, force: bool = False, cache: bool = False) -> TruthsDict:
|
1053
1192
|
"""Get the truth_t data for the given mixture ID
|
1054
1193
|
|
1055
1194
|
:param m_id: Zero-based mixture ID
|
1056
|
-
:param targets: List of augmented target audio data (one per target in the mixup) for the given mixture ID
|
1057
|
-
:param noise: Augmented noise audio data for the given mixture ID
|
1058
|
-
:param mixture: Mixture audio data for the given mixture ID
|
1059
1195
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1196
|
+
:param cache: Cache result
|
1060
1197
|
:return: list of truth_t data
|
1061
1198
|
"""
|
1199
|
+
from .data_io import write_cached_data
|
1062
1200
|
from .truth import truth_function
|
1063
1201
|
|
1064
1202
|
if not force:
|
1065
|
-
truth_t = self.read_mixture_data(m_id, "truth_t")
|
1203
|
+
truth_t = self.read_mixture_data(m_id, "truth_t")["truth_t"]
|
1066
1204
|
if truth_t is not None:
|
1067
1205
|
return truth_t
|
1068
1206
|
|
1069
|
-
|
1070
|
-
targets = self.mixture_targets(m_id, force)
|
1071
|
-
|
1072
|
-
if force or noise is None:
|
1073
|
-
noise = self.mixture_noise(m_id, force)
|
1074
|
-
|
1075
|
-
if force or mixture is None:
|
1076
|
-
mixture = self.mixture_mixture(m_id, targets=targets, noise=noise, force=force)
|
1207
|
+
truth_t = truth_function(self, m_id)
|
1077
1208
|
|
1078
|
-
if
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1209
|
+
if cache:
|
1210
|
+
write_cached_data(
|
1211
|
+
location=self.location,
|
1212
|
+
name="mixture",
|
1213
|
+
index=self.mixture(m_id).name,
|
1214
|
+
items={"truth_t": truth_t},
|
1215
|
+
)
|
1083
1216
|
|
1084
|
-
return
|
1217
|
+
return truth_t
|
1085
1218
|
|
1086
1219
|
def mixture_segsnr_t(
|
1087
1220
|
self,
|
1088
1221
|
m_id: int,
|
1089
|
-
|
1090
|
-
|
1222
|
+
sources: SourcesAudioT | None = None,
|
1223
|
+
source: AudioT | None = None,
|
1091
1224
|
noise: AudioT | None = None,
|
1092
1225
|
force: bool = False,
|
1226
|
+
cache: bool = False,
|
1093
1227
|
) -> Segsnr:
|
1094
1228
|
"""Get the segsnr_t data for the given mixture ID
|
1095
1229
|
|
1096
1230
|
:param m_id: Zero-based mixture ID
|
1097
|
-
:param
|
1098
|
-
:param
|
1099
|
-
:param noise:
|
1231
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1232
|
+
:param source: Post-truth, gained, and summed source audio data
|
1233
|
+
:param noise: Post-truth and gained noise audio data
|
1100
1234
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1235
|
+
:param cache: Cache result
|
1101
1236
|
:return: segsnr_t data
|
1102
1237
|
"""
|
1103
1238
|
import numpy as np
|
1104
1239
|
import torch
|
1105
1240
|
from pyaaware import ForwardTransform
|
1106
1241
|
|
1242
|
+
from .data_io import write_cached_data
|
1243
|
+
|
1107
1244
|
if not force:
|
1108
|
-
segsnr_t = self.read_mixture_data(m_id, "segsnr_t")
|
1245
|
+
segsnr_t = self.read_mixture_data(m_id, "segsnr_t")["segsnr_t"]
|
1109
1246
|
if segsnr_t is not None:
|
1110
1247
|
return segsnr_t
|
1111
1248
|
|
1112
|
-
if
|
1113
|
-
|
1249
|
+
if source is None:
|
1250
|
+
source = self.mixture_source(m_id, sources, force)
|
1114
1251
|
|
1115
|
-
if
|
1116
|
-
noise = self.mixture_noise(m_id, force)
|
1252
|
+
if noise is None:
|
1253
|
+
noise = self.mixture_noise(m_id, sources, force)
|
1117
1254
|
|
1118
1255
|
ft = ForwardTransform(
|
1119
1256
|
length=self.ft_config.length,
|
@@ -1127,13 +1264,13 @@ class MixtureDatabase:
|
|
1127
1264
|
|
1128
1265
|
segsnr_t = np.empty(mixture.samples, dtype=np.float32)
|
1129
1266
|
|
1130
|
-
|
1267
|
+
source_energy = ft.execute_all(torch.from_numpy(source))[1].numpy()
|
1131
1268
|
noise_energy = ft.execute_all(torch.from_numpy(noise))[1].numpy()
|
1132
1269
|
|
1133
1270
|
offsets = range(0, mixture.samples, self.ft_config.overlap)
|
1134
|
-
if len(
|
1271
|
+
if len(source_energy) != len(offsets):
|
1135
1272
|
raise ValueError(
|
1136
|
-
f"Number of frames in energy, {len(
|
1273
|
+
f"Number of frames in energy, {len(source_energy)}, is not number of frames in mixture, {len(offsets)}"
|
1137
1274
|
)
|
1138
1275
|
|
1139
1276
|
for idx, offset in enumerate(offsets):
|
@@ -1142,187 +1279,242 @@ class MixtureDatabase:
|
|
1142
1279
|
if noise_energy[idx] == 0:
|
1143
1280
|
snr = np.float32(np.inf)
|
1144
1281
|
else:
|
1145
|
-
snr = np.float32(
|
1282
|
+
snr = np.float32(source_energy[idx] / noise_energy[idx])
|
1146
1283
|
|
1147
1284
|
segsnr_t[indices] = snr
|
1148
1285
|
|
1286
|
+
if cache:
|
1287
|
+
write_cached_data(
|
1288
|
+
location=self.location,
|
1289
|
+
name="mixture",
|
1290
|
+
index=mixture.name,
|
1291
|
+
items={"segsnr_t": segsnr_t},
|
1292
|
+
)
|
1293
|
+
|
1149
1294
|
return segsnr_t
|
1150
1295
|
|
1151
1296
|
def mixture_segsnr(
|
1152
1297
|
self,
|
1153
1298
|
m_id: int,
|
1154
1299
|
segsnr_t: Segsnr | None = None,
|
1155
|
-
|
1156
|
-
|
1300
|
+
sources: SourcesAudioT | None = None,
|
1301
|
+
source: AudioT | None = None,
|
1157
1302
|
noise: AudioT | None = None,
|
1158
1303
|
force: bool = False,
|
1304
|
+
cache: bool = False,
|
1159
1305
|
) -> Segsnr:
|
1160
1306
|
"""Get the segsnr data for the given mixture ID
|
1161
1307
|
|
1162
1308
|
:param m_id: Zero-based mixture ID
|
1163
1309
|
:param segsnr_t: segsnr_t data
|
1164
|
-
:param
|
1165
|
-
:param
|
1166
|
-
:param noise:
|
1310
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1311
|
+
:param source: Post-truth, gained, and summed source audio data
|
1312
|
+
:param noise: Post-truth and gained noise audio data
|
1167
1313
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1314
|
+
:param cache: Cache result
|
1168
1315
|
:return: segsnr data
|
1169
1316
|
"""
|
1317
|
+
from .data_io import write_cached_data
|
1318
|
+
|
1170
1319
|
if not force:
|
1171
|
-
segsnr = self.read_mixture_data(m_id, "segsnr")
|
1320
|
+
segsnr = self.read_mixture_data(m_id, "segsnr")["segsnr"]
|
1172
1321
|
if segsnr is not None:
|
1173
1322
|
return segsnr
|
1174
1323
|
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1324
|
+
if segsnr_t is None:
|
1325
|
+
segsnr_t = self.mixture_segsnr_t(m_id, sources, source, noise, force)
|
1326
|
+
|
1327
|
+
segsnr = segsnr_t[0 :: self.ft_config.overlap]
|
1178
1328
|
|
1179
|
-
if
|
1180
|
-
|
1329
|
+
if cache:
|
1330
|
+
write_cached_data(
|
1331
|
+
location=self.location,
|
1332
|
+
name="mixture",
|
1333
|
+
index=self.mixture(m_id).name,
|
1334
|
+
items={"segsnr": segsnr},
|
1335
|
+
)
|
1181
1336
|
|
1182
|
-
return
|
1337
|
+
return segsnr
|
1183
1338
|
|
1184
1339
|
def mixture_ft(
|
1185
1340
|
self,
|
1186
1341
|
m_id: int,
|
1187
|
-
|
1188
|
-
|
1342
|
+
sources: SourcesAudioT | None = None,
|
1343
|
+
source: AudioT | None = None,
|
1189
1344
|
noise: AudioT | None = None,
|
1190
1345
|
mixture_f: AudioF | None = None,
|
1191
1346
|
mixture: AudioT | None = None,
|
1192
|
-
truth_t:
|
1347
|
+
truth_t: TruthsDict | None = None,
|
1193
1348
|
force: bool = False,
|
1194
|
-
|
1349
|
+
cache: bool = False,
|
1350
|
+
) -> tuple[Feature, TruthsDict]:
|
1195
1351
|
"""Get the feature and truth_f data for the given mixture ID
|
1196
1352
|
|
1197
1353
|
:param m_id: Zero-based mixture ID
|
1198
|
-
:param
|
1199
|
-
:param
|
1200
|
-
:param noise:
|
1354
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1355
|
+
:param source: Post-truth, gained, and summed source audio data
|
1356
|
+
:param noise: Post-truth and gained noise audio data
|
1201
1357
|
:param mixture_f: Mixture transform data
|
1202
1358
|
:param mixture: Mixture audio data
|
1203
1359
|
:param truth_t: truth_t
|
1204
1360
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1361
|
+
:param cache: Cache result
|
1205
1362
|
:return: Tuple of (feature, truth_f) data
|
1206
1363
|
"""
|
1207
1364
|
from pyaaware import FeatureGenerator
|
1208
1365
|
|
1366
|
+
from .data_io import write_cached_data
|
1209
1367
|
from .truth import truth_stride_reduction
|
1210
1368
|
|
1211
1369
|
if not force:
|
1212
|
-
|
1213
|
-
if feature is not None and truth_f is not None:
|
1214
|
-
return feature, truth_f
|
1370
|
+
ft = self.read_mixture_data(m_id, ["feature", "truth_f"])
|
1371
|
+
if ft["feature"] is not None and ft["truth_f"] is not None:
|
1372
|
+
return ft["feature"], ft["truth_f"]
|
1215
1373
|
|
1216
|
-
if
|
1374
|
+
if mixture_f is None:
|
1217
1375
|
mixture_f = self.mixture_mixture_f(
|
1218
1376
|
m_id=m_id,
|
1219
|
-
|
1220
|
-
|
1377
|
+
sources=sources,
|
1378
|
+
source=source,
|
1221
1379
|
noise=noise,
|
1222
1380
|
mixture=mixture,
|
1223
1381
|
force=force,
|
1224
1382
|
)
|
1225
1383
|
|
1226
|
-
if
|
1227
|
-
truth_t = self.mixture_truth_t(m_id
|
1384
|
+
if truth_t is None:
|
1385
|
+
truth_t = self.mixture_truth_t(m_id, force)
|
1228
1386
|
|
1229
1387
|
fg = FeatureGenerator(self.fg_config.feature_mode, self.fg_config.truth_parameters)
|
1230
1388
|
|
1231
|
-
|
1232
|
-
feature, truth_f = fg.execute_all(mixture_f, truth_t[0])
|
1389
|
+
feature, truth_f = fg.execute_all(mixture_f, truth_t)
|
1233
1390
|
if truth_f is not None:
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1391
|
+
truth_configs = self.mixture_truth_configs(m_id)
|
1392
|
+
for category, configs in truth_configs.items():
|
1393
|
+
for name, config in configs.items():
|
1394
|
+
if self.truth_parameters[category][name] is not None:
|
1395
|
+
truth_f[category][name] = truth_stride_reduction(
|
1396
|
+
truth_f[category][name], config.stride_reduction
|
1397
|
+
)
|
1237
1398
|
else:
|
1238
1399
|
raise TypeError("Unexpected truth of None from feature generator")
|
1239
1400
|
|
1401
|
+
if cache:
|
1402
|
+
write_cached_data(
|
1403
|
+
location=self.location,
|
1404
|
+
name="mixture",
|
1405
|
+
index=self.mixture(m_id).name,
|
1406
|
+
items={"feature": truth_f, "truth_f": truth_f},
|
1407
|
+
)
|
1408
|
+
|
1240
1409
|
return feature, truth_f
|
1241
1410
|
|
1242
1411
|
def mixture_feature(
|
1243
1412
|
self,
|
1244
1413
|
m_id: int,
|
1245
|
-
|
1414
|
+
sources: SourcesAudioT | None = None,
|
1246
1415
|
noise: AudioT | None = None,
|
1247
1416
|
mixture: AudioT | None = None,
|
1248
|
-
truth_t:
|
1417
|
+
truth_t: TruthsDict | None = None,
|
1249
1418
|
force: bool = False,
|
1419
|
+
cache: bool = False,
|
1250
1420
|
) -> Feature:
|
1251
1421
|
"""Get the feature data for the given mixture ID
|
1252
1422
|
|
1253
1423
|
:param m_id: Zero-based mixture ID
|
1254
|
-
:param
|
1255
|
-
:param noise:
|
1424
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1425
|
+
:param noise: Post-truth and gained noise audio data
|
1256
1426
|
:param mixture: Mixture audio data
|
1257
1427
|
:param truth_t: truth_t
|
1258
1428
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1429
|
+
:param cache: Cache result
|
1259
1430
|
:return: Feature data
|
1260
1431
|
"""
|
1261
|
-
|
1432
|
+
from .data_io import write_cached_data
|
1433
|
+
|
1434
|
+
feature = self.mixture_ft(
|
1262
1435
|
m_id=m_id,
|
1263
|
-
|
1436
|
+
sources=sources,
|
1264
1437
|
noise=noise,
|
1265
1438
|
mixture=mixture,
|
1266
1439
|
truth_t=truth_t,
|
1267
1440
|
force=force,
|
1268
|
-
)
|
1441
|
+
)[0]
|
1442
|
+
|
1443
|
+
if cache:
|
1444
|
+
write_cached_data(
|
1445
|
+
location=self.location,
|
1446
|
+
name="mixture",
|
1447
|
+
index=self.mixture(m_id).name,
|
1448
|
+
items={"feature": feature},
|
1449
|
+
)
|
1450
|
+
|
1269
1451
|
return feature
|
1270
1452
|
|
1271
1453
|
def mixture_truth_f(
|
1272
1454
|
self,
|
1273
1455
|
m_id: int,
|
1274
|
-
|
1456
|
+
sources: SourcesAudioT | None = None,
|
1275
1457
|
noise: AudioT | None = None,
|
1276
1458
|
mixture: AudioT | None = None,
|
1277
|
-
truth_t:
|
1459
|
+
truth_t: TruthsDict | None = None,
|
1278
1460
|
force: bool = False,
|
1461
|
+
cache: bool = False,
|
1279
1462
|
) -> TruthDict:
|
1280
1463
|
"""Get the truth_f data for the given mixture ID
|
1281
1464
|
|
1282
1465
|
:param m_id: Zero-based mixture ID
|
1283
|
-
:param
|
1284
|
-
:param noise:
|
1466
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1467
|
+
:param noise: Post-truth and gained noise audio data
|
1285
1468
|
:param mixture: Mixture audio data
|
1286
1469
|
:param truth_t: truth_t
|
1287
1470
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1471
|
+
:param cache: Cache result
|
1288
1472
|
:return: truth_f data
|
1289
1473
|
"""
|
1290
|
-
|
1474
|
+
from .data_io import write_cached_data
|
1475
|
+
|
1476
|
+
truth_f = self.mixture_ft(
|
1291
1477
|
m_id=m_id,
|
1292
|
-
|
1478
|
+
sources=sources,
|
1293
1479
|
noise=noise,
|
1294
1480
|
mixture=mixture,
|
1295
1481
|
truth_t=truth_t,
|
1296
1482
|
force=force,
|
1297
|
-
)
|
1483
|
+
)[1]
|
1484
|
+
|
1485
|
+
if cache:
|
1486
|
+
write_cached_data(
|
1487
|
+
location=self.location,
|
1488
|
+
name="mixture",
|
1489
|
+
index=self.mixture(m_id).name,
|
1490
|
+
items={"truth_f": truth_f},
|
1491
|
+
)
|
1492
|
+
|
1298
1493
|
return truth_f
|
1299
1494
|
|
1300
|
-
def mixture_class_count(
|
1301
|
-
self,
|
1302
|
-
m_id: int,
|
1303
|
-
targets: list[AudioT] | None = None,
|
1304
|
-
noise: AudioT | None = None,
|
1305
|
-
truth_t: list[TruthDict] | None = None,
|
1306
|
-
) -> ClassCount:
|
1495
|
+
def mixture_class_count(self, m_id: int, truth_t: TruthsDict | None = None) -> dict[str, ClassCount]:
|
1307
1496
|
"""Compute the number of frames for which each class index is active for the given mixture ID
|
1308
1497
|
|
1309
1498
|
:param m_id: Zero-based mixture ID
|
1310
|
-
:param targets: List of augmented target audio (one per target in the mixup)
|
1311
|
-
:param noise: Augmented noise audio
|
1312
1499
|
:param truth_t: truth_t
|
1313
|
-
:return:
|
1500
|
+
:return: Dictionary of class counts
|
1314
1501
|
"""
|
1315
1502
|
import numpy as np
|
1316
1503
|
|
1317
1504
|
if truth_t is None:
|
1318
|
-
truth_t = self.mixture_truth_t(m_id
|
1505
|
+
truth_t = self.mixture_truth_t(m_id)
|
1319
1506
|
|
1320
|
-
class_count
|
1321
|
-
|
1322
|
-
|
1323
|
-
|
1324
|
-
|
1325
|
-
|
1507
|
+
class_count: dict[str, ClassCount] = {}
|
1508
|
+
|
1509
|
+
truth_configs = self.mixture_truth_configs(m_id)
|
1510
|
+
for category in truth_configs:
|
1511
|
+
class_count[category] = [0] * self.num_classes
|
1512
|
+
for configs in truth_configs[category]:
|
1513
|
+
if "sed" in configs:
|
1514
|
+
for cl in range(self.num_classes):
|
1515
|
+
class_count[category][cl] = int(
|
1516
|
+
np.sum(truth_t[category]["sed"][:, cl] >= self.class_weights_thresholds[cl])
|
1517
|
+
)
|
1326
1518
|
|
1327
1519
|
return class_count
|
1328
1520
|
|
@@ -1348,57 +1540,56 @@ class MixtureDatabase:
|
|
1348
1540
|
return _speaker(self.db, s_id, tier, self.use_cache)
|
1349
1541
|
|
1350
1542
|
def speech_metadata(self, tier: str) -> list[str]:
|
1351
|
-
from .helpers import
|
1543
|
+
from .helpers import get_textgrid_tier_from_source_file
|
1352
1544
|
|
1353
1545
|
results: set[str] = set()
|
1354
1546
|
if tier in self.textgrid_metadata_tiers:
|
1355
|
-
for
|
1356
|
-
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1360
|
-
|
1361
|
-
|
1362
|
-
|
1363
|
-
|
1547
|
+
for source_files in self.source_files.values():
|
1548
|
+
for source_file in source_files:
|
1549
|
+
data = get_textgrid_tier_from_source_file(source_file.name, tier)
|
1550
|
+
if data is None:
|
1551
|
+
continue
|
1552
|
+
if isinstance(data, list):
|
1553
|
+
for item in data:
|
1554
|
+
results.add(item.label)
|
1555
|
+
else:
|
1556
|
+
results.add(data)
|
1364
1557
|
elif tier in self.speaker_metadata_tiers:
|
1365
|
-
for
|
1366
|
-
|
1367
|
-
|
1368
|
-
|
1558
|
+
for source_files in self.source_files.values():
|
1559
|
+
for source_file in source_files:
|
1560
|
+
data = self.speaker(source_file.speaker_id, tier)
|
1561
|
+
if data is not None:
|
1562
|
+
results.add(data)
|
1369
1563
|
|
1370
1564
|
return sorted(results)
|
1371
1565
|
|
1372
|
-
def mixture_speech_metadata(self, mixid: int, tier: str) ->
|
1566
|
+
def mixture_speech_metadata(self, mixid: int, tier: str) -> dict[str, SpeechMetadata]:
|
1373
1567
|
from praatio.utilities.constants import Interval
|
1374
1568
|
|
1375
|
-
from .helpers import
|
1569
|
+
from .helpers import get_textgrid_tier_from_source_file
|
1376
1570
|
|
1377
|
-
results:
|
1571
|
+
results: dict[str, SpeechMetadata] = {}
|
1378
1572
|
is_textgrid = tier in self.textgrid_metadata_tiers
|
1379
1573
|
if is_textgrid:
|
1380
|
-
for
|
1381
|
-
data =
|
1574
|
+
for category, source in self.mixture(mixid).all_sources.items():
|
1575
|
+
data = get_textgrid_tier_from_source_file(self.source_file(source.file_id).name, tier)
|
1382
1576
|
if isinstance(data, list):
|
1383
|
-
# Check for tempo
|
1577
|
+
# Check for tempo effect and adjust Interval start and end data as needed
|
1384
1578
|
entries = []
|
1385
1579
|
for entry in data:
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1391
|
-
entry.label,
|
1392
|
-
)
|
1580
|
+
entries.append(
|
1581
|
+
Interval(
|
1582
|
+
entry.start / source.pre_tempo,
|
1583
|
+
entry.end / source.pre_tempo,
|
1584
|
+
entry.label,
|
1393
1585
|
)
|
1394
|
-
|
1395
|
-
|
1396
|
-
results.append(entries)
|
1586
|
+
)
|
1587
|
+
results[category] = entries
|
1397
1588
|
else:
|
1398
|
-
results
|
1589
|
+
results[category] = data
|
1399
1590
|
else:
|
1400
|
-
for
|
1401
|
-
results
|
1591
|
+
for category, source in self.mixture(mixid).all_sources.items():
|
1592
|
+
results[category] = self.speaker(self.source_file(source.file_id).speaker_id, tier)
|
1402
1593
|
|
1403
1594
|
return results
|
1404
1595
|
|
@@ -1450,7 +1641,7 @@ class MixtureDatabase:
|
|
1450
1641
|
|
1451
1642
|
return [mixture_id[0] - 1 for mixture_id in results]
|
1452
1643
|
|
1453
|
-
def mixture_all_speech_metadata(self, m_id: int) ->
|
1644
|
+
def mixture_all_speech_metadata(self, m_id: int) -> dict[str, dict[str, SpeechMetadata]]:
|
1454
1645
|
from .helpers import mixture_all_speech_metadata
|
1455
1646
|
|
1456
1647
|
return mixture_all_speech_metadata(self, self.mixture(m_id))
|
@@ -1483,63 +1674,65 @@ class MixtureDatabase:
|
|
1483
1674
|
:param m_id: Zero-based mixture ID
|
1484
1675
|
:param metrics: List of metrics to get
|
1485
1676
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1486
|
-
:return:
|
1677
|
+
:return: Dictionary of metric data
|
1487
1678
|
"""
|
1488
1679
|
from collections.abc import Callable
|
1489
1680
|
|
1490
1681
|
import numpy as np
|
1491
1682
|
from pystoi import stoi
|
1492
1683
|
|
1493
|
-
from
|
1494
|
-
from
|
1495
|
-
from
|
1496
|
-
from
|
1497
|
-
from
|
1498
|
-
from
|
1499
|
-
from
|
1500
|
-
from
|
1501
|
-
from
|
1502
|
-
from
|
1503
|
-
from
|
1504
|
-
|
1505
|
-
|
1506
|
-
|
1507
|
-
|
1508
|
-
|
1684
|
+
from ..constants import SAMPLE_RATE
|
1685
|
+
from ..datatypes import AudioStatsMetrics
|
1686
|
+
from ..datatypes import SpeechMetrics
|
1687
|
+
from ..metrics.calc_audio_stats import calc_audio_stats
|
1688
|
+
from ..metrics.calc_pesq import calc_pesq
|
1689
|
+
from ..metrics.calc_phase_distance import calc_phase_distance
|
1690
|
+
from ..metrics.calc_segsnr_f import calc_segsnr_f
|
1691
|
+
from ..metrics.calc_segsnr_f import calc_segsnr_f_bin
|
1692
|
+
from ..metrics.calc_speech import calc_speech
|
1693
|
+
from ..metrics.calc_wer import calc_wer
|
1694
|
+
from ..metrics.calc_wsdr import calc_wsdr
|
1695
|
+
from ..utils.asr import calc_asr
|
1696
|
+
from ..utils.db import linear_to_db
|
1697
|
+
|
1698
|
+
def create_sources_audio() -> Callable[[], dict[str, AudioT]]:
|
1699
|
+
state: dict[str, AudioT] | None = None
|
1700
|
+
|
1701
|
+
def get() -> dict[str, AudioT]:
|
1509
1702
|
nonlocal state
|
1510
1703
|
if state is None:
|
1511
|
-
state = self.
|
1704
|
+
state = self.mixture_sources(m_id)
|
1512
1705
|
return state
|
1513
1706
|
|
1514
1707
|
return get
|
1515
1708
|
|
1516
|
-
|
1709
|
+
sources_audio = create_sources_audio()
|
1517
1710
|
|
1518
|
-
def
|
1711
|
+
def create_source_audio() -> Callable[[], AudioT]:
|
1519
1712
|
state: AudioT | None = None
|
1520
1713
|
|
1521
1714
|
def get() -> AudioT:
|
1522
1715
|
nonlocal state
|
1523
1716
|
if state is None:
|
1524
|
-
state = self.
|
1717
|
+
state = self.mixture_source(m_id)
|
1525
1718
|
return state
|
1526
1719
|
|
1527
1720
|
return get
|
1528
1721
|
|
1529
|
-
|
1722
|
+
source_audio = create_source_audio()
|
1530
1723
|
|
1531
|
-
def
|
1724
|
+
def create_source_f() -> Callable[[], AudioF]:
|
1532
1725
|
state: AudioF | None = None
|
1533
1726
|
|
1534
1727
|
def get() -> AudioF:
|
1535
1728
|
nonlocal state
|
1536
1729
|
if state is None:
|
1537
|
-
state = self.
|
1730
|
+
state = self.mixture_source_f(m_id)
|
1538
1731
|
return state
|
1539
1732
|
|
1540
1733
|
return get
|
1541
1734
|
|
1542
|
-
|
1735
|
+
source_f = create_source_f()
|
1543
1736
|
|
1544
1737
|
def create_noise_audio() -> Callable[[], AudioT]:
|
1545
1738
|
state: AudioT | None = None
|
@@ -1593,15 +1786,29 @@ class MixtureDatabase:
|
|
1593
1786
|
|
1594
1787
|
segsnr_f = create_segsnr_f()
|
1595
1788
|
|
1596
|
-
def
|
1597
|
-
state:
|
1789
|
+
def create_pesq() -> Callable[[], dict[str, float]]:
|
1790
|
+
state: dict[str, float] | None = None
|
1791
|
+
|
1792
|
+
def get() -> dict[str, float]:
|
1793
|
+
nonlocal state
|
1794
|
+
if state is None:
|
1795
|
+
state = {category: calc_pesq(mixture_audio(), audio) for category, audio in sources_audio().items()}
|
1796
|
+
return state
|
1797
|
+
|
1798
|
+
return get
|
1799
|
+
|
1800
|
+
pesq = create_pesq()
|
1801
|
+
|
1802
|
+
def create_speech() -> Callable[[], dict[str, SpeechMetrics]]:
|
1803
|
+
state: dict[str, SpeechMetrics] | None = None
|
1598
1804
|
|
1599
|
-
def get() ->
|
1805
|
+
def get() -> dict[str, SpeechMetrics]:
|
1600
1806
|
nonlocal state
|
1601
1807
|
if state is None:
|
1602
|
-
state =
|
1603
|
-
|
1604
|
-
|
1808
|
+
state = {
|
1809
|
+
category: calc_speech(mixture_audio(), audio, pesq()[category])
|
1810
|
+
for category, audio in sources_audio().items()
|
1811
|
+
}
|
1605
1812
|
return state
|
1606
1813
|
|
1607
1814
|
return get
|
@@ -1621,33 +1828,34 @@ class MixtureDatabase:
|
|
1621
1828
|
|
1622
1829
|
mixture_stats = create_mixture_stats()
|
1623
1830
|
|
1624
|
-
def
|
1625
|
-
state:
|
1831
|
+
def create_sources_stats() -> Callable[[], dict[str, AudioStatsMetrics]]:
|
1832
|
+
state: dict[str, AudioStatsMetrics] | None = None
|
1626
1833
|
|
1627
|
-
def get() ->
|
1834
|
+
def get() -> dict[str, AudioStatsMetrics]:
|
1628
1835
|
nonlocal state
|
1629
1836
|
if state is None:
|
1630
|
-
state =
|
1631
|
-
|
1632
|
-
|
1837
|
+
state = {
|
1838
|
+
category: calc_audio_stats(audio, self.fg_info.ft_config.length / SAMPLE_RATE)
|
1839
|
+
for category, audio in sources_audio().items()
|
1840
|
+
}
|
1633
1841
|
return state
|
1634
1842
|
|
1635
1843
|
return get
|
1636
1844
|
|
1637
|
-
|
1845
|
+
sources_stats = create_sources_stats()
|
1638
1846
|
|
1639
|
-
def
|
1847
|
+
def create_source_stats() -> Callable[[], AudioStatsMetrics]:
|
1640
1848
|
state: AudioStatsMetrics | None = None
|
1641
1849
|
|
1642
1850
|
def get() -> AudioStatsMetrics:
|
1643
1851
|
nonlocal state
|
1644
1852
|
if state is None:
|
1645
|
-
state = calc_audio_stats(
|
1853
|
+
state = calc_audio_stats(source_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
|
1646
1854
|
return state
|
1647
1855
|
|
1648
1856
|
return get
|
1649
1857
|
|
1650
|
-
|
1858
|
+
source_stats = create_source_stats()
|
1651
1859
|
|
1652
1860
|
def create_noise_stats() -> Callable[[], AudioStatsMetrics]:
|
1653
1861
|
state: AudioStatsMetrics | None = None
|
@@ -1678,33 +1886,34 @@ class MixtureDatabase:
|
|
1678
1886
|
|
1679
1887
|
asr_config = create_asr_config()
|
1680
1888
|
|
1681
|
-
def
|
1682
|
-
state: dict[str,
|
1889
|
+
def create_sources_asr() -> Callable[[str], dict[str, str]]:
|
1890
|
+
state: dict[str, dict[str, str]] = {}
|
1683
1891
|
|
1684
|
-
def get(asr_name) ->
|
1892
|
+
def get(asr_name) -> dict[str, str]:
|
1685
1893
|
nonlocal state
|
1686
1894
|
if asr_name not in state:
|
1687
|
-
state[asr_name] =
|
1688
|
-
|
1689
|
-
|
1895
|
+
state[asr_name] = {
|
1896
|
+
category: calc_asr(audio, **asr_config(asr_name)).text
|
1897
|
+
for category, audio in sources_audio().items()
|
1898
|
+
}
|
1690
1899
|
return state[asr_name]
|
1691
1900
|
|
1692
1901
|
return get
|
1693
1902
|
|
1694
|
-
|
1903
|
+
sources_asr = create_sources_asr()
|
1695
1904
|
|
1696
|
-
def
|
1905
|
+
def create_source_asr() -> Callable[[str], str]:
|
1697
1906
|
state: dict[str, str] = {}
|
1698
1907
|
|
1699
1908
|
def get(asr_name) -> str:
|
1700
1909
|
nonlocal state
|
1701
1910
|
if asr_name not in state:
|
1702
|
-
state[asr_name] = calc_asr(
|
1911
|
+
state[asr_name] = calc_asr(source_audio(), **asr_config(asr_name)).text
|
1703
1912
|
return state[asr_name]
|
1704
1913
|
|
1705
1914
|
return get
|
1706
1915
|
|
1707
|
-
|
1916
|
+
source_asr = create_source_asr()
|
1708
1917
|
|
1709
1918
|
def create_mixture_asr() -> Callable[[str], str]:
|
1710
1919
|
state: dict[str, str] = {}
|
@@ -1728,11 +1937,11 @@ class MixtureDatabase:
|
|
1728
1937
|
|
1729
1938
|
def calc(m: str) -> Any:
|
1730
1939
|
if m == "mxsnr":
|
1731
|
-
return self.mixture(m_id).
|
1940
|
+
return {category: source.snr for category, source in self.mixture(m_id).all_sources.items()}
|
1732
1941
|
|
1733
1942
|
# Get cached data first, if exists
|
1734
1943
|
if not force:
|
1735
|
-
value = self.read_mixture_data(m_id, m)
|
1944
|
+
value = self.read_mixture_data(m_id, m)[m]
|
1736
1945
|
if value is not None:
|
1737
1946
|
return value
|
1738
1947
|
|
@@ -1744,8 +1953,8 @@ class MixtureDatabase:
|
|
1744
1953
|
# noise only, ignore/reset target asr
|
1745
1954
|
return float("nan")
|
1746
1955
|
|
1747
|
-
if
|
1748
|
-
return calc_wer(mixture_asr(asr_name),
|
1956
|
+
if source_asr(asr_name):
|
1957
|
+
return calc_wer(mixture_asr(asr_name), source_asr(asr_name)).wer * 100
|
1749
1958
|
|
1750
1959
|
# TODO: should this be NaN like above?
|
1751
1960
|
return float(0)
|
@@ -1753,12 +1962,14 @@ class MixtureDatabase:
|
|
1753
1962
|
if m.startswith("basewer"):
|
1754
1963
|
asr_name = get_asr_name(m)
|
1755
1964
|
|
1756
|
-
text = self.mixture_speech_metadata(m_id, "text")
|
1757
|
-
|
1758
|
-
|
1759
|
-
|
1760
|
-
|
1761
|
-
|
1965
|
+
text = self.mixture_speech_metadata(m_id, "text")
|
1966
|
+
base_wer: dict[str, float] = {}
|
1967
|
+
for category, source in sources_asr(asr_name).items():
|
1968
|
+
if isinstance(text[category], str):
|
1969
|
+
base_wer[category] = calc_wer(source, str(text[category])).wer * 100
|
1970
|
+
else:
|
1971
|
+
base_wer[category] = 0
|
1972
|
+
return base_wer
|
1762
1973
|
|
1763
1974
|
if m.startswith("mxasr"):
|
1764
1975
|
return mixture_asr(get_asr_name(m))
|
@@ -1769,6 +1980,18 @@ class MixtureDatabase:
|
|
1769
1980
|
if m == "mxssnr_std":
|
1770
1981
|
return calc_segsnr_f(segsnr_f()).std
|
1771
1982
|
|
1983
|
+
if m == "mxssnr_avg_db":
|
1984
|
+
val = calc_segsnr_f(segsnr_f()).avg
|
1985
|
+
if val is not None:
|
1986
|
+
return linear_to_db(val)
|
1987
|
+
return None
|
1988
|
+
|
1989
|
+
if m == "mxssnr_std_db":
|
1990
|
+
val = calc_segsnr_f(segsnr_f()).std
|
1991
|
+
if val is not None:
|
1992
|
+
return linear_to_db(val)
|
1993
|
+
return None
|
1994
|
+
|
1772
1995
|
if m == "mxssnrdb_avg":
|
1773
1996
|
return calc_segsnr_f(segsnr_f()).db_avg
|
1774
1997
|
|
@@ -1776,40 +1999,40 @@ class MixtureDatabase:
|
|
1776
1999
|
return calc_segsnr_f(segsnr_f()).db_std
|
1777
2000
|
|
1778
2001
|
if m == "mxssnrf_avg":
|
1779
|
-
return calc_segsnr_f_bin(
|
2002
|
+
return calc_segsnr_f_bin(source_f(), noise_f()).avg
|
1780
2003
|
|
1781
2004
|
if m == "mxssnrf_std":
|
1782
|
-
return calc_segsnr_f_bin(
|
2005
|
+
return calc_segsnr_f_bin(source_f(), noise_f()).std
|
1783
2006
|
|
1784
2007
|
if m == "mxssnrdbf_avg":
|
1785
|
-
return calc_segsnr_f_bin(
|
2008
|
+
return calc_segsnr_f_bin(source_f(), noise_f()).db_avg
|
1786
2009
|
|
1787
2010
|
if m == "mxssnrdbf_std":
|
1788
|
-
return calc_segsnr_f_bin(
|
2011
|
+
return calc_segsnr_f_bin(source_f(), noise_f()).db_std
|
1789
2012
|
|
1790
2013
|
if m == "mxpesq":
|
1791
2014
|
if self.mixture(m_id).is_noise_only:
|
1792
|
-
return
|
1793
|
-
return
|
2015
|
+
return dict.fromkeys(pesq(), 0)
|
2016
|
+
return pesq()
|
1794
2017
|
|
1795
2018
|
if m == "mxcsig":
|
1796
2019
|
if self.mixture(m_id).is_noise_only:
|
1797
|
-
return
|
1798
|
-
return
|
2020
|
+
return dict.fromkeys(speech(), 0)
|
2021
|
+
return {category: s.csig for category, s in speech().items()}
|
1799
2022
|
|
1800
2023
|
if m == "mxcbak":
|
1801
2024
|
if self.mixture(m_id).is_noise_only:
|
1802
|
-
return
|
1803
|
-
return
|
2025
|
+
return dict.fromkeys(speech(), 0)
|
2026
|
+
return {category: s.cbak for category, s in speech().items()}
|
1804
2027
|
|
1805
2028
|
if m == "mxcovl":
|
1806
2029
|
if self.mixture(m_id).is_noise_only:
|
1807
|
-
return
|
1808
|
-
return
|
2030
|
+
return dict.fromkeys(speech(), 0)
|
2031
|
+
return {category: s.covl for category, s in speech().items()}
|
1809
2032
|
|
1810
2033
|
if m == "mxwsdr":
|
1811
2034
|
mixture = mixture_audio()[:, np.newaxis]
|
1812
|
-
target =
|
2035
|
+
target = source_audio()[:, np.newaxis]
|
1813
2036
|
noise = noise_audio()[:, np.newaxis]
|
1814
2037
|
return calc_wsdr(
|
1815
2038
|
hypothesis=np.concatenate((mixture, noise), axis=1),
|
@@ -1819,11 +2042,11 @@ class MixtureDatabase:
|
|
1819
2042
|
|
1820
2043
|
if m == "mxpd":
|
1821
2044
|
mixture_f = self.mixture_mixture_f(m_id)
|
1822
|
-
return calc_phase_distance(hypothesis=mixture_f, reference=
|
2045
|
+
return calc_phase_distance(hypothesis=mixture_f, reference=source_f())[0]
|
1823
2046
|
|
1824
2047
|
if m == "mxstoi":
|
1825
2048
|
return stoi(
|
1826
|
-
x=
|
2049
|
+
x=source_audio(),
|
1827
2050
|
y=mixture_audio(),
|
1828
2051
|
fs_sig=SAMPLE_RATE,
|
1829
2052
|
extended=False,
|
@@ -1860,70 +2083,70 @@ class MixtureDatabase:
|
|
1860
2083
|
return mixture_stats().pkc
|
1861
2084
|
|
1862
2085
|
if m == "mxtdco":
|
1863
|
-
return
|
2086
|
+
return source_stats().dco
|
1864
2087
|
|
1865
2088
|
if m == "mxtmin":
|
1866
|
-
return
|
2089
|
+
return source_stats().min
|
1867
2090
|
|
1868
2091
|
if m == "mxtmax":
|
1869
|
-
return
|
2092
|
+
return source_stats().max
|
1870
2093
|
|
1871
2094
|
if m == "mxtpkdb":
|
1872
|
-
return
|
2095
|
+
return source_stats().pkdb
|
1873
2096
|
|
1874
2097
|
if m == "mxtlrms":
|
1875
|
-
return
|
2098
|
+
return source_stats().lrms
|
1876
2099
|
|
1877
2100
|
if m == "mxtpkr":
|
1878
|
-
return
|
2101
|
+
return source_stats().pkr
|
1879
2102
|
|
1880
2103
|
if m == "mxttr":
|
1881
|
-
return
|
2104
|
+
return source_stats().tr
|
1882
2105
|
|
1883
2106
|
if m == "mxtcr":
|
1884
|
-
return
|
2107
|
+
return source_stats().cr
|
1885
2108
|
|
1886
2109
|
if m == "mxtfl":
|
1887
|
-
return
|
2110
|
+
return source_stats().fl
|
1888
2111
|
|
1889
2112
|
if m == "mxtpkc":
|
1890
|
-
return
|
2113
|
+
return source_stats().pkc
|
1891
2114
|
|
1892
|
-
if m == "
|
1893
|
-
return
|
2115
|
+
if m == "sdco":
|
2116
|
+
return {category: s.dco for category, s in sources_stats().items()}
|
1894
2117
|
|
1895
|
-
if m == "
|
1896
|
-
return
|
2118
|
+
if m == "smin":
|
2119
|
+
return {category: s.min for category, s in sources_stats().items()}
|
1897
2120
|
|
1898
|
-
if m == "
|
1899
|
-
return
|
2121
|
+
if m == "smax":
|
2122
|
+
return {category: s.max for category, s in sources_stats().items()}
|
1900
2123
|
|
1901
|
-
if m == "
|
1902
|
-
return
|
2124
|
+
if m == "spkdb":
|
2125
|
+
return {category: s.pkdb for category, s in sources_stats().items()}
|
1903
2126
|
|
1904
|
-
if m == "
|
1905
|
-
return
|
2127
|
+
if m == "slrms":
|
2128
|
+
return {category: s.lrms for category, s in sources_stats().items()}
|
1906
2129
|
|
1907
|
-
if m == "
|
1908
|
-
return
|
2130
|
+
if m == "spkr":
|
2131
|
+
return {category: s.pkr for category, s in sources_stats().items()}
|
1909
2132
|
|
1910
|
-
if m == "
|
1911
|
-
return
|
2133
|
+
if m == "str":
|
2134
|
+
return {category: s.tr for category, s in sources_stats().items()}
|
1912
2135
|
|
1913
|
-
if m == "
|
1914
|
-
return
|
2136
|
+
if m == "scr":
|
2137
|
+
return {category: s.cr for category, s in sources_stats().items()}
|
1915
2138
|
|
1916
|
-
if m == "
|
1917
|
-
return
|
2139
|
+
if m == "sfl":
|
2140
|
+
return {category: s.fl for category, s in sources_stats().items()}
|
1918
2141
|
|
1919
|
-
if m == "
|
1920
|
-
return
|
2142
|
+
if m == "spkc":
|
2143
|
+
return {category: s.pkc for category, s in sources_stats().items()}
|
1921
2144
|
|
1922
|
-
if m.startswith("
|
1923
|
-
return
|
2145
|
+
if m.startswith("sasr"):
|
2146
|
+
return sources_asr(get_asr_name(m))
|
1924
2147
|
|
1925
|
-
if m.startswith("
|
1926
|
-
return
|
2148
|
+
if m.startswith("mxsasr"):
|
2149
|
+
return source_asr(get_asr_name(m))
|
1927
2150
|
|
1928
2151
|
if m == "ndco":
|
1929
2152
|
return noise_stats().dco
|
@@ -2022,82 +2245,85 @@ def __spectral_mask(db: partial, sm_id: int) -> SpectralMask:
|
|
2022
2245
|
)
|
2023
2246
|
|
2024
2247
|
|
2025
|
-
def
|
2026
|
-
"""Get
|
2248
|
+
def _num_source_files(db: partial, category: str, use_cache: bool = True) -> int:
|
2249
|
+
"""Get number of source files from category from db
|
2027
2250
|
|
2028
2251
|
:param db: Database context
|
2029
|
-
:param
|
2252
|
+
:param category: Source category
|
2030
2253
|
:param use_cache: If true, use LRU caching
|
2031
|
-
:return:
|
2254
|
+
:return: Number of source files
|
2032
2255
|
"""
|
2033
2256
|
if use_cache:
|
2034
|
-
return
|
2035
|
-
return
|
2257
|
+
return __num_source_files(db, category)
|
2258
|
+
return __num_source_files.__wrapped__(db, category)
|
2036
2259
|
|
2037
2260
|
|
2038
2261
|
@lru_cache
|
2039
|
-
def
|
2040
|
-
"""Get
|
2262
|
+
def __num_source_files(db: partial, category: str) -> int:
|
2263
|
+
"""Get number of source files from category from db
|
2041
2264
|
|
2042
2265
|
:param db: Database context
|
2043
|
-
:param
|
2044
|
-
:
|
2045
|
-
:return: Target file
|
2266
|
+
:param category: Source category
|
2267
|
+
:return: Number of source files
|
2046
2268
|
"""
|
2047
|
-
import json
|
2048
|
-
|
2049
|
-
from .db_datatypes import TargetFileRecord
|
2050
|
-
|
2051
2269
|
with db() as c:
|
2052
|
-
|
2053
|
-
|
2054
|
-
""
|
2055
|
-
|
2056
|
-
FROM target_file
|
2057
|
-
WHERE ? = target_file.id
|
2058
|
-
""",
|
2059
|
-
(t_id,),
|
2060
|
-
).fetchone()
|
2061
|
-
)
|
2062
|
-
|
2063
|
-
return TargetFile(
|
2064
|
-
name=target_file.name,
|
2065
|
-
samples=target_file.samples,
|
2066
|
-
class_indices=json.loads(target_file.class_indices),
|
2067
|
-
level_type=target_file.level_type,
|
2068
|
-
truth_configs=_target_truth_configs(db, t_id, use_cache),
|
2069
|
-
speaker_id=target_file.speaker_id,
|
2270
|
+
return int(
|
2271
|
+
c.execute(
|
2272
|
+
"SELECT count(source_file.id) FROM source_file WHERE ? = source_file.category", (category,)
|
2273
|
+
).fetchone()[0]
|
2070
2274
|
)
|
2071
2275
|
|
2072
2276
|
|
2073
|
-
def
|
2074
|
-
"""Get
|
2277
|
+
def _source_file(db: partial, s_id: int, use_cache: bool = True) -> SourceFile:
|
2278
|
+
"""Get source file with ID from db
|
2075
2279
|
|
2076
2280
|
:param db: Database context
|
2077
|
-
:param
|
2281
|
+
:param s_id: Source file ID
|
2078
2282
|
:param use_cache: If true, use LRU caching
|
2079
|
-
:return:
|
2283
|
+
:return: Source file
|
2080
2284
|
"""
|
2081
2285
|
if use_cache:
|
2082
|
-
return
|
2083
|
-
return
|
2286
|
+
return __source_file(db, s_id, use_cache)
|
2287
|
+
return __source_file.__wrapped__(db, s_id, use_cache)
|
2084
2288
|
|
2085
2289
|
|
2086
2290
|
@lru_cache
|
2087
|
-
def
|
2291
|
+
def __source_file(db: partial, s_id: int, use_cache: bool = True) -> SourceFile:
|
2292
|
+
"""Get source file with ID from db
|
2293
|
+
|
2294
|
+
:param db: Database context
|
2295
|
+
:param s_id: Source file ID
|
2296
|
+
:param use_cache: If true, use LRU caching
|
2297
|
+
:return: Source file
|
2298
|
+
"""
|
2299
|
+
import json
|
2300
|
+
|
2301
|
+
from .db_datatypes import SourceFileRecord
|
2302
|
+
|
2088
2303
|
with db() as c:
|
2089
|
-
|
2090
|
-
|
2091
|
-
|
2092
|
-
|
2093
|
-
|
2094
|
-
|
2095
|
-
|
2096
|
-
|
2097
|
-
|
2304
|
+
source_file = SourceFileRecord(
|
2305
|
+
*c.execute(
|
2306
|
+
"""
|
2307
|
+
SELECT *
|
2308
|
+
FROM source_file
|
2309
|
+
WHERE ? = source_file.id
|
2310
|
+
""",
|
2311
|
+
(s_id,),
|
2312
|
+
).fetchone()
|
2313
|
+
)
|
2314
|
+
|
2315
|
+
return SourceFile(
|
2316
|
+
category=source_file.category,
|
2317
|
+
name=source_file.name,
|
2318
|
+
samples=source_file.samples,
|
2319
|
+
class_indices=json.loads(source_file.class_indices),
|
2320
|
+
level_type=source_file.level_type,
|
2321
|
+
truth_configs=_source_truth_configs(db, s_id, use_cache),
|
2322
|
+
speaker_id=source_file.speaker_id,
|
2323
|
+
)
|
2098
2324
|
|
2099
2325
|
|
2100
|
-
def
|
2326
|
+
def _ir_file(db: partial, ir_id: int, use_cache: bool = True) -> str:
|
2101
2327
|
"""Get impulse response file name with ID from db
|
2102
2328
|
|
2103
2329
|
:param db: Database context
|
@@ -2106,26 +2332,26 @@ def _impulse_response_file(db: partial, ir_id: int, use_cache: bool = True) -> s
|
|
2106
2332
|
:return: Impulse response file name
|
2107
2333
|
"""
|
2108
2334
|
if use_cache:
|
2109
|
-
return
|
2110
|
-
return
|
2335
|
+
return __ir_file(db, ir_id)
|
2336
|
+
return __ir_file.__wrapped__(db, ir_id)
|
2111
2337
|
|
2112
2338
|
|
2113
2339
|
@lru_cache
|
2114
|
-
def
|
2340
|
+
def __ir_file(db: partial, ir_id: int) -> str:
|
2115
2341
|
with db() as c:
|
2116
2342
|
return str(
|
2117
2343
|
c.execute(
|
2118
2344
|
"""
|
2119
|
-
SELECT
|
2120
|
-
FROM
|
2121
|
-
WHERE ? =
|
2345
|
+
SELECT ir_file.name
|
2346
|
+
FROM ir_file
|
2347
|
+
WHERE ? = ir_file.id
|
2122
2348
|
""",
|
2123
2349
|
(ir_id + 1,),
|
2124
2350
|
).fetchone()[0]
|
2125
2351
|
)
|
2126
2352
|
|
2127
2353
|
|
2128
|
-
def
|
2354
|
+
def _ir_delay(db: partial, ir_id: int, use_cache: bool = True) -> int:
|
2129
2355
|
"""Get impulse response delay with ID from db
|
2130
2356
|
|
2131
2357
|
:param db: Database context
|
@@ -2134,19 +2360,19 @@ def _impulse_response_delay(db: partial, ir_id: int, use_cache: bool = True) ->
|
|
2134
2360
|
:return: Impulse response delay
|
2135
2361
|
"""
|
2136
2362
|
if use_cache:
|
2137
|
-
return
|
2138
|
-
return
|
2363
|
+
return __ir_delay(db, ir_id)
|
2364
|
+
return __ir_delay.__wrapped__(db, ir_id)
|
2139
2365
|
|
2140
2366
|
|
2141
2367
|
@lru_cache
|
2142
|
-
def
|
2368
|
+
def __ir_delay(db: partial, ir_id: int) -> int:
|
2143
2369
|
with db() as c:
|
2144
2370
|
return int(
|
2145
2371
|
c.execute(
|
2146
2372
|
"""
|
2147
|
-
SELECT
|
2148
|
-
FROM
|
2149
|
-
WHERE ? =
|
2373
|
+
SELECT ir_file.delay
|
2374
|
+
FROM ir_file
|
2375
|
+
WHERE ? = ir_file.id
|
2150
2376
|
""",
|
2151
2377
|
(ir_id + 1,),
|
2152
2378
|
).fetchone()[0]
|
@@ -2169,9 +2395,9 @@ def _mixture(db: partial, m_id: int, use_cache: bool = True) -> Mixture:
|
|
2169
2395
|
@lru_cache
|
2170
2396
|
def __mixture(db: partial, m_id: int) -> Mixture:
|
2171
2397
|
from .db_datatypes import MixtureRecord
|
2172
|
-
from .db_datatypes import
|
2398
|
+
from .db_datatypes import SourceRecord
|
2173
2399
|
from .helpers import to_mixture
|
2174
|
-
from .helpers import
|
2400
|
+
from .helpers import to_source
|
2175
2401
|
|
2176
2402
|
with db() as c:
|
2177
2403
|
mixture = MixtureRecord(
|
@@ -2185,19 +2411,20 @@ def __mixture(db: partial, m_id: int) -> Mixture:
|
|
2185
2411
|
).fetchone()
|
2186
2412
|
)
|
2187
2413
|
|
2188
|
-
|
2189
|
-
|
2190
|
-
|
2191
|
-
|
2192
|
-
|
2193
|
-
|
2194
|
-
WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id
|
2414
|
+
sources: Sources = {}
|
2415
|
+
for source in c.execute(
|
2416
|
+
"""
|
2417
|
+
SELECT source.*
|
2418
|
+
FROM source, mixture_source
|
2419
|
+
WHERE ? = mixture_source.mixture_id AND source.id = mixture_source.source_id
|
2195
2420
|
""",
|
2196
|
-
|
2197
|
-
|
2198
|
-
|
2421
|
+
(mixture.id,),
|
2422
|
+
).fetchall():
|
2423
|
+
s = SourceRecord(*source)
|
2424
|
+
category = c.execute("SELECT category FROM source_file WHERE ? = id", (s.file_id,)).fetchone()[0]
|
2425
|
+
sources[category] = to_source(s)
|
2199
2426
|
|
2200
|
-
return to_mixture(mixture,
|
2427
|
+
return to_mixture(mixture, sources)
|
2201
2428
|
|
2202
2429
|
|
2203
2430
|
def _speaker(db: partial, s_id: int | None, tier: str, use_cache: bool = True) -> str | None:
|
@@ -2220,27 +2447,62 @@ def __speaker(db: partial, s_id: int | None, tier: str) -> str | None:
|
|
2220
2447
|
return data[0]
|
2221
2448
|
|
2222
2449
|
|
2223
|
-
def
|
2450
|
+
def _category_truth_configs(db: partial, category: str, use_cache: bool = True) -> dict[str, str]:
|
2451
|
+
if use_cache:
|
2452
|
+
return __category_truth_configs(db, category)
|
2453
|
+
return __category_truth_configs.__wrapped__(db, category)
|
2454
|
+
|
2455
|
+
|
2456
|
+
@lru_cache
|
2457
|
+
def __category_truth_configs(db: partial, category: str) -> dict[str, str]:
|
2458
|
+
import json
|
2459
|
+
|
2460
|
+
truth_configs: dict[str, str] = {}
|
2461
|
+
with db() as c:
|
2462
|
+
s_ids = c.execute(
|
2463
|
+
"""
|
2464
|
+
SELECT id
|
2465
|
+
FROM source_file
|
2466
|
+
WHERE ? = category
|
2467
|
+
""",
|
2468
|
+
(category,),
|
2469
|
+
).fetchall()
|
2470
|
+
|
2471
|
+
for s_id in s_ids:
|
2472
|
+
for truth_config_record in c.execute(
|
2473
|
+
"""
|
2474
|
+
SELECT truth_config.config
|
2475
|
+
FROM truth_config, source_file_truth_config
|
2476
|
+
WHERE ? = source_file_truth_config.source_file_id AND truth_config.id = source_file_truth_config.truth_config_id
|
2477
|
+
""",
|
2478
|
+
(s_id[0],),
|
2479
|
+
).fetchall():
|
2480
|
+
truth_config = json.loads(truth_config_record[0])
|
2481
|
+
truth_configs[truth_config["name"]] = truth_config["function"]
|
2482
|
+
return truth_configs
|
2483
|
+
|
2484
|
+
|
2485
|
+
def _source_truth_configs(db: partial, s_id: int, use_cache: bool = True) -> TruthConfigs:
|
2224
2486
|
if use_cache:
|
2225
|
-
return
|
2226
|
-
return
|
2487
|
+
return __source_truth_configs(db, s_id)
|
2488
|
+
return __source_truth_configs.__wrapped__(db, s_id)
|
2227
2489
|
|
2228
2490
|
|
2229
2491
|
@lru_cache
|
2230
|
-
def
|
2492
|
+
def __source_truth_configs(db: partial, s_id: int) -> TruthConfigs:
|
2231
2493
|
import json
|
2232
2494
|
|
2233
|
-
from
|
2495
|
+
from ..datatypes import TruthConfig
|
2234
2496
|
|
2235
2497
|
truth_configs: TruthConfigs = {}
|
2236
2498
|
with db() as c:
|
2237
2499
|
for truth_config_record in c.execute(
|
2238
2500
|
"""
|
2239
2501
|
SELECT truth_config.config
|
2240
|
-
FROM truth_config,
|
2241
|
-
WHERE ? =
|
2502
|
+
FROM truth_config, source_file_truth_config
|
2503
|
+
WHERE ? = source_file_truth_config.source_file_id AND truth_config.id = source_file_truth_config.truth_config_id
|
2242
2504
|
""",
|
2243
|
-
(
|
2505
|
+
(s_id,),
|
2244
2506
|
).fetchall():
|
2245
2507
|
truth_config = json.loads(truth_config_record[0])
|
2246
2508
|
truth_configs[truth_config["name"]] = TruthConfig(
|