sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +16 -3
- sonusai/audiofe.py +241 -77
- sonusai/calc_metric_spenh.py +71 -73
- sonusai/config/__init__.py +3 -0
- sonusai/config/config.py +61 -0
- sonusai/config/config.yml +20 -0
- sonusai/config/constants.py +8 -0
- sonusai/constants.py +11 -0
- sonusai/data/genmixdb.yml +21 -36
- sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
- sonusai/deprecated/plot.py +4 -5
- sonusai/doc/doc.py +4 -4
- sonusai/doc.py +11 -4
- sonusai/genft.py +43 -45
- sonusai/genmetrics.py +25 -19
- sonusai/genmix.py +54 -82
- sonusai/genmixdb.py +88 -264
- sonusai/ir_metric.py +30 -34
- sonusai/lsdb.py +41 -48
- sonusai/main.py +15 -22
- sonusai/metrics/calc_audio_stats.py +4 -293
- sonusai/metrics/calc_class_weights.py +4 -4
- sonusai/metrics/calc_optimal_thresholds.py +8 -5
- sonusai/metrics/calc_pesq.py +2 -2
- sonusai/metrics/calc_segsnr_f.py +4 -4
- sonusai/metrics/calc_speech.py +25 -13
- sonusai/metrics/class_summary.py +7 -7
- sonusai/metrics/confusion_matrix_summary.py +5 -5
- sonusai/metrics/one_hot.py +4 -4
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/metrics_summary.py +38 -45
- sonusai/mixture/__init__.py +4 -104
- sonusai/mixture/audio.py +10 -39
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/config.py +251 -271
- sonusai/mixture/constants.py +35 -39
- sonusai/mixture/data_io.py +25 -36
- sonusai/mixture/db_datatypes.py +58 -22
- sonusai/mixture/effects.py +386 -0
- sonusai/mixture/feature.py +7 -11
- sonusai/mixture/generation.py +478 -628
- sonusai/mixture/helpers.py +82 -184
- sonusai/mixture/ir_delay.py +3 -4
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +6 -12
- sonusai/mixture/mixdb.py +910 -729
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +2 -2
- sonusai/mixture/truth.py +17 -15
- sonusai/mixture/truth_functions/crm.py +12 -12
- sonusai/mixture/truth_functions/energy.py +22 -22
- sonusai/mixture/truth_functions/file.py +5 -5
- sonusai/mixture/truth_functions/metadata.py +4 -4
- sonusai/mixture/truth_functions/metrics.py +4 -4
- sonusai/mixture/truth_functions/phoneme.py +3 -3
- sonusai/mixture/truth_functions/sed.py +11 -13
- sonusai/mixture/truth_functions/target.py +10 -10
- sonusai/mkwav.py +26 -29
- sonusai/onnx_predict.py +240 -88
- sonusai/queries/__init__.py +2 -2
- sonusai/queries/queries.py +38 -34
- sonusai/speech/librispeech.py +1 -1
- sonusai/speech/mcgill.py +1 -1
- sonusai/speech/timit.py +2 -2
- sonusai/summarize_metric_spenh.py +10 -17
- sonusai/utils/__init__.py +7 -1
- sonusai/utils/asl_p56.py +2 -2
- sonusai/utils/asr.py +2 -2
- sonusai/utils/asr_functions/aaware_whisper.py +4 -5
- sonusai/utils/choice.py +31 -0
- sonusai/utils/compress.py +1 -1
- sonusai/utils/dataclass_from_dict.py +19 -1
- sonusai/utils/energy_f.py +3 -3
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/onnx_utils.py +3 -17
- sonusai/utils/print_mixture_details.py +21 -19
- sonusai/utils/{temp_seed.py → rand.py} +3 -3
- sonusai/utils/read_predict_data.py +2 -2
- sonusai/utils/reshape.py +3 -3
- sonusai/utils/stratified_shuffle_split.py +3 -3
- sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
- sonusai/utils/write_audio.py +2 -2
- sonusai/vars.py +11 -4
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
- sonusai-1.0.2.dist-info/RECORD +138 -0
- sonusai/mixture/augmentation.py +0 -444
- sonusai/mixture/class_count.py +0 -15
- sonusai/mixture/eq_rule_is_valid.py +0 -45
- sonusai/mixture/target_class_balancing.py +0 -107
- sonusai/mixture/targets.py +0 -175
- sonusai-0.20.3.dist-info/RECORD +0 -128
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py
CHANGED
@@ -6,36 +6,43 @@ from sqlite3 import Connection
|
|
6
6
|
from sqlite3 import Cursor
|
7
7
|
from typing import Any
|
8
8
|
|
9
|
-
from
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from
|
13
|
-
from
|
14
|
-
from
|
15
|
-
from
|
16
|
-
from
|
17
|
-
from
|
18
|
-
from
|
19
|
-
from
|
20
|
-
from
|
21
|
-
from
|
22
|
-
from
|
23
|
-
from
|
24
|
-
from
|
25
|
-
from
|
26
|
-
from
|
27
|
-
from
|
28
|
-
from
|
29
|
-
from
|
9
|
+
from ..datatypes import ASRConfigs
|
10
|
+
from ..datatypes import AudioF
|
11
|
+
from ..datatypes import AudioT
|
12
|
+
from ..datatypes import ClassCount
|
13
|
+
from ..datatypes import Feature
|
14
|
+
from ..datatypes import FeatureGeneratorConfig
|
15
|
+
from ..datatypes import FeatureGeneratorInfo
|
16
|
+
from ..datatypes import GeneralizedIDs
|
17
|
+
from ..datatypes import ImpulseResponseFile
|
18
|
+
from ..datatypes import MetricDoc
|
19
|
+
from ..datatypes import MetricDocs
|
20
|
+
from ..datatypes import Mixture
|
21
|
+
from ..datatypes import Segsnr
|
22
|
+
from ..datatypes import SourceFile
|
23
|
+
from ..datatypes import Sources
|
24
|
+
from ..datatypes import SourcesAudioF
|
25
|
+
from ..datatypes import SourcesAudioT
|
26
|
+
from ..datatypes import SpectralMask
|
27
|
+
from ..datatypes import SpeechMetadata
|
28
|
+
from ..datatypes import TransformConfig
|
29
|
+
from ..datatypes import TruthConfigs
|
30
|
+
from ..datatypes import TruthDict
|
31
|
+
from ..datatypes import TruthsConfigs
|
32
|
+
from ..datatypes import TruthsDict
|
33
|
+
from ..datatypes import UniversalSNR
|
30
34
|
|
31
35
|
|
32
36
|
def db_file(location: str, test: bool = False) -> str:
|
33
37
|
from os.path import join
|
34
38
|
|
39
|
+
from .constants import MIXDB_NAME
|
40
|
+
from .constants import TEST_MIXDB_NAME
|
41
|
+
|
35
42
|
if test:
|
36
|
-
name =
|
43
|
+
name = TEST_MIXDB_NAME
|
37
44
|
else:
|
38
|
-
name =
|
45
|
+
name = MIXDB_NAME
|
39
46
|
|
40
47
|
return join(location, name)
|
41
48
|
|
@@ -103,7 +110,7 @@ class MixtureDatabase:
|
|
103
110
|
config = load_config(self.location)
|
104
111
|
new_asr_configs = json.dumps(config["asr_configs"])
|
105
112
|
with self.db() as c:
|
106
|
-
old_asr_configs = c.execute("SELECT
|
113
|
+
old_asr_configs = c.execute("SELECT asr_configs FROM top").fetchone()
|
107
114
|
|
108
115
|
if old_asr_configs is not None and new_asr_configs != old_asr_configs[0]:
|
109
116
|
con = db_connection(location=self.location, readonly=False, test=self.test)
|
@@ -113,7 +120,7 @@ class MixtureDatabase:
|
|
113
120
|
|
114
121
|
@cached_property
|
115
122
|
def json(self) -> str:
|
116
|
-
from
|
123
|
+
from ..datatypes import MixtureDatabaseConfig
|
117
124
|
|
118
125
|
config = MixtureDatabaseConfig(
|
119
126
|
asr_configs=self.asr_configs,
|
@@ -121,13 +128,11 @@ class MixtureDatabase:
|
|
121
128
|
class_labels=self.class_labels,
|
122
129
|
class_weights_threshold=self.class_weights_thresholds,
|
123
130
|
feature=self.feature,
|
124
|
-
|
125
|
-
mixtures=self.mixtures
|
126
|
-
noise_mix_mode=self.noise_mix_mode,
|
127
|
-
noise_files=self.noise_files,
|
131
|
+
ir_files=self.ir_files,
|
132
|
+
mixtures=self.mixtures,
|
128
133
|
num_classes=self.num_classes,
|
129
134
|
spectral_masks=self.spectral_masks,
|
130
|
-
|
135
|
+
source_files=self.source_files,
|
131
136
|
)
|
132
137
|
return config.to_json(indent=2)
|
133
138
|
|
@@ -153,30 +158,28 @@ class MixtureDatabase:
|
|
153
158
|
return get_feature_generator_info(self.fg_config)
|
154
159
|
|
155
160
|
@cached_property
|
156
|
-
def truth_parameters(self) -> dict[str, int | None]:
|
161
|
+
def truth_parameters(self) -> dict[str, dict[str, int | None]]:
|
157
162
|
with self.db() as c:
|
158
|
-
rows = c.execute("SELECT
|
159
|
-
truth_parameters: dict[str, int | None] = {}
|
163
|
+
rows = c.execute("SELECT category, name, parameters FROM truth_parameters").fetchall()
|
164
|
+
truth_parameters: dict[str, dict[str, int | None]] = {}
|
160
165
|
for row in rows:
|
161
|
-
|
166
|
+
category, name, parameters = row
|
167
|
+
if category not in truth_parameters:
|
168
|
+
truth_parameters[category] = {}
|
169
|
+
truth_parameters[category][name] = parameters
|
162
170
|
return truth_parameters
|
163
171
|
|
164
172
|
@cached_property
|
165
173
|
def num_classes(self) -> int:
|
166
174
|
with self.db() as c:
|
167
|
-
return int(c.execute("SELECT
|
168
|
-
|
169
|
-
@cached_property
|
170
|
-
def noise_mix_mode(self) -> str:
|
171
|
-
with self.db() as c:
|
172
|
-
return str(c.execute("SELECT top.noise_mix_mode FROM top").fetchone()[0])
|
175
|
+
return int(c.execute("SELECT num_classes FROM top").fetchone()[0])
|
173
176
|
|
174
177
|
@cached_property
|
175
178
|
def asr_configs(self) -> ASRConfigs:
|
176
179
|
import json
|
177
180
|
|
178
181
|
with self.db() as c:
|
179
|
-
return json.loads(c.execute("SELECT
|
182
|
+
return json.loads(c.execute("SELECT asr_configs FROM top").fetchone()[0])
|
180
183
|
|
181
184
|
@cached_property
|
182
185
|
def supported_metrics(self) -> MetricDocs:
|
@@ -223,36 +226,36 @@ class MixtureDatabase:
|
|
223
226
|
"mxssnrdbf_std",
|
224
227
|
"Per-bin segmental standard deviation of the dB frame values over all frames (using feature transform)",
|
225
228
|
),
|
226
|
-
MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true
|
229
|
+
MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true sources"),
|
227
230
|
MetricDoc(
|
228
231
|
"Mixture Metrics",
|
229
232
|
"mxwsdr",
|
230
|
-
"Weighted signal distortion ratio of mixture versus true
|
233
|
+
"Weighted signal distortion ratio of mixture versus true sources",
|
231
234
|
),
|
232
235
|
MetricDoc(
|
233
236
|
"Mixture Metrics",
|
234
237
|
"mxpd",
|
235
|
-
"Phase distance between mixture and true
|
238
|
+
"Phase distance between mixture and true sources",
|
236
239
|
),
|
237
240
|
MetricDoc(
|
238
241
|
"Mixture Metrics",
|
239
242
|
"mxstoi",
|
240
|
-
"Short term objective intelligibility of mixture versus true
|
243
|
+
"Short term objective intelligibility of mixture versus true sources",
|
241
244
|
),
|
242
245
|
MetricDoc(
|
243
246
|
"Mixture Metrics",
|
244
247
|
"mxcsig",
|
245
|
-
"Predicted rating of speech distortion of mixture versus true
|
248
|
+
"Predicted rating of speech distortion of mixture versus true sources",
|
246
249
|
),
|
247
250
|
MetricDoc(
|
248
251
|
"Mixture Metrics",
|
249
252
|
"mxcbak",
|
250
|
-
"Predicted rating of background distortion of mixture versus true
|
253
|
+
"Predicted rating of background distortion of mixture versus true sources",
|
251
254
|
),
|
252
255
|
MetricDoc(
|
253
256
|
"Mixture Metrics",
|
254
257
|
"mxcovl",
|
255
|
-
"Predicted rating of overall quality of mixture versus true
|
258
|
+
"Predicted rating of overall quality of mixture versus true sources",
|
256
259
|
),
|
257
260
|
MetricDoc("Mixture Metrics", "ssnr", "Segmental SNR"),
|
258
261
|
MetricDoc("Mixture Metrics", "mxdco", "Mixture DC offset"),
|
@@ -265,26 +268,26 @@ class MixtureDatabase:
|
|
265
268
|
MetricDoc("Mixture Metrics", "mxcr", "Mixture Crest factor"),
|
266
269
|
MetricDoc("Mixture Metrics", "mxfl", "Mixture Flat factor"),
|
267
270
|
MetricDoc("Mixture Metrics", "mxpkc", "Mixture Pk count"),
|
268
|
-
MetricDoc("Mixture Metrics", "mxtdco", "Mixture
|
269
|
-
MetricDoc("Mixture Metrics", "mxtmin", "Mixture
|
270
|
-
MetricDoc("Mixture Metrics", "mxtmax", "Mixture
|
271
|
-
MetricDoc("Mixture Metrics", "mxtpkdb", "Mixture
|
272
|
-
MetricDoc("Mixture Metrics", "mxtlrms", "Mixture
|
273
|
-
MetricDoc("Mixture Metrics", "mxtpkr", "Mixture
|
274
|
-
MetricDoc("Mixture Metrics", "mxttr", "Mixture
|
275
|
-
MetricDoc("Mixture Metrics", "mxtcr", "Mixture
|
276
|
-
MetricDoc("Mixture Metrics", "mxtfl", "Mixture
|
277
|
-
MetricDoc("Mixture Metrics", "mxtpkc", "Mixture
|
278
|
-
MetricDoc("
|
279
|
-
MetricDoc("
|
280
|
-
MetricDoc("
|
281
|
-
MetricDoc("
|
282
|
-
MetricDoc("
|
283
|
-
MetricDoc("
|
284
|
-
MetricDoc("
|
285
|
-
MetricDoc("
|
286
|
-
MetricDoc("
|
287
|
-
MetricDoc("
|
271
|
+
MetricDoc("Mixture Metrics", "mxtdco", "Mixture source DC offset"),
|
272
|
+
MetricDoc("Mixture Metrics", "mxtmin", "Mixture source min level"),
|
273
|
+
MetricDoc("Mixture Metrics", "mxtmax", "Mixture source max levl"),
|
274
|
+
MetricDoc("Mixture Metrics", "mxtpkdb", "Mixture source Pk lev dB"),
|
275
|
+
MetricDoc("Mixture Metrics", "mxtlrms", "Mixture source RMS lev dB"),
|
276
|
+
MetricDoc("Mixture Metrics", "mxtpkr", "Mixture source RMS Pk dB"),
|
277
|
+
MetricDoc("Mixture Metrics", "mxttr", "Mixture source RMS Tr dB"),
|
278
|
+
MetricDoc("Mixture Metrics", "mxtcr", "Mixture source Crest factor"),
|
279
|
+
MetricDoc("Mixture Metrics", "mxtfl", "Mixture source Flat factor"),
|
280
|
+
MetricDoc("Mixture Metrics", "mxtpkc", "Mixture source Pk count"),
|
281
|
+
MetricDoc("Sources Metrics", "sdco", "Sources DC offset"),
|
282
|
+
MetricDoc("Sources Metrics", "smin", "Sources min level"),
|
283
|
+
MetricDoc("Sources Metrics", "smax", "Sources max levl"),
|
284
|
+
MetricDoc("Sources Metrics", "spkdb", "Sources Pk lev dB"),
|
285
|
+
MetricDoc("Sources Metrics", "slrms", "Sources RMS lev dB"),
|
286
|
+
MetricDoc("Sources Metrics", "spkr", "Sources RMS Pk dB"),
|
287
|
+
MetricDoc("Sources Metrics", "str", "Sources RMS Tr dB"),
|
288
|
+
MetricDoc("Sources Metrics", "scr", "Sources Crest factor"),
|
289
|
+
MetricDoc("Sources Metrics", "sfl", "Sources Flat factor"),
|
290
|
+
MetricDoc("Sources Metrics", "spkc", "Sources Pk count"),
|
288
291
|
MetricDoc("Noise Metrics", "ndco", "Noise DC offset"),
|
289
292
|
MetricDoc("Noise Metrics", "nmin", "Noise min level"),
|
290
293
|
MetricDoc("Noise Metrics", "nmax", "Noise max levl"),
|
@@ -320,16 +323,16 @@ class MixtureDatabase:
|
|
320
323
|
for name in self.asr_configs:
|
321
324
|
metrics.append(
|
322
325
|
MetricDoc(
|
323
|
-
"
|
324
|
-
f"
|
325
|
-
f"Mixture
|
326
|
+
"Source Metrics",
|
327
|
+
f"mxsasr.{name}",
|
328
|
+
f"Mixture Source ASR text using {name} ASR as defined in mixdb asr_configs parameter",
|
326
329
|
)
|
327
330
|
)
|
328
331
|
metrics.append(
|
329
332
|
MetricDoc(
|
330
|
-
"
|
331
|
-
f"
|
332
|
-
f"
|
333
|
+
"Source Metrics",
|
334
|
+
f"sasr.{name}",
|
335
|
+
f"Sources ASR text using {name} ASR as defined in mixdb asr_configs parameter",
|
333
336
|
)
|
334
337
|
)
|
335
338
|
metrics.append(
|
@@ -341,16 +344,16 @@ class MixtureDatabase:
|
|
341
344
|
)
|
342
345
|
metrics.append(
|
343
346
|
MetricDoc(
|
344
|
-
"
|
347
|
+
"Source Metrics",
|
345
348
|
f"basewer.{name}",
|
346
|
-
f"Word error rate of
|
349
|
+
f"Word error rate of sasr.{name} vs. speech text metadata for the source",
|
347
350
|
)
|
348
351
|
)
|
349
352
|
metrics.append(
|
350
353
|
MetricDoc(
|
351
354
|
"Mixture Metrics",
|
352
355
|
f"mxwer.{name}",
|
353
|
-
f"Word error rate of mxasr.{name} vs.
|
356
|
+
f"Word error rate of mxasr.{name} vs. sasr.{name}",
|
354
357
|
)
|
355
358
|
)
|
356
359
|
|
@@ -359,12 +362,12 @@ class MixtureDatabase:
|
|
359
362
|
@cached_property
|
360
363
|
def class_balancing(self) -> bool:
|
361
364
|
with self.db() as c:
|
362
|
-
return bool(c.execute("SELECT
|
365
|
+
return bool(c.execute("SELECT class_balancing FROM top").fetchone()[0])
|
363
366
|
|
364
367
|
@cached_property
|
365
368
|
def feature(self) -> str:
|
366
369
|
with self.db() as c:
|
367
|
-
return str(c.execute("SELECT
|
370
|
+
return str(c.execute("SELECT feature FROM top").fetchone()[0])
|
368
371
|
|
369
372
|
@cached_property
|
370
373
|
def fg_decimation(self) -> int:
|
@@ -396,7 +399,7 @@ class MixtureDatabase:
|
|
396
399
|
|
397
400
|
@cached_property
|
398
401
|
def transform_frame_ms(self) -> float:
|
399
|
-
from
|
402
|
+
from ..constants import SAMPLE_RATE
|
400
403
|
|
401
404
|
return float(self.ft_config.overlap) / float(SAMPLE_RATE / 1000)
|
402
405
|
|
@@ -417,12 +420,7 @@ class MixtureDatabase:
|
|
417
420
|
return self.ft_config.overlap * self.fg_decimation * self.fg_step
|
418
421
|
|
419
422
|
def total_samples(self, m_ids: GeneralizedIDs = "*") -> int:
|
420
|
-
samples
|
421
|
-
for m_id in self.mixids_to_list(m_ids):
|
422
|
-
s = self.mixture(m_id).samples
|
423
|
-
if s is not None:
|
424
|
-
samples += s
|
425
|
-
return samples
|
423
|
+
return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(m_ids)])
|
426
424
|
|
427
425
|
def total_transform_frames(self, m_ids: GeneralizedIDs = "*") -> int:
|
428
426
|
return self.total_samples(m_ids) // self.ft_config.overlap
|
@@ -457,10 +455,7 @@ class MixtureDatabase:
|
|
457
455
|
:return: Class labels
|
458
456
|
"""
|
459
457
|
with self.db() as c:
|
460
|
-
return [
|
461
|
-
str(item[0])
|
462
|
-
for item in c.execute("SELECT class_label.label FROM class_label ORDER BY class_label.id").fetchall()
|
463
|
-
]
|
458
|
+
return [str(item[0]) for item in c.execute("SELECT label FROM class_label ORDER BY id").fetchall()]
|
464
459
|
|
465
460
|
@cached_property
|
466
461
|
def class_weights_thresholds(self) -> list[float]:
|
@@ -469,37 +464,20 @@ class MixtureDatabase:
|
|
469
464
|
:return: Class weights thresholds
|
470
465
|
"""
|
471
466
|
with self.db() as c:
|
472
|
-
return [
|
473
|
-
float(item[0])
|
474
|
-
for item in c.execute(
|
475
|
-
"SELECT class_weights_threshold.threshold FROM class_weights_threshold"
|
476
|
-
).fetchall()
|
477
|
-
]
|
478
|
-
|
479
|
-
@cached_property
|
480
|
-
def truth_configs(self) -> TruthConfigs:
|
481
|
-
"""Get truth configs from db
|
482
|
-
|
483
|
-
:return: Truth configs
|
484
|
-
"""
|
485
|
-
import json
|
467
|
+
return [float(item[0]) for item in c.execute("SELECT threshold FROM class_weights_threshold").fetchall()]
|
486
468
|
|
487
|
-
|
469
|
+
def category_truth_configs(self, category: str) -> dict[str, str]:
|
470
|
+
return _category_truth_configs(self.db, category, self.use_cache)
|
488
471
|
|
489
|
-
|
490
|
-
|
491
|
-
for truth_config_record in c.execute("SELECT truth_config.config FROM truth_config").fetchall():
|
492
|
-
truth_config = json.loads(truth_config_record[0])
|
493
|
-
if truth_config["name"] not in truth_configs:
|
494
|
-
truth_configs[truth_config["name"]] = TruthConfig(
|
495
|
-
function=truth_config["function"],
|
496
|
-
stride_reduction=truth_config["stride_reduction"],
|
497
|
-
config=truth_config["config"],
|
498
|
-
)
|
499
|
-
return truth_configs
|
472
|
+
def source_truth_configs(self, s_id: int) -> TruthConfigs:
|
473
|
+
return _source_truth_configs(self.db, s_id, self.use_cache)
|
500
474
|
|
501
|
-
def
|
502
|
-
|
475
|
+
def mixture_truth_configs(self, m_id: int) -> TruthsConfigs:
|
476
|
+
mixture = self.mixture(m_id)
|
477
|
+
return {
|
478
|
+
category: self.source_truth_configs(mixture.all_sources[category].file_id)
|
479
|
+
for category in mixture.all_sources
|
480
|
+
}
|
503
481
|
|
504
482
|
@cached_property
|
505
483
|
def random_snrs(self) -> list[float]:
|
@@ -509,10 +487,7 @@ class MixtureDatabase:
|
|
509
487
|
"""
|
510
488
|
with self.db() as c:
|
511
489
|
return list(
|
512
|
-
{
|
513
|
-
float(item[0])
|
514
|
-
for item in c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 1").fetchall()
|
515
|
-
}
|
490
|
+
{float(item[0]) for item in c.execute("SELECT snr FROM source WHERE snr_random == 1").fetchall()}
|
516
491
|
)
|
517
492
|
|
518
493
|
@cached_property
|
@@ -523,10 +498,7 @@ class MixtureDatabase:
|
|
523
498
|
"""
|
524
499
|
with self.db() as c:
|
525
500
|
return list(
|
526
|
-
{
|
527
|
-
float(item[0])
|
528
|
-
for item in c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 0").fetchall()
|
529
|
-
}
|
501
|
+
{float(item[0]) for item in c.execute("SELECT snr FROM source WHERE snr_random == 0").fetchall()}
|
530
502
|
)
|
531
503
|
|
532
504
|
@cached_property
|
@@ -570,199 +542,216 @@ class MixtureDatabase:
|
|
570
542
|
return _spectral_mask(self.db, sm_id, self.use_cache)
|
571
543
|
|
572
544
|
@cached_property
|
573
|
-
def
|
574
|
-
"""Get
|
545
|
+
def source_files(self) -> dict[str, list[SourceFile]]:
|
546
|
+
"""Get source files from db
|
575
547
|
|
576
|
-
:return:
|
548
|
+
:return: Source files
|
577
549
|
"""
|
578
550
|
import json
|
579
551
|
|
580
|
-
from
|
581
|
-
from
|
582
|
-
from .db_datatypes import
|
552
|
+
from ..datatypes import TruthConfig
|
553
|
+
from ..datatypes import TruthConfigs
|
554
|
+
from .db_datatypes import SourceFileRecord
|
583
555
|
|
584
556
|
with self.db() as c:
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
557
|
+
source_files: dict[str, list[SourceFile]] = {}
|
558
|
+
categories = c.execute("SELECT DISTINCT category FROM source_file").fetchall()
|
559
|
+
for category in categories:
|
560
|
+
source_files[category[0]] = []
|
561
|
+
source_file_records = [
|
562
|
+
SourceFileRecord(*result)
|
563
|
+
for result in c.execute("SELECT * FROM source_file WHERE ? = category", (category[0],)).fetchall()
|
564
|
+
]
|
565
|
+
for source_file_record in source_file_records:
|
566
|
+
truth_configs: TruthConfigs = {}
|
567
|
+
for truth_config_records in c.execute(
|
568
|
+
"""
|
569
|
+
SELECT truth_config.config
|
570
|
+
FROM truth_config, source_file_truth_config
|
571
|
+
WHERE ? = source_file_truth_config.source_file_id
|
572
|
+
AND truth_config.id = source_file_truth_config.truth_config_id
|
573
|
+
""",
|
574
|
+
(source_file_record.id,),
|
575
|
+
).fetchall():
|
576
|
+
truth_config = json.loads(truth_config_records[0])
|
577
|
+
truth_configs[truth_config["name"]] = TruthConfig(
|
578
|
+
function=truth_config["function"],
|
579
|
+
stride_reduction=truth_config["stride_reduction"],
|
580
|
+
config=truth_config["config"],
|
581
|
+
)
|
582
|
+
source_files[source_file_record.category].append(
|
583
|
+
SourceFile(
|
584
|
+
id=source_file_record.id,
|
585
|
+
category=source_file_record.category,
|
586
|
+
name=source_file_record.name,
|
587
|
+
samples=source_file_record.samples,
|
588
|
+
class_indices=json.loads(source_file_record.class_indices),
|
589
|
+
level_type=source_file_record.level_type,
|
590
|
+
truth_configs=truth_configs,
|
591
|
+
speaker_id=source_file_record.speaker_id,
|
592
|
+
)
|
614
593
|
)
|
615
|
-
|
616
|
-
return target_files
|
594
|
+
return source_files
|
617
595
|
|
618
596
|
@cached_property
|
619
|
-
def
|
620
|
-
"""Get
|
597
|
+
def source_file_ids(self) -> dict[str, list[int]]:
|
598
|
+
"""Get source file IDs from db
|
621
599
|
|
622
|
-
:return:
|
600
|
+
:return: Dictionary of list of source file IDs
|
623
601
|
"""
|
624
602
|
with self.db() as c:
|
625
|
-
|
603
|
+
source_file_ids: dict[str, list[int]] = {}
|
604
|
+
categories = c.execute("SELECT DISTINCT category FROM source_file").fetchall()
|
605
|
+
for category in categories:
|
606
|
+
source_file_ids[category[0]] = [
|
607
|
+
int(item[0])
|
608
|
+
for item in c.execute("SELECT id FROM source_file WHERE ? = category", (category[0],)).fetchall()
|
609
|
+
]
|
610
|
+
return source_file_ids
|
626
611
|
|
627
|
-
def
|
628
|
-
"""Get
|
612
|
+
def source_file(self, s_id: int) -> SourceFile:
|
613
|
+
"""Get source file with ID from db
|
629
614
|
|
630
|
-
:param
|
631
|
-
:return:
|
615
|
+
:param s_id: Source file ID
|
616
|
+
:return: Source file
|
632
617
|
"""
|
633
|
-
return
|
618
|
+
return _source_file(self.db, s_id, self.use_cache)
|
634
619
|
|
635
|
-
|
636
|
-
|
637
|
-
"""Get number of target files from db
|
620
|
+
def num_source_files(self, category: str) -> int:
|
621
|
+
"""Get number of source files from category from db
|
638
622
|
|
639
|
-
:
|
623
|
+
:param category: Source category
|
624
|
+
:return: Number of source files
|
640
625
|
"""
|
641
|
-
|
642
|
-
return int(c.execute("SELECT count(target_file.id) FROM target_file").fetchone()[0])
|
626
|
+
return _num_source_files(self.db, category, self.use_cache)
|
643
627
|
|
644
628
|
@cached_property
|
645
|
-
def
|
646
|
-
"""Get
|
629
|
+
def ir_files(self) -> list[ImpulseResponseFile]:
|
630
|
+
"""Get impulse response files from db
|
647
631
|
|
648
|
-
:return:
|
632
|
+
:return: Impulse response files
|
649
633
|
"""
|
650
|
-
|
651
|
-
return [
|
652
|
-
NoiseFile(name=noise[0], samples=noise[1])
|
653
|
-
for noise in c.execute("SELECT noise_file.name, samples FROM noise_file").fetchall()
|
654
|
-
]
|
655
|
-
|
656
|
-
@cached_property
|
657
|
-
def noise_file_ids(self) -> list[int]:
|
658
|
-
"""Get noise file IDs from db
|
634
|
+
from .db_datatypes import ImpulseResponseFileRecord
|
659
635
|
|
660
|
-
:return: List of noise file IDs
|
661
|
-
"""
|
662
636
|
with self.db() as c:
|
663
|
-
|
637
|
+
files: list[ImpulseResponseFile] = []
|
638
|
+
entries = c.execute("SELECT * FROM ir_file").fetchall()
|
639
|
+
for entry in entries:
|
640
|
+
file = ImpulseResponseFileRecord(*entry)
|
641
|
+
|
642
|
+
tags = [
|
643
|
+
tag[0]
|
644
|
+
for tag in c.execute(
|
645
|
+
"""
|
646
|
+
SELECT ir_tag.tag
|
647
|
+
FROM ir_tag, ir_file_ir_tag
|
648
|
+
WHERE ? = ir_file_ir_tag.file_id
|
649
|
+
AND ir_tag.id = ir_file_ir_tag.tag_id
|
650
|
+
""",
|
651
|
+
(file.id,),
|
652
|
+
).fetchall()
|
653
|
+
]
|
664
654
|
|
665
|
-
|
666
|
-
|
655
|
+
files.append(
|
656
|
+
ImpulseResponseFile(
|
657
|
+
delay=file.delay,
|
658
|
+
name=file.name,
|
659
|
+
tags=tags,
|
660
|
+
)
|
661
|
+
)
|
667
662
|
|
668
|
-
|
669
|
-
:return: Noise file
|
670
|
-
"""
|
671
|
-
return _noise_file(self.db, n_id, self.use_cache)
|
663
|
+
return files
|
672
664
|
|
673
665
|
@cached_property
|
674
|
-
def
|
675
|
-
"""Get
|
666
|
+
def ir_file_ids(self) -> list[int]:
|
667
|
+
"""Get impulse response file IDs from db
|
676
668
|
|
677
|
-
:return:
|
669
|
+
:return: List of impulse response file IDs
|
678
670
|
"""
|
679
671
|
with self.db() as c:
|
680
|
-
return int(c.execute("SELECT
|
672
|
+
return [int(item[0]) for item in c.execute("SELECT id FROM ir_file").fetchall()]
|
681
673
|
|
682
|
-
|
683
|
-
|
684
|
-
"""Get impulse response files from db
|
674
|
+
def ir_file_ids_for_tag(self, tag: str) -> list[int]:
|
675
|
+
"""Get impulse response file IDs for given tag from db
|
685
676
|
|
686
|
-
:return:
|
677
|
+
:return: List of impulse response file IDs for given tag
|
687
678
|
"""
|
688
|
-
import json
|
689
|
-
|
690
|
-
from .datatypes import ImpulseResponseFile
|
691
|
-
|
692
679
|
with self.db() as c:
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
"SELECT impulse_response_file.* FROM impulse_response_file"
|
697
|
-
).fetchall()
|
698
|
-
]
|
699
|
-
|
700
|
-
@cached_property
|
701
|
-
def impulse_response_file_ids(self) -> list[int]:
|
702
|
-
"""Get impulse response file IDs from db
|
680
|
+
tag_id = c.execute("SELECT id FROM ir_tag WHERE ? = tag", (tag,)).fetchone()
|
681
|
+
if not tag_id:
|
682
|
+
return []
|
703
683
|
|
704
|
-
:return: List of impulse response file IDs
|
705
|
-
"""
|
706
|
-
with self.db() as c:
|
707
684
|
return [
|
708
|
-
int(item[0])
|
709
|
-
for item in c.execute("SELECT
|
685
|
+
int(item[0] - 1)
|
686
|
+
for item in c.execute("SELECT file_id FROM ir_file_ir_tag WHERE ? = tag_id", (tag_id[0],)).fetchall()
|
710
687
|
]
|
711
688
|
|
712
|
-
def
|
689
|
+
def ir_file(self, ir_id: int) -> str:
|
713
690
|
"""Get impulse response file name with ID from db
|
714
691
|
|
715
692
|
:param ir_id: Impulse response file ID
|
716
693
|
:return: Impulse response file name
|
717
694
|
"""
|
718
|
-
|
719
|
-
return None
|
720
|
-
return _impulse_response_file(self.db, ir_id, self.use_cache)
|
695
|
+
return _ir_file(self.db, ir_id, self.use_cache)
|
721
696
|
|
722
|
-
def
|
697
|
+
def ir_delay(self, ir_id: int) -> int:
|
723
698
|
"""Get impulse response delay with ID from db
|
724
699
|
|
725
700
|
:param ir_id: Impulse response file ID
|
726
701
|
:return: Impulse response delay
|
727
702
|
"""
|
728
|
-
|
729
|
-
return None
|
730
|
-
return _impulse_response_delay(self.db, ir_id, self.use_cache)
|
703
|
+
return _ir_delay(self.db, ir_id, self.use_cache)
|
731
704
|
|
732
705
|
@cached_property
|
733
|
-
def
|
706
|
+
def num_ir_files(self) -> int:
|
734
707
|
"""Get number of impulse response files from db
|
735
708
|
|
736
709
|
:return: Number of impulse response files
|
737
710
|
"""
|
738
711
|
with self.db() as c:
|
739
|
-
return int(c.execute("SELECT count(
|
712
|
+
return int(c.execute("SELECT count(id) FROM ir_file").fetchone()[0])
|
740
713
|
|
714
|
+
@cached_property
|
715
|
+
def ir_tags(self) -> list[str]:
|
716
|
+
"""Get tags of impulse response files from db
|
717
|
+
|
718
|
+
:return: Tags of impulse response files
|
719
|
+
"""
|
720
|
+
with self.db() as c:
|
721
|
+
return [tag[0] for tag in c.execute("SELECT tag FROM ir_tag").fetchall()]
|
722
|
+
|
723
|
+
@property
|
741
724
|
def mixtures(self) -> list[Mixture]:
|
742
725
|
"""Get mixtures from db
|
743
726
|
|
744
727
|
:return: Mixtures
|
745
728
|
"""
|
746
729
|
from .db_datatypes import MixtureRecord
|
747
|
-
from .db_datatypes import
|
730
|
+
from .db_datatypes import SourceRecord
|
748
731
|
from .helpers import to_mixture
|
749
|
-
from .helpers import
|
732
|
+
from .helpers import to_source
|
750
733
|
|
751
734
|
with self.db() as c:
|
752
735
|
mixtures: list[Mixture] = []
|
753
736
|
for mixture in [MixtureRecord(*record) for record in c.execute("SELECT * FROM mixture").fetchall()]:
|
754
|
-
|
755
|
-
|
756
|
-
for
|
737
|
+
sources_list = [
|
738
|
+
to_source(SourceRecord(*source))
|
739
|
+
for source in c.execute(
|
757
740
|
"""
|
758
|
-
SELECT
|
759
|
-
FROM
|
760
|
-
WHERE ? =
|
741
|
+
SELECT source.*
|
742
|
+
FROM source, mixture_source
|
743
|
+
WHERE ? = mixture_source.mixture_id AND source.id = mixture_source.source_id
|
761
744
|
""",
|
762
745
|
(mixture.id,),
|
763
746
|
).fetchall()
|
764
747
|
]
|
765
|
-
|
748
|
+
|
749
|
+
sources: Sources = {}
|
750
|
+
for source in sources_list:
|
751
|
+
sources[self.source_file(source.file_id).category] = source
|
752
|
+
|
753
|
+
mixtures.append(to_mixture(mixture, sources))
|
754
|
+
|
766
755
|
return mixtures
|
767
756
|
|
768
757
|
@cached_property
|
@@ -772,7 +761,7 @@ class MixtureDatabase:
|
|
772
761
|
:return: List of zero-based mixture IDs
|
773
762
|
"""
|
774
763
|
with self.db() as c:
|
775
|
-
return [int(item[0]) - 1 for item in c.execute("SELECT
|
764
|
+
return [int(item[0]) - 1 for item in c.execute("SELECT id FROM mixture").fetchall()]
|
776
765
|
|
777
766
|
def mixture(self, m_id: int) -> Mixture:
|
778
767
|
"""Get mixture record with ID from db
|
@@ -785,7 +774,7 @@ class MixtureDatabase:
|
|
785
774
|
@cached_property
|
786
775
|
def mixid_width(self) -> int:
|
787
776
|
with self.db() as c:
|
788
|
-
return int(c.execute("SELECT
|
777
|
+
return int(c.execute("SELECT mixid_width FROM top").fetchone()[0])
|
789
778
|
|
790
779
|
def mixture_location(self, m_id: int) -> str:
|
791
780
|
"""Get the file location for the give mixture ID
|
@@ -804,231 +793,342 @@ class MixtureDatabase:
|
|
804
793
|
:return: Number of mixtures
|
805
794
|
"""
|
806
795
|
with self.db() as c:
|
807
|
-
return int(c.execute("SELECT count(
|
796
|
+
return int(c.execute("SELECT count(id) FROM mixture").fetchone()[0])
|
808
797
|
|
809
|
-
def read_mixture_data(self, m_id: int, items: list[str] | str) -> Any:
|
798
|
+
def read_mixture_data(self, m_id: int, items: list[str] | str) -> dict[str, Any]:
|
810
799
|
"""Read mixture data
|
811
800
|
|
812
801
|
:param m_id: Zero-based mixture ID
|
813
802
|
:param items: String(s) of dataset(s) to retrieve
|
814
|
-
:return:
|
803
|
+
:return: Dictionary of name: data
|
815
804
|
"""
|
816
|
-
from
|
805
|
+
from .data_io import read_cached_data
|
817
806
|
|
818
807
|
return read_cached_data(self.location, "mixture", self.mixture(m_id).name, items)
|
819
808
|
|
820
|
-
def
|
821
|
-
"""Read
|
809
|
+
def read_source_audio(self, s_id: int) -> AudioT:
|
810
|
+
"""Read source audio
|
822
811
|
|
823
|
-
:param
|
824
|
-
:return:
|
812
|
+
:param s_id: Source ID
|
813
|
+
:return: Source audio
|
825
814
|
"""
|
826
815
|
from .audio import read_audio
|
827
816
|
|
828
|
-
return read_audio(self.
|
829
|
-
|
830
|
-
def augmented_noise_audio(self, mixture: Mixture) -> AudioT:
|
831
|
-
"""Get augmented noise audio
|
832
|
-
|
833
|
-
:param mixture: Mixture
|
834
|
-
:return: Augmented noise audio
|
835
|
-
"""
|
836
|
-
from .audio import read_audio
|
837
|
-
from .augmentation import apply_augmentation
|
838
|
-
|
839
|
-
noise = self.noise_file(mixture.noise.file_id)
|
840
|
-
audio = read_audio(noise.name, self.use_cache)
|
841
|
-
audio = apply_augmentation(self, audio, mixture.noise.augmentation.pre)
|
842
|
-
|
843
|
-
return audio
|
817
|
+
return read_audio(self.source_file(s_id).name, self.use_cache)
|
844
818
|
|
845
819
|
def mixture_class_indices(self, m_id: int) -> list[int]:
|
846
820
|
class_indices: list[int] = []
|
847
|
-
for
|
848
|
-
class_indices.extend(self.
|
821
|
+
for s_id in self.mixture(m_id).source_ids.values():
|
822
|
+
class_indices.extend(self.source_file(s_id).class_indices)
|
849
823
|
return sorted(set(class_indices))
|
850
824
|
|
851
|
-
def
|
852
|
-
"""Get the
|
825
|
+
def mixture_sources(self, m_id: int, force: bool = False, cache: bool = False) -> SourcesAudioT:
|
826
|
+
"""Get the pre-truth source audio data (one per source in the mixture) for the given mixture ID
|
853
827
|
|
854
828
|
:param m_id: Zero-based mixture ID
|
855
829
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
856
|
-
:
|
830
|
+
:param cache: Cache result
|
831
|
+
:return: Dictionary of pre-truth source audio data (one per source in the mixture)
|
857
832
|
"""
|
858
|
-
from .
|
859
|
-
from .
|
860
|
-
from .
|
833
|
+
from .data_io import write_cached_data
|
834
|
+
from .effects import apply_effects
|
835
|
+
from .effects import conform_audio_to_length
|
861
836
|
|
862
837
|
if not force:
|
863
|
-
|
864
|
-
if
|
865
|
-
return
|
838
|
+
sources = self.read_mixture_data(m_id, "sources")["sources"]
|
839
|
+
if sources is not None:
|
840
|
+
return sources
|
866
841
|
|
867
842
|
mixture = self.mixture(m_id)
|
868
843
|
if mixture is None:
|
869
844
|
raise ValueError(f"Could not find mixture for m_id: {m_id}")
|
870
845
|
|
871
|
-
|
872
|
-
for
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
846
|
+
sources = {}
|
847
|
+
for category, source in mixture.all_sources.items():
|
848
|
+
source = mixture.all_sources[category]
|
849
|
+
audio = self.read_source_audio(source.file_id)
|
850
|
+
audio = apply_effects(self, audio, source.effects, pre=True, post=False)
|
851
|
+
audio = conform_audio_to_length(audio, mixture.samples, source.repeat, source.start)
|
852
|
+
sources[category] = audio
|
853
|
+
|
854
|
+
if cache:
|
855
|
+
write_cached_data(
|
856
|
+
location=self.location,
|
857
|
+
name="mixture",
|
858
|
+
index=mixture.name,
|
859
|
+
items={"sources": sources},
|
879
860
|
)
|
880
|
-
target_audio = apply_gain(audio=target_audio, gain=mixture.target_snr_gain)
|
881
|
-
target_audio = pad_audio_to_length(audio=target_audio, length=mixture.samples)
|
882
|
-
targets_audio.append(target_audio)
|
883
861
|
|
884
|
-
return
|
862
|
+
return sources
|
885
863
|
|
886
|
-
def
|
887
|
-
|
864
|
+
def mixture_sources_f(
|
865
|
+
self,
|
866
|
+
m_id: int,
|
867
|
+
sources: SourcesAudioT | None = None,
|
868
|
+
force: bool = False,
|
869
|
+
cache: bool = False,
|
870
|
+
) -> SourcesAudioF:
|
871
|
+
"""Get the pre-truth source transform data (one per source in the mixture) for the given mixture ID
|
888
872
|
|
889
873
|
:param m_id: Zero-based mixture ID
|
890
|
-
:param
|
874
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
891
875
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
892
|
-
:
|
876
|
+
:param cache: Cache result
|
877
|
+
:return: Dictionary of pre-truth source transform data (one per source in the mixture)
|
893
878
|
"""
|
879
|
+
from .data_io import write_cached_data
|
894
880
|
from .helpers import forward_transform
|
895
881
|
|
896
|
-
if
|
897
|
-
|
882
|
+
if sources is None:
|
883
|
+
sources = self.mixture_sources(m_id, force)
|
898
884
|
|
899
|
-
|
885
|
+
sources_f = {category: forward_transform(sources[category], self.ft_config) for category in sources}
|
900
886
|
|
901
|
-
|
902
|
-
|
887
|
+
if cache:
|
888
|
+
write_cached_data(
|
889
|
+
location=self.location,
|
890
|
+
name="mixture",
|
891
|
+
index=self.mixture(m_id).name,
|
892
|
+
items={"sources_f": sources_f},
|
893
|
+
)
|
894
|
+
|
895
|
+
return sources_f
|
896
|
+
|
897
|
+
def mixture_source(
|
898
|
+
self,
|
899
|
+
m_id: int,
|
900
|
+
sources: SourcesAudioT | None = None,
|
901
|
+
force: bool = False,
|
902
|
+
cache: bool = False,
|
903
|
+
) -> AudioT:
|
904
|
+
"""Get the post-truth, summed, and gained source audio data for the given mixture ID
|
903
905
|
|
904
906
|
:param m_id: Zero-based mixture ID
|
905
|
-
:param
|
907
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
906
908
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
907
|
-
:
|
909
|
+
:param cache: Cache result
|
910
|
+
:return: Post-truth, gained, and summed source audio data
|
908
911
|
"""
|
909
|
-
|
912
|
+
import numpy as np
|
913
|
+
|
914
|
+
from .data_io import write_cached_data
|
915
|
+
from .effects import apply_effects
|
910
916
|
|
911
917
|
if not force:
|
912
|
-
|
913
|
-
if
|
914
|
-
return
|
918
|
+
source = self.read_mixture_data(m_id, "source")["source"]
|
919
|
+
if source is not None:
|
920
|
+
return source
|
921
|
+
|
922
|
+
if sources is None:
|
923
|
+
sources = self.mixture_sources(m_id, force)
|
924
|
+
|
925
|
+
mixture = self.mixture(m_id)
|
915
926
|
|
916
|
-
|
917
|
-
|
927
|
+
source = np.sum(
|
928
|
+
[
|
929
|
+
apply_effects(
|
930
|
+
self,
|
931
|
+
audio=sources[category],
|
932
|
+
effects=mixture.all_sources[category].effects,
|
933
|
+
pre=False,
|
934
|
+
post=True,
|
935
|
+
)
|
936
|
+
* mixture.all_sources[category].snr_gain
|
937
|
+
for category in sources
|
938
|
+
if category != "noise"
|
939
|
+
],
|
940
|
+
axis=0,
|
941
|
+
)
|
918
942
|
|
919
|
-
|
943
|
+
if cache:
|
944
|
+
write_cached_data(
|
945
|
+
location=self.location,
|
946
|
+
name="mixture",
|
947
|
+
index=mixture.name,
|
948
|
+
items={"source": source},
|
949
|
+
)
|
950
|
+
|
951
|
+
return source
|
920
952
|
|
921
|
-
def
|
953
|
+
def mixture_source_f(
|
922
954
|
self,
|
923
955
|
m_id: int,
|
924
|
-
|
925
|
-
|
956
|
+
sources: SourcesAudioT | None = None,
|
957
|
+
source: AudioT | None = None,
|
926
958
|
force: bool = False,
|
959
|
+
cache: bool = False,
|
927
960
|
) -> AudioF:
|
928
|
-
"""Get the
|
961
|
+
"""Get the post-truth, summed, and gained source transform data for the given mixture ID
|
929
962
|
|
930
963
|
:param m_id: Zero-based mixture ID
|
931
|
-
:param
|
932
|
-
:param
|
964
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
965
|
+
:param source: Post-truth, gained, and summed source audio for the given m_id
|
933
966
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
934
|
-
:
|
967
|
+
:param cache: Cache result
|
968
|
+
:return: Post-truth, gained, and summed source transform data
|
935
969
|
"""
|
970
|
+
from .data_io import write_cached_data
|
936
971
|
from .helpers import forward_transform
|
937
972
|
|
938
|
-
if
|
939
|
-
|
973
|
+
if source is None:
|
974
|
+
source = self.mixture_source(m_id, sources, force)
|
940
975
|
|
941
|
-
|
976
|
+
source_f = forward_transform(source, self.ft_config)
|
942
977
|
|
943
|
-
|
944
|
-
|
978
|
+
if cache:
|
979
|
+
write_cached_data(
|
980
|
+
location=self.location,
|
981
|
+
name="mixture",
|
982
|
+
index=self.mixture(m_id).name,
|
983
|
+
items={"source_f": source_f},
|
984
|
+
)
|
985
|
+
|
986
|
+
return source_f
|
987
|
+
|
988
|
+
def mixture_noise(
|
989
|
+
self,
|
990
|
+
m_id: int,
|
991
|
+
sources: SourcesAudioT | None = None,
|
992
|
+
force: bool = False,
|
993
|
+
cache: bool = False,
|
994
|
+
) -> AudioT:
|
995
|
+
"""Get the post-truth and gained noise audio data for the given mixture ID
|
945
996
|
|
946
997
|
:param m_id: Zero-based mixture ID
|
998
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
947
999
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
948
|
-
:
|
1000
|
+
:param cache: Cache result
|
1001
|
+
:return: Post-truth and gained noise audio data
|
949
1002
|
"""
|
950
|
-
from .
|
951
|
-
from .
|
1003
|
+
from .data_io import write_cached_data
|
1004
|
+
from .effects import apply_effects
|
952
1005
|
|
953
1006
|
if not force:
|
954
|
-
noise = self.read_mixture_data(m_id, "noise")
|
1007
|
+
noise = self.read_mixture_data(m_id, "noise")["noise"]
|
955
1008
|
if noise is not None:
|
956
1009
|
return noise
|
957
1010
|
|
958
|
-
|
959
|
-
|
960
|
-
noise = get_next_noise(audio=noise, offset=mixture.noise_offset, length=mixture.samples)
|
961
|
-
return apply_gain(audio=noise, gain=mixture.noise_snr_gain)
|
1011
|
+
if sources is None:
|
1012
|
+
sources = self.mixture_sources(m_id, force)
|
962
1013
|
|
963
|
-
|
964
|
-
|
1014
|
+
noise = self.mixture(m_id).noise
|
1015
|
+
noise = apply_effects(self, sources["noise"], noise.effects, pre=False, post=True) * noise.snr_gain
|
1016
|
+
|
1017
|
+
if cache:
|
1018
|
+
write_cached_data(
|
1019
|
+
location=self.location,
|
1020
|
+
name="mixture",
|
1021
|
+
index=self.mixture(m_id).name,
|
1022
|
+
items={"noise": noise},
|
1023
|
+
)
|
1024
|
+
|
1025
|
+
return noise
|
1026
|
+
|
1027
|
+
def mixture_noise_f(
|
1028
|
+
self,
|
1029
|
+
m_id: int,
|
1030
|
+
sources: SourcesAudioT | None = None,
|
1031
|
+
noise: AudioT | None = None,
|
1032
|
+
force: bool = False,
|
1033
|
+
cache: bool = False,
|
1034
|
+
) -> AudioF:
|
1035
|
+
"""Get the post-truth and gained noise transform for the given mixture ID
|
965
1036
|
|
966
1037
|
:param m_id: Zero-based mixture ID
|
967
|
-
:param
|
1038
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1039
|
+
:param noise: Post-truth and gained noise audio data
|
968
1040
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
969
|
-
:
|
1041
|
+
:param cache: Cache result
|
1042
|
+
:return: Post-truth and gained noise transform data
|
970
1043
|
"""
|
1044
|
+
from .data_io import write_cached_data
|
971
1045
|
from .helpers import forward_transform
|
972
1046
|
|
973
1047
|
if force or noise is None:
|
974
|
-
noise = self.mixture_noise(m_id, force)
|
1048
|
+
noise = self.mixture_noise(m_id, sources, force)
|
1049
|
+
|
1050
|
+
noise_f = forward_transform(noise, self.ft_config)
|
1051
|
+
if cache:
|
1052
|
+
write_cached_data(
|
1053
|
+
location=self.location,
|
1054
|
+
name="mixture",
|
1055
|
+
index=self.mixture(m_id).name,
|
1056
|
+
items={"noise_f": noise_f},
|
1057
|
+
)
|
975
1058
|
|
976
|
-
return
|
1059
|
+
return noise_f
|
977
1060
|
|
978
1061
|
def mixture_mixture(
|
979
1062
|
self,
|
980
1063
|
m_id: int,
|
981
|
-
|
982
|
-
|
1064
|
+
sources: SourcesAudioT | None = None,
|
1065
|
+
source: AudioT | None = None,
|
983
1066
|
noise: AudioT | None = None,
|
984
1067
|
force: bool = False,
|
1068
|
+
cache: bool = False,
|
985
1069
|
) -> AudioT:
|
986
1070
|
"""Get the mixture audio data for the given mixture ID
|
987
1071
|
|
988
1072
|
:param m_id: Zero-based mixture ID
|
989
|
-
:param
|
990
|
-
:param
|
991
|
-
:param noise:
|
1073
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1074
|
+
:param source: Post-truth, gained, and summed source audio data
|
1075
|
+
:param noise: Post-truth and gained noise audio data
|
992
1076
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1077
|
+
:param cache: Cache result
|
993
1078
|
:return: Mixture audio data
|
994
1079
|
"""
|
1080
|
+
from .data_io import write_cached_data
|
1081
|
+
|
995
1082
|
if not force:
|
996
|
-
mixture = self.read_mixture_data(m_id, "mixture")
|
1083
|
+
mixture = self.read_mixture_data(m_id, "mixture")["mixture"]
|
997
1084
|
if mixture is not None:
|
998
1085
|
return mixture
|
999
1086
|
|
1000
|
-
if
|
1001
|
-
|
1087
|
+
if source is None:
|
1088
|
+
source = self.mixture_source(m_id, sources, force)
|
1002
1089
|
|
1003
|
-
if
|
1004
|
-
noise = self.mixture_noise(m_id, force)
|
1090
|
+
if noise is None:
|
1091
|
+
noise = self.mixture_noise(m_id, sources, force)
|
1092
|
+
|
1093
|
+
mixture = source + noise
|
1005
1094
|
|
1006
|
-
|
1095
|
+
if cache:
|
1096
|
+
write_cached_data(
|
1097
|
+
location=self.location,
|
1098
|
+
name="mixture",
|
1099
|
+
index=self.mixture(m_id).name,
|
1100
|
+
items={"mixture": mixture},
|
1101
|
+
)
|
1102
|
+
|
1103
|
+
return mixture
|
1007
1104
|
|
1008
1105
|
def mixture_mixture_f(
|
1009
1106
|
self,
|
1010
1107
|
m_id: int,
|
1011
|
-
|
1012
|
-
|
1108
|
+
sources: SourcesAudioT | None = None,
|
1109
|
+
source: AudioT | None = None,
|
1013
1110
|
noise: AudioT | None = None,
|
1014
1111
|
mixture: AudioT | None = None,
|
1015
1112
|
force: bool = False,
|
1113
|
+
cache: bool = False,
|
1016
1114
|
) -> AudioF:
|
1017
1115
|
"""Get the mixture transform for the given mixture ID
|
1018
1116
|
|
1019
1117
|
:param m_id: Zero-based mixture ID
|
1020
|
-
:param
|
1021
|
-
:param
|
1022
|
-
:param noise:
|
1118
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1119
|
+
:param source: Post-truth, gained, and summed source audio data
|
1120
|
+
:param noise: Post-truth and gained noise audio data
|
1023
1121
|
:param mixture: Mixture audio data
|
1024
1122
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1123
|
+
:param cache: Cache result
|
1025
1124
|
:return: Mixture transform data
|
1026
1125
|
"""
|
1126
|
+
from .data_io import write_cached_data
|
1027
1127
|
from .helpers import forward_transform
|
1028
1128
|
from .spectral_mask import apply_spectral_mask
|
1029
1129
|
|
1030
|
-
if
|
1031
|
-
mixture = self.mixture_mixture(m_id,
|
1130
|
+
if mixture is None:
|
1131
|
+
mixture = self.mixture_mixture(m_id, sources, source, noise, force)
|
1032
1132
|
|
1033
1133
|
mixture_f = forward_transform(mixture, self.ft_config)
|
1034
1134
|
|
@@ -1040,80 +1140,79 @@ class MixtureDatabase:
|
|
1040
1140
|
seed=m.spectral_mask_seed,
|
1041
1141
|
)
|
1042
1142
|
|
1143
|
+
if cache:
|
1144
|
+
write_cached_data(
|
1145
|
+
location=self.location,
|
1146
|
+
name="mixture",
|
1147
|
+
index=self.mixture(m_id).name,
|
1148
|
+
items={"mixture_f": mixture_f},
|
1149
|
+
)
|
1150
|
+
|
1043
1151
|
return mixture_f
|
1044
1152
|
|
1045
|
-
def mixture_truth_t(
|
1046
|
-
self,
|
1047
|
-
m_id: int,
|
1048
|
-
targets: list[AudioT] | None = None,
|
1049
|
-
noise: AudioT | None = None,
|
1050
|
-
mixture: AudioT | None = None,
|
1051
|
-
force: bool = False,
|
1052
|
-
) -> list[TruthDict]:
|
1153
|
+
def mixture_truth_t(self, m_id: int, force: bool = False, cache: bool = False) -> TruthsDict:
|
1053
1154
|
"""Get the truth_t data for the given mixture ID
|
1054
1155
|
|
1055
1156
|
:param m_id: Zero-based mixture ID
|
1056
|
-
:param targets: List of augmented target audio data (one per target in the mixup) for the given mixture ID
|
1057
|
-
:param noise: Augmented noise audio data for the given mixture ID
|
1058
|
-
:param mixture: Mixture audio data for the given mixture ID
|
1059
1157
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1158
|
+
:param cache: Cache result
|
1060
1159
|
:return: list of truth_t data
|
1061
1160
|
"""
|
1161
|
+
from .data_io import write_cached_data
|
1062
1162
|
from .truth import truth_function
|
1063
1163
|
|
1064
1164
|
if not force:
|
1065
|
-
truth_t = self.read_mixture_data(m_id, "truth_t")
|
1165
|
+
truth_t = self.read_mixture_data(m_id, "truth_t")["truth_t"]
|
1066
1166
|
if truth_t is not None:
|
1067
1167
|
return truth_t
|
1068
1168
|
|
1069
|
-
|
1070
|
-
targets = self.mixture_targets(m_id, force)
|
1071
|
-
|
1072
|
-
if force or noise is None:
|
1073
|
-
noise = self.mixture_noise(m_id, force)
|
1074
|
-
|
1075
|
-
if force or mixture is None:
|
1076
|
-
mixture = self.mixture_mixture(m_id, targets=targets, noise=noise, force=force)
|
1077
|
-
|
1078
|
-
if not all(len(target) == self.mixture(m_id).samples for target in targets):
|
1079
|
-
raise ValueError("Lengths of targets do not match length of mixture")
|
1169
|
+
truth_t = truth_function(self, m_id)
|
1080
1170
|
|
1081
|
-
if
|
1082
|
-
|
1171
|
+
if cache:
|
1172
|
+
write_cached_data(
|
1173
|
+
location=self.location,
|
1174
|
+
name="mixture",
|
1175
|
+
index=self.mixture(m_id).name,
|
1176
|
+
items={"truth_t": truth_t},
|
1177
|
+
)
|
1083
1178
|
|
1084
|
-
return
|
1179
|
+
return truth_t
|
1085
1180
|
|
1086
1181
|
def mixture_segsnr_t(
|
1087
1182
|
self,
|
1088
1183
|
m_id: int,
|
1089
|
-
|
1090
|
-
|
1184
|
+
sources: SourcesAudioT | None = None,
|
1185
|
+
source: AudioT | None = None,
|
1091
1186
|
noise: AudioT | None = None,
|
1092
1187
|
force: bool = False,
|
1188
|
+
cache: bool = False,
|
1093
1189
|
) -> Segsnr:
|
1094
1190
|
"""Get the segsnr_t data for the given mixture ID
|
1095
1191
|
|
1096
1192
|
:param m_id: Zero-based mixture ID
|
1097
|
-
:param
|
1098
|
-
:param
|
1099
|
-
:param noise:
|
1193
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1194
|
+
:param source: Post-truth, gained, and summed source audio data
|
1195
|
+
:param noise: Post-truth and gained noise audio data
|
1100
1196
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1197
|
+
:param cache: Cache result
|
1101
1198
|
:return: segsnr_t data
|
1102
1199
|
"""
|
1103
1200
|
import numpy as np
|
1104
1201
|
import torch
|
1105
1202
|
from pyaaware import ForwardTransform
|
1106
1203
|
|
1204
|
+
from .data_io import write_cached_data
|
1205
|
+
|
1107
1206
|
if not force:
|
1108
|
-
segsnr_t = self.read_mixture_data(m_id, "segsnr_t")
|
1207
|
+
segsnr_t = self.read_mixture_data(m_id, "segsnr_t")["segsnr_t"]
|
1109
1208
|
if segsnr_t is not None:
|
1110
1209
|
return segsnr_t
|
1111
1210
|
|
1112
|
-
if
|
1113
|
-
|
1211
|
+
if source is None:
|
1212
|
+
source = self.mixture_source(m_id, sources, force)
|
1114
1213
|
|
1115
|
-
if
|
1116
|
-
noise = self.mixture_noise(m_id, force)
|
1214
|
+
if noise is None:
|
1215
|
+
noise = self.mixture_noise(m_id, sources, force)
|
1117
1216
|
|
1118
1217
|
ft = ForwardTransform(
|
1119
1218
|
length=self.ft_config.length,
|
@@ -1127,13 +1226,13 @@ class MixtureDatabase:
|
|
1127
1226
|
|
1128
1227
|
segsnr_t = np.empty(mixture.samples, dtype=np.float32)
|
1129
1228
|
|
1130
|
-
|
1229
|
+
source_energy = ft.execute_all(torch.from_numpy(source))[1].numpy()
|
1131
1230
|
noise_energy = ft.execute_all(torch.from_numpy(noise))[1].numpy()
|
1132
1231
|
|
1133
1232
|
offsets = range(0, mixture.samples, self.ft_config.overlap)
|
1134
|
-
if len(
|
1233
|
+
if len(source_energy) != len(offsets):
|
1135
1234
|
raise ValueError(
|
1136
|
-
f"Number of frames in energy, {len(
|
1235
|
+
f"Number of frames in energy, {len(source_energy)}, is not number of frames in mixture, {len(offsets)}"
|
1137
1236
|
)
|
1138
1237
|
|
1139
1238
|
for idx, offset in enumerate(offsets):
|
@@ -1142,187 +1241,242 @@ class MixtureDatabase:
|
|
1142
1241
|
if noise_energy[idx] == 0:
|
1143
1242
|
snr = np.float32(np.inf)
|
1144
1243
|
else:
|
1145
|
-
snr = np.float32(
|
1244
|
+
snr = np.float32(source_energy[idx] / noise_energy[idx])
|
1146
1245
|
|
1147
1246
|
segsnr_t[indices] = snr
|
1148
1247
|
|
1248
|
+
if cache:
|
1249
|
+
write_cached_data(
|
1250
|
+
location=self.location,
|
1251
|
+
name="mixture",
|
1252
|
+
index=mixture.name,
|
1253
|
+
items={"segsnr_t": segsnr_t},
|
1254
|
+
)
|
1255
|
+
|
1149
1256
|
return segsnr_t
|
1150
1257
|
|
1151
1258
|
def mixture_segsnr(
|
1152
1259
|
self,
|
1153
1260
|
m_id: int,
|
1154
1261
|
segsnr_t: Segsnr | None = None,
|
1155
|
-
|
1156
|
-
|
1262
|
+
sources: SourcesAudioT | None = None,
|
1263
|
+
source: AudioT | None = None,
|
1157
1264
|
noise: AudioT | None = None,
|
1158
1265
|
force: bool = False,
|
1266
|
+
cache: bool = False,
|
1159
1267
|
) -> Segsnr:
|
1160
1268
|
"""Get the segsnr data for the given mixture ID
|
1161
1269
|
|
1162
1270
|
:param m_id: Zero-based mixture ID
|
1163
1271
|
:param segsnr_t: segsnr_t data
|
1164
|
-
:param
|
1165
|
-
:param
|
1166
|
-
:param noise:
|
1272
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1273
|
+
:param source: Post-truth, gained, and summed source audio data
|
1274
|
+
:param noise: Post-truth and gained noise audio data
|
1167
1275
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1276
|
+
:param cache: Cache result
|
1168
1277
|
:return: segsnr data
|
1169
1278
|
"""
|
1279
|
+
from .data_io import write_cached_data
|
1280
|
+
|
1170
1281
|
if not force:
|
1171
|
-
segsnr = self.read_mixture_data(m_id, "segsnr")
|
1282
|
+
segsnr = self.read_mixture_data(m_id, "segsnr")["segsnr"]
|
1172
1283
|
if segsnr is not None:
|
1173
1284
|
return segsnr
|
1174
1285
|
|
1175
|
-
|
1176
|
-
|
1177
|
-
return segsnr_t[0 :: self.ft_config.overlap]
|
1286
|
+
if segsnr_t is None:
|
1287
|
+
segsnr_t = self.mixture_segsnr_t(m_id, sources, source, noise, force)
|
1178
1288
|
|
1179
|
-
|
1180
|
-
segsnr_t = self.mixture_segsnr_t(m_id, targets, target, noise, force)
|
1289
|
+
segsnr = segsnr_t[0 :: self.ft_config.overlap]
|
1181
1290
|
|
1182
|
-
|
1291
|
+
if cache:
|
1292
|
+
write_cached_data(
|
1293
|
+
location=self.location,
|
1294
|
+
name="mixture",
|
1295
|
+
index=self.mixture(m_id).name,
|
1296
|
+
items={"segsnr": segsnr},
|
1297
|
+
)
|
1298
|
+
|
1299
|
+
return segsnr
|
1183
1300
|
|
1184
1301
|
def mixture_ft(
|
1185
1302
|
self,
|
1186
1303
|
m_id: int,
|
1187
|
-
|
1188
|
-
|
1304
|
+
sources: SourcesAudioT | None = None,
|
1305
|
+
source: AudioT | None = None,
|
1189
1306
|
noise: AudioT | None = None,
|
1190
1307
|
mixture_f: AudioF | None = None,
|
1191
1308
|
mixture: AudioT | None = None,
|
1192
|
-
truth_t:
|
1309
|
+
truth_t: TruthsDict | None = None,
|
1193
1310
|
force: bool = False,
|
1194
|
-
|
1311
|
+
cache: bool = False,
|
1312
|
+
) -> tuple[Feature, TruthsDict]:
|
1195
1313
|
"""Get the feature and truth_f data for the given mixture ID
|
1196
1314
|
|
1197
1315
|
:param m_id: Zero-based mixture ID
|
1198
|
-
:param
|
1199
|
-
:param
|
1200
|
-
:param noise:
|
1316
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1317
|
+
:param source: Post-truth, gained, and summed source audio data
|
1318
|
+
:param noise: Post-truth and gained noise audio data
|
1201
1319
|
:param mixture_f: Mixture transform data
|
1202
1320
|
:param mixture: Mixture audio data
|
1203
1321
|
:param truth_t: truth_t
|
1204
1322
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1323
|
+
:param cache: Cache result
|
1205
1324
|
:return: Tuple of (feature, truth_f) data
|
1206
1325
|
"""
|
1207
1326
|
from pyaaware import FeatureGenerator
|
1208
1327
|
|
1328
|
+
from .data_io import write_cached_data
|
1209
1329
|
from .truth import truth_stride_reduction
|
1210
1330
|
|
1211
1331
|
if not force:
|
1212
|
-
|
1213
|
-
if feature is not None and truth_f is not None:
|
1214
|
-
return feature, truth_f
|
1332
|
+
ft = self.read_mixture_data(m_id, ["feature", "truth_f"])
|
1333
|
+
if ft["feature"] is not None and ft["truth_f"] is not None:
|
1334
|
+
return ft["feature"], ft["truth_f"]
|
1215
1335
|
|
1216
|
-
if
|
1336
|
+
if mixture_f is None:
|
1217
1337
|
mixture_f = self.mixture_mixture_f(
|
1218
1338
|
m_id=m_id,
|
1219
|
-
|
1220
|
-
|
1339
|
+
sources=sources,
|
1340
|
+
source=source,
|
1221
1341
|
noise=noise,
|
1222
1342
|
mixture=mixture,
|
1223
1343
|
force=force,
|
1224
1344
|
)
|
1225
1345
|
|
1226
|
-
if
|
1227
|
-
truth_t = self.mixture_truth_t(m_id
|
1346
|
+
if truth_t is None:
|
1347
|
+
truth_t = self.mixture_truth_t(m_id, force)
|
1228
1348
|
|
1229
1349
|
fg = FeatureGenerator(self.fg_config.feature_mode, self.fg_config.truth_parameters)
|
1230
1350
|
|
1231
|
-
|
1232
|
-
feature, truth_f = fg.execute_all(mixture_f, truth_t[0])
|
1351
|
+
feature, truth_f = fg.execute_all(mixture_f, truth_t)
|
1233
1352
|
if truth_f is not None:
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1353
|
+
truth_configs = self.mixture_truth_configs(m_id)
|
1354
|
+
for category, configs in truth_configs.items():
|
1355
|
+
for name, config in configs.items():
|
1356
|
+
if self.truth_parameters[category][name] is not None:
|
1357
|
+
truth_f[category][name] = truth_stride_reduction(
|
1358
|
+
truth_f[category][name], config.stride_reduction
|
1359
|
+
)
|
1237
1360
|
else:
|
1238
1361
|
raise TypeError("Unexpected truth of None from feature generator")
|
1239
1362
|
|
1363
|
+
if cache:
|
1364
|
+
write_cached_data(
|
1365
|
+
location=self.location,
|
1366
|
+
name="mixture",
|
1367
|
+
index=self.mixture(m_id).name,
|
1368
|
+
items={"feature": truth_f, "truth_f": truth_f},
|
1369
|
+
)
|
1370
|
+
|
1240
1371
|
return feature, truth_f
|
1241
1372
|
|
1242
1373
|
def mixture_feature(
|
1243
1374
|
self,
|
1244
1375
|
m_id: int,
|
1245
|
-
|
1376
|
+
sources: SourcesAudioT | None = None,
|
1246
1377
|
noise: AudioT | None = None,
|
1247
1378
|
mixture: AudioT | None = None,
|
1248
|
-
truth_t:
|
1379
|
+
truth_t: TruthsDict | None = None,
|
1249
1380
|
force: bool = False,
|
1381
|
+
cache: bool = False,
|
1250
1382
|
) -> Feature:
|
1251
1383
|
"""Get the feature data for the given mixture ID
|
1252
1384
|
|
1253
1385
|
:param m_id: Zero-based mixture ID
|
1254
|
-
:param
|
1255
|
-
:param noise:
|
1386
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1387
|
+
:param noise: Post-truth and gained noise audio data
|
1256
1388
|
:param mixture: Mixture audio data
|
1257
1389
|
:param truth_t: truth_t
|
1258
1390
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1391
|
+
:param cache: Cache result
|
1259
1392
|
:return: Feature data
|
1260
1393
|
"""
|
1261
|
-
|
1394
|
+
from .data_io import write_cached_data
|
1395
|
+
|
1396
|
+
feature = self.mixture_ft(
|
1262
1397
|
m_id=m_id,
|
1263
|
-
|
1398
|
+
sources=sources,
|
1264
1399
|
noise=noise,
|
1265
1400
|
mixture=mixture,
|
1266
1401
|
truth_t=truth_t,
|
1267
1402
|
force=force,
|
1268
|
-
)
|
1403
|
+
)[0]
|
1404
|
+
|
1405
|
+
if cache:
|
1406
|
+
write_cached_data(
|
1407
|
+
location=self.location,
|
1408
|
+
name="mixture",
|
1409
|
+
index=self.mixture(m_id).name,
|
1410
|
+
items={"feature": feature},
|
1411
|
+
)
|
1412
|
+
|
1269
1413
|
return feature
|
1270
1414
|
|
1271
1415
|
def mixture_truth_f(
|
1272
1416
|
self,
|
1273
1417
|
m_id: int,
|
1274
|
-
|
1418
|
+
sources: SourcesAudioT | None = None,
|
1275
1419
|
noise: AudioT | None = None,
|
1276
1420
|
mixture: AudioT | None = None,
|
1277
|
-
truth_t:
|
1421
|
+
truth_t: TruthsDict | None = None,
|
1278
1422
|
force: bool = False,
|
1423
|
+
cache: bool = False,
|
1279
1424
|
) -> TruthDict:
|
1280
1425
|
"""Get the truth_f data for the given mixture ID
|
1281
1426
|
|
1282
1427
|
:param m_id: Zero-based mixture ID
|
1283
|
-
:param
|
1284
|
-
:param noise:
|
1428
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1429
|
+
:param noise: Post-truth and gained noise audio data
|
1285
1430
|
:param mixture: Mixture audio data
|
1286
1431
|
:param truth_t: truth_t
|
1287
1432
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1433
|
+
:param cache: Cache result
|
1288
1434
|
:return: truth_f data
|
1289
1435
|
"""
|
1290
|
-
|
1436
|
+
from .data_io import write_cached_data
|
1437
|
+
|
1438
|
+
truth_f = self.mixture_ft(
|
1291
1439
|
m_id=m_id,
|
1292
|
-
|
1440
|
+
sources=sources,
|
1293
1441
|
noise=noise,
|
1294
1442
|
mixture=mixture,
|
1295
1443
|
truth_t=truth_t,
|
1296
1444
|
force=force,
|
1297
|
-
)
|
1445
|
+
)[1]
|
1446
|
+
|
1447
|
+
if cache:
|
1448
|
+
write_cached_data(
|
1449
|
+
location=self.location,
|
1450
|
+
name="mixture",
|
1451
|
+
index=self.mixture(m_id).name,
|
1452
|
+
items={"truth_f": truth_f},
|
1453
|
+
)
|
1454
|
+
|
1298
1455
|
return truth_f
|
1299
1456
|
|
1300
|
-
def mixture_class_count(
|
1301
|
-
self,
|
1302
|
-
m_id: int,
|
1303
|
-
targets: list[AudioT] | None = None,
|
1304
|
-
noise: AudioT | None = None,
|
1305
|
-
truth_t: list[TruthDict] | None = None,
|
1306
|
-
) -> ClassCount:
|
1457
|
+
def mixture_class_count(self, m_id: int, truth_t: TruthsDict | None = None) -> dict[str, ClassCount]:
|
1307
1458
|
"""Compute the number of frames for which each class index is active for the given mixture ID
|
1308
1459
|
|
1309
1460
|
:param m_id: Zero-based mixture ID
|
1310
|
-
:param targets: List of augmented target audio (one per target in the mixup)
|
1311
|
-
:param noise: Augmented noise audio
|
1312
1461
|
:param truth_t: truth_t
|
1313
|
-
:return:
|
1462
|
+
:return: Dictionary of class counts
|
1314
1463
|
"""
|
1315
1464
|
import numpy as np
|
1316
1465
|
|
1317
1466
|
if truth_t is None:
|
1318
|
-
truth_t = self.mixture_truth_t(m_id
|
1467
|
+
truth_t = self.mixture_truth_t(m_id)
|
1319
1468
|
|
1320
|
-
class_count
|
1321
|
-
|
1322
|
-
|
1323
|
-
|
1324
|
-
|
1325
|
-
|
1469
|
+
class_count: dict[str, ClassCount] = {}
|
1470
|
+
|
1471
|
+
truth_configs = self.mixture_truth_configs(m_id)
|
1472
|
+
for category in truth_configs:
|
1473
|
+
class_count[category] = [0] * self.num_classes
|
1474
|
+
for configs in truth_configs[category]:
|
1475
|
+
if "sed" in configs:
|
1476
|
+
for cl in range(self.num_classes):
|
1477
|
+
class_count[category][cl] = int(
|
1478
|
+
np.sum(truth_t[category]["sed"][:, cl] >= self.class_weights_thresholds[cl])
|
1479
|
+
)
|
1326
1480
|
|
1327
1481
|
return class_count
|
1328
1482
|
|
@@ -1348,57 +1502,56 @@ class MixtureDatabase:
|
|
1348
1502
|
return _speaker(self.db, s_id, tier, self.use_cache)
|
1349
1503
|
|
1350
1504
|
def speech_metadata(self, tier: str) -> list[str]:
|
1351
|
-
from .helpers import
|
1505
|
+
from .helpers import get_textgrid_tier_from_source_file
|
1352
1506
|
|
1353
1507
|
results: set[str] = set()
|
1354
1508
|
if tier in self.textgrid_metadata_tiers:
|
1355
|
-
for
|
1356
|
-
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1360
|
-
|
1361
|
-
|
1362
|
-
|
1363
|
-
|
1509
|
+
for source_files in self.source_files.values():
|
1510
|
+
for source_file in source_files:
|
1511
|
+
data = get_textgrid_tier_from_source_file(source_file.name, tier)
|
1512
|
+
if data is None:
|
1513
|
+
continue
|
1514
|
+
if isinstance(data, list):
|
1515
|
+
for item in data:
|
1516
|
+
results.add(item.label)
|
1517
|
+
else:
|
1518
|
+
results.add(data)
|
1364
1519
|
elif tier in self.speaker_metadata_tiers:
|
1365
|
-
for
|
1366
|
-
|
1367
|
-
|
1368
|
-
|
1520
|
+
for source_files in self.source_files.values():
|
1521
|
+
for source_file in source_files:
|
1522
|
+
data = self.speaker(source_file.speaker_id, tier)
|
1523
|
+
if data is not None:
|
1524
|
+
results.add(data)
|
1369
1525
|
|
1370
1526
|
return sorted(results)
|
1371
1527
|
|
1372
|
-
def mixture_speech_metadata(self, mixid: int, tier: str) ->
|
1528
|
+
def mixture_speech_metadata(self, mixid: int, tier: str) -> dict[str, SpeechMetadata]:
|
1373
1529
|
from praatio.utilities.constants import Interval
|
1374
1530
|
|
1375
|
-
from .helpers import
|
1531
|
+
from .helpers import get_textgrid_tier_from_source_file
|
1376
1532
|
|
1377
|
-
results:
|
1533
|
+
results: dict[str, SpeechMetadata] = {}
|
1378
1534
|
is_textgrid = tier in self.textgrid_metadata_tiers
|
1379
1535
|
if is_textgrid:
|
1380
|
-
for
|
1381
|
-
data =
|
1536
|
+
for category, source in self.mixture(mixid).all_sources.items():
|
1537
|
+
data = get_textgrid_tier_from_source_file(self.source_file(source.file_id).name, tier)
|
1382
1538
|
if isinstance(data, list):
|
1383
|
-
# Check for tempo
|
1539
|
+
# Check for tempo effect and adjust Interval start and end data as needed
|
1384
1540
|
entries = []
|
1385
1541
|
for entry in data:
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1391
|
-
entry.label,
|
1392
|
-
)
|
1542
|
+
entries.append(
|
1543
|
+
Interval(
|
1544
|
+
entry.start / source.pre_tempo,
|
1545
|
+
entry.end / source.pre_tempo,
|
1546
|
+
entry.label,
|
1393
1547
|
)
|
1394
|
-
|
1395
|
-
|
1396
|
-
results.append(entries)
|
1548
|
+
)
|
1549
|
+
results[category] = entries
|
1397
1550
|
else:
|
1398
|
-
results
|
1551
|
+
results[category] = data
|
1399
1552
|
else:
|
1400
|
-
for
|
1401
|
-
results
|
1553
|
+
for category, source in self.mixture(mixid).all_sources.items():
|
1554
|
+
results[category] = self.speaker(self.source_file(source.file_id).speaker_id, tier)
|
1402
1555
|
|
1403
1556
|
return results
|
1404
1557
|
|
@@ -1407,7 +1560,7 @@ class MixtureDatabase:
|
|
1407
1560
|
tier: str | None = None,
|
1408
1561
|
value: str | None = None,
|
1409
1562
|
where: str | None = None,
|
1410
|
-
) -> list[int]:
|
1563
|
+
) -> dict[str, list[int]]:
|
1411
1564
|
"""Get a list of mixture IDs for the given speech metadata tier.
|
1412
1565
|
|
1413
1566
|
If 'where' is None, then include mixture IDs whose tier values are equal to the given 'value'.
|
@@ -1441,16 +1594,29 @@ class MixtureDatabase:
|
|
1441
1594
|
results = c.execute(f"SELECT id FROM speaker WHERE {where}").fetchall()
|
1442
1595
|
speaker_ids = ",".join(map(str, [i[0] for i in results]))
|
1443
1596
|
|
1444
|
-
results = c.execute(f"SELECT id FROM
|
1445
|
-
|
1597
|
+
results = c.execute(f"SELECT id, category FROM source_file WHERE speaker_id IN ({speaker_ids})").fetchall()
|
1598
|
+
source_file_ids: dict[str, list[int]] = {}
|
1599
|
+
for result in results:
|
1600
|
+
source_file_id, category = result
|
1601
|
+
if category not in source_file_ids:
|
1602
|
+
source_file_ids[category] = [source_file_id]
|
1603
|
+
else:
|
1604
|
+
source_file_ids[category].append(source_file_id)
|
1446
1605
|
|
1447
|
-
|
1448
|
-
|
1449
|
-
|
1606
|
+
mixids: dict[str, list[int]] = {}
|
1607
|
+
for category in source_file_ids:
|
1608
|
+
id_str = ",".join(map(str, source_file_ids[category]))
|
1609
|
+
results = c.execute(f"SELECT id FROM source WHERE file_id IN ({id_str})").fetchall()
|
1610
|
+
source_ids = ",".join(map(str, [i[0] for i in results]))
|
1450
1611
|
|
1451
|
-
|
1612
|
+
results = c.execute(
|
1613
|
+
f"SELECT mixture_id FROM mixture_source WHERE source_id IN ({source_ids})"
|
1614
|
+
).fetchall()
|
1615
|
+
mixids[category] = [mixture_id[0] - 1 for mixture_id in results]
|
1452
1616
|
|
1453
|
-
|
1617
|
+
return mixids
|
1618
|
+
|
1619
|
+
def mixture_all_speech_metadata(self, m_id: int) -> dict[str, dict[str, SpeechMetadata]]:
|
1454
1620
|
from .helpers import mixture_all_speech_metadata
|
1455
1621
|
|
1456
1622
|
return mixture_all_speech_metadata(self, self.mixture(m_id))
|
@@ -1483,63 +1649,65 @@ class MixtureDatabase:
|
|
1483
1649
|
:param m_id: Zero-based mixture ID
|
1484
1650
|
:param metrics: List of metrics to get
|
1485
1651
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1486
|
-
:return:
|
1652
|
+
:return: Dictionary of metric data
|
1487
1653
|
"""
|
1488
1654
|
from collections.abc import Callable
|
1489
1655
|
|
1490
1656
|
import numpy as np
|
1491
1657
|
from pystoi import stoi
|
1492
1658
|
|
1493
|
-
from
|
1494
|
-
from
|
1495
|
-
from
|
1496
|
-
from
|
1497
|
-
from
|
1498
|
-
from
|
1499
|
-
from
|
1500
|
-
from
|
1501
|
-
from
|
1502
|
-
from
|
1503
|
-
from
|
1504
|
-
|
1505
|
-
|
1506
|
-
|
1507
|
-
|
1508
|
-
|
1659
|
+
from ..constants import SAMPLE_RATE
|
1660
|
+
from ..datatypes import AudioStatsMetrics
|
1661
|
+
from ..datatypes import SpeechMetrics
|
1662
|
+
from ..metrics.calc_audio_stats import calc_audio_stats
|
1663
|
+
from ..metrics.calc_pesq import calc_pesq
|
1664
|
+
from ..metrics.calc_phase_distance import calc_phase_distance
|
1665
|
+
from ..metrics.calc_segsnr_f import calc_segsnr_f
|
1666
|
+
from ..metrics.calc_segsnr_f import calc_segsnr_f_bin
|
1667
|
+
from ..metrics.calc_speech import calc_speech
|
1668
|
+
from ..metrics.calc_wer import calc_wer
|
1669
|
+
from ..metrics.calc_wsdr import calc_wsdr
|
1670
|
+
from ..utils.asr import calc_asr
|
1671
|
+
from ..utils.db import linear_to_db
|
1672
|
+
|
1673
|
+
def create_sources_audio() -> Callable[[], dict[str, AudioT]]:
|
1674
|
+
state: dict[str, AudioT] | None = None
|
1675
|
+
|
1676
|
+
def get() -> dict[str, AudioT]:
|
1509
1677
|
nonlocal state
|
1510
1678
|
if state is None:
|
1511
|
-
state = self.
|
1679
|
+
state = self.mixture_sources(m_id)
|
1512
1680
|
return state
|
1513
1681
|
|
1514
1682
|
return get
|
1515
1683
|
|
1516
|
-
|
1684
|
+
sources_audio = create_sources_audio()
|
1517
1685
|
|
1518
|
-
def
|
1686
|
+
def create_source_audio() -> Callable[[], AudioT]:
|
1519
1687
|
state: AudioT | None = None
|
1520
1688
|
|
1521
1689
|
def get() -> AudioT:
|
1522
1690
|
nonlocal state
|
1523
1691
|
if state is None:
|
1524
|
-
state = self.
|
1692
|
+
state = self.mixture_source(m_id)
|
1525
1693
|
return state
|
1526
1694
|
|
1527
1695
|
return get
|
1528
1696
|
|
1529
|
-
|
1697
|
+
source_audio = create_source_audio()
|
1530
1698
|
|
1531
|
-
def
|
1699
|
+
def create_source_f() -> Callable[[], AudioF]:
|
1532
1700
|
state: AudioF | None = None
|
1533
1701
|
|
1534
1702
|
def get() -> AudioF:
|
1535
1703
|
nonlocal state
|
1536
1704
|
if state is None:
|
1537
|
-
state = self.
|
1705
|
+
state = self.mixture_source_f(m_id)
|
1538
1706
|
return state
|
1539
1707
|
|
1540
1708
|
return get
|
1541
1709
|
|
1542
|
-
|
1710
|
+
source_f = create_source_f()
|
1543
1711
|
|
1544
1712
|
def create_noise_audio() -> Callable[[], AudioT]:
|
1545
1713
|
state: AudioT | None = None
|
@@ -1593,15 +1761,29 @@ class MixtureDatabase:
|
|
1593
1761
|
|
1594
1762
|
segsnr_f = create_segsnr_f()
|
1595
1763
|
|
1596
|
-
def
|
1597
|
-
state:
|
1764
|
+
def create_pesq() -> Callable[[], dict[str, float]]:
|
1765
|
+
state: dict[str, float] | None = None
|
1598
1766
|
|
1599
|
-
def get() ->
|
1767
|
+
def get() -> dict[str, float]:
|
1600
1768
|
nonlocal state
|
1601
1769
|
if state is None:
|
1602
|
-
state =
|
1603
|
-
|
1604
|
-
|
1770
|
+
state = {category: calc_pesq(mixture_audio(), audio) for category, audio in sources_audio().items()}
|
1771
|
+
return state
|
1772
|
+
|
1773
|
+
return get
|
1774
|
+
|
1775
|
+
pesq = create_pesq()
|
1776
|
+
|
1777
|
+
def create_speech() -> Callable[[], dict[str, SpeechMetrics]]:
|
1778
|
+
state: dict[str, SpeechMetrics] | None = None
|
1779
|
+
|
1780
|
+
def get() -> dict[str, SpeechMetrics]:
|
1781
|
+
nonlocal state
|
1782
|
+
if state is None:
|
1783
|
+
state = {
|
1784
|
+
category: calc_speech(mixture_audio(), audio, pesq()[category])
|
1785
|
+
for category, audio in sources_audio().items()
|
1786
|
+
}
|
1605
1787
|
return state
|
1606
1788
|
|
1607
1789
|
return get
|
@@ -1621,33 +1803,34 @@ class MixtureDatabase:
|
|
1621
1803
|
|
1622
1804
|
mixture_stats = create_mixture_stats()
|
1623
1805
|
|
1624
|
-
def
|
1625
|
-
state:
|
1806
|
+
def create_sources_stats() -> Callable[[], dict[str, AudioStatsMetrics]]:
|
1807
|
+
state: dict[str, AudioStatsMetrics] | None = None
|
1626
1808
|
|
1627
|
-
def get() ->
|
1809
|
+
def get() -> dict[str, AudioStatsMetrics]:
|
1628
1810
|
nonlocal state
|
1629
1811
|
if state is None:
|
1630
|
-
state =
|
1631
|
-
|
1632
|
-
|
1812
|
+
state = {
|
1813
|
+
category: calc_audio_stats(audio, self.fg_info.ft_config.length / SAMPLE_RATE)
|
1814
|
+
for category, audio in sources_audio().items()
|
1815
|
+
}
|
1633
1816
|
return state
|
1634
1817
|
|
1635
1818
|
return get
|
1636
1819
|
|
1637
|
-
|
1820
|
+
sources_stats = create_sources_stats()
|
1638
1821
|
|
1639
|
-
def
|
1822
|
+
def create_source_stats() -> Callable[[], AudioStatsMetrics]:
|
1640
1823
|
state: AudioStatsMetrics | None = None
|
1641
1824
|
|
1642
1825
|
def get() -> AudioStatsMetrics:
|
1643
1826
|
nonlocal state
|
1644
1827
|
if state is None:
|
1645
|
-
state = calc_audio_stats(
|
1828
|
+
state = calc_audio_stats(source_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
|
1646
1829
|
return state
|
1647
1830
|
|
1648
1831
|
return get
|
1649
1832
|
|
1650
|
-
|
1833
|
+
source_stats = create_source_stats()
|
1651
1834
|
|
1652
1835
|
def create_noise_stats() -> Callable[[], AudioStatsMetrics]:
|
1653
1836
|
state: AudioStatsMetrics | None = None
|
@@ -1678,33 +1861,34 @@ class MixtureDatabase:
|
|
1678
1861
|
|
1679
1862
|
asr_config = create_asr_config()
|
1680
1863
|
|
1681
|
-
def
|
1682
|
-
state: dict[str,
|
1864
|
+
def create_sources_asr() -> Callable[[str], dict[str, str]]:
|
1865
|
+
state: dict[str, dict[str, str]] = {}
|
1683
1866
|
|
1684
|
-
def get(asr_name) ->
|
1867
|
+
def get(asr_name) -> dict[str, str]:
|
1685
1868
|
nonlocal state
|
1686
1869
|
if asr_name not in state:
|
1687
|
-
state[asr_name] =
|
1688
|
-
|
1689
|
-
|
1870
|
+
state[asr_name] = {
|
1871
|
+
category: calc_asr(audio, **asr_config(asr_name)).text
|
1872
|
+
for category, audio in sources_audio().items()
|
1873
|
+
}
|
1690
1874
|
return state[asr_name]
|
1691
1875
|
|
1692
1876
|
return get
|
1693
1877
|
|
1694
|
-
|
1878
|
+
sources_asr = create_sources_asr()
|
1695
1879
|
|
1696
|
-
def
|
1880
|
+
def create_source_asr() -> Callable[[str], str]:
|
1697
1881
|
state: dict[str, str] = {}
|
1698
1882
|
|
1699
1883
|
def get(asr_name) -> str:
|
1700
1884
|
nonlocal state
|
1701
1885
|
if asr_name not in state:
|
1702
|
-
state[asr_name] = calc_asr(
|
1886
|
+
state[asr_name] = calc_asr(source_audio(), **asr_config(asr_name)).text
|
1703
1887
|
return state[asr_name]
|
1704
1888
|
|
1705
1889
|
return get
|
1706
1890
|
|
1707
|
-
|
1891
|
+
source_asr = create_source_asr()
|
1708
1892
|
|
1709
1893
|
def create_mixture_asr() -> Callable[[str], str]:
|
1710
1894
|
state: dict[str, str] = {}
|
@@ -1728,11 +1912,11 @@ class MixtureDatabase:
|
|
1728
1912
|
|
1729
1913
|
def calc(m: str) -> Any:
|
1730
1914
|
if m == "mxsnr":
|
1731
|
-
return self.mixture(m_id).
|
1915
|
+
return {category: source.snr for category, source in self.mixture(m_id).all_sources.items()}
|
1732
1916
|
|
1733
1917
|
# Get cached data first, if exists
|
1734
1918
|
if not force:
|
1735
|
-
value = self.read_mixture_data(m_id, m)
|
1919
|
+
value = self.read_mixture_data(m_id, m)[m]
|
1736
1920
|
if value is not None:
|
1737
1921
|
return value
|
1738
1922
|
|
@@ -1744,8 +1928,8 @@ class MixtureDatabase:
|
|
1744
1928
|
# noise only, ignore/reset target asr
|
1745
1929
|
return float("nan")
|
1746
1930
|
|
1747
|
-
if
|
1748
|
-
return calc_wer(mixture_asr(asr_name),
|
1931
|
+
if source_asr(asr_name):
|
1932
|
+
return calc_wer(mixture_asr(asr_name), source_asr(asr_name)).wer * 100
|
1749
1933
|
|
1750
1934
|
# TODO: should this be NaN like above?
|
1751
1935
|
return float(0)
|
@@ -1753,12 +1937,14 @@ class MixtureDatabase:
|
|
1753
1937
|
if m.startswith("basewer"):
|
1754
1938
|
asr_name = get_asr_name(m)
|
1755
1939
|
|
1756
|
-
text = self.mixture_speech_metadata(m_id, "text")
|
1757
|
-
|
1758
|
-
|
1759
|
-
|
1760
|
-
|
1761
|
-
|
1940
|
+
text = self.mixture_speech_metadata(m_id, "text")
|
1941
|
+
base_wer: dict[str, float] = {}
|
1942
|
+
for category, source in sources_asr(asr_name).items():
|
1943
|
+
if isinstance(text[category], str):
|
1944
|
+
base_wer[category] = calc_wer(source, str(text[category])).wer * 100
|
1945
|
+
else:
|
1946
|
+
base_wer[category] = 0
|
1947
|
+
return base_wer
|
1762
1948
|
|
1763
1949
|
if m.startswith("mxasr"):
|
1764
1950
|
return mixture_asr(get_asr_name(m))
|
@@ -1769,6 +1955,18 @@ class MixtureDatabase:
|
|
1769
1955
|
if m == "mxssnr_std":
|
1770
1956
|
return calc_segsnr_f(segsnr_f()).std
|
1771
1957
|
|
1958
|
+
if m == "mxssnr_avg_db":
|
1959
|
+
val = calc_segsnr_f(segsnr_f()).avg
|
1960
|
+
if val is not None:
|
1961
|
+
return linear_to_db(val)
|
1962
|
+
return None
|
1963
|
+
|
1964
|
+
if m == "mxssnr_std_db":
|
1965
|
+
val = calc_segsnr_f(segsnr_f()).std
|
1966
|
+
if val is not None:
|
1967
|
+
return linear_to_db(val)
|
1968
|
+
return None
|
1969
|
+
|
1772
1970
|
if m == "mxssnrdb_avg":
|
1773
1971
|
return calc_segsnr_f(segsnr_f()).db_avg
|
1774
1972
|
|
@@ -1776,40 +1974,40 @@ class MixtureDatabase:
|
|
1776
1974
|
return calc_segsnr_f(segsnr_f()).db_std
|
1777
1975
|
|
1778
1976
|
if m == "mxssnrf_avg":
|
1779
|
-
return calc_segsnr_f_bin(
|
1977
|
+
return calc_segsnr_f_bin(source_f(), noise_f()).avg
|
1780
1978
|
|
1781
1979
|
if m == "mxssnrf_std":
|
1782
|
-
return calc_segsnr_f_bin(
|
1980
|
+
return calc_segsnr_f_bin(source_f(), noise_f()).std
|
1783
1981
|
|
1784
1982
|
if m == "mxssnrdbf_avg":
|
1785
|
-
return calc_segsnr_f_bin(
|
1983
|
+
return calc_segsnr_f_bin(source_f(), noise_f()).db_avg
|
1786
1984
|
|
1787
1985
|
if m == "mxssnrdbf_std":
|
1788
|
-
return calc_segsnr_f_bin(
|
1986
|
+
return calc_segsnr_f_bin(source_f(), noise_f()).db_std
|
1789
1987
|
|
1790
1988
|
if m == "mxpesq":
|
1791
1989
|
if self.mixture(m_id).is_noise_only:
|
1792
|
-
return
|
1793
|
-
return
|
1990
|
+
return dict.fromkeys(pesq(), 0)
|
1991
|
+
return pesq()
|
1794
1992
|
|
1795
1993
|
if m == "mxcsig":
|
1796
1994
|
if self.mixture(m_id).is_noise_only:
|
1797
|
-
return
|
1798
|
-
return
|
1995
|
+
return dict.fromkeys(speech(), 0)
|
1996
|
+
return {category: s.csig for category, s in speech().items()}
|
1799
1997
|
|
1800
1998
|
if m == "mxcbak":
|
1801
1999
|
if self.mixture(m_id).is_noise_only:
|
1802
|
-
return
|
1803
|
-
return
|
2000
|
+
return dict.fromkeys(speech(), 0)
|
2001
|
+
return {category: s.cbak for category, s in speech().items()}
|
1804
2002
|
|
1805
2003
|
if m == "mxcovl":
|
1806
2004
|
if self.mixture(m_id).is_noise_only:
|
1807
|
-
return
|
1808
|
-
return
|
2005
|
+
return dict.fromkeys(speech(), 0)
|
2006
|
+
return {category: s.covl for category, s in speech().items()}
|
1809
2007
|
|
1810
2008
|
if m == "mxwsdr":
|
1811
2009
|
mixture = mixture_audio()[:, np.newaxis]
|
1812
|
-
target =
|
2010
|
+
target = source_audio()[:, np.newaxis]
|
1813
2011
|
noise = noise_audio()[:, np.newaxis]
|
1814
2012
|
return calc_wsdr(
|
1815
2013
|
hypothesis=np.concatenate((mixture, noise), axis=1),
|
@@ -1819,11 +2017,11 @@ class MixtureDatabase:
|
|
1819
2017
|
|
1820
2018
|
if m == "mxpd":
|
1821
2019
|
mixture_f = self.mixture_mixture_f(m_id)
|
1822
|
-
return calc_phase_distance(hypothesis=mixture_f, reference=
|
2020
|
+
return calc_phase_distance(hypothesis=mixture_f, reference=source_f())[0]
|
1823
2021
|
|
1824
2022
|
if m == "mxstoi":
|
1825
2023
|
return stoi(
|
1826
|
-
x=
|
2024
|
+
x=source_audio(),
|
1827
2025
|
y=mixture_audio(),
|
1828
2026
|
fs_sig=SAMPLE_RATE,
|
1829
2027
|
extended=False,
|
@@ -1860,70 +2058,70 @@ class MixtureDatabase:
|
|
1860
2058
|
return mixture_stats().pkc
|
1861
2059
|
|
1862
2060
|
if m == "mxtdco":
|
1863
|
-
return
|
2061
|
+
return source_stats().dco
|
1864
2062
|
|
1865
2063
|
if m == "mxtmin":
|
1866
|
-
return
|
2064
|
+
return source_stats().min
|
1867
2065
|
|
1868
2066
|
if m == "mxtmax":
|
1869
|
-
return
|
2067
|
+
return source_stats().max
|
1870
2068
|
|
1871
2069
|
if m == "mxtpkdb":
|
1872
|
-
return
|
2070
|
+
return source_stats().pkdb
|
1873
2071
|
|
1874
2072
|
if m == "mxtlrms":
|
1875
|
-
return
|
2073
|
+
return source_stats().lrms
|
1876
2074
|
|
1877
2075
|
if m == "mxtpkr":
|
1878
|
-
return
|
2076
|
+
return source_stats().pkr
|
1879
2077
|
|
1880
2078
|
if m == "mxttr":
|
1881
|
-
return
|
2079
|
+
return source_stats().tr
|
1882
2080
|
|
1883
2081
|
if m == "mxtcr":
|
1884
|
-
return
|
2082
|
+
return source_stats().cr
|
1885
2083
|
|
1886
2084
|
if m == "mxtfl":
|
1887
|
-
return
|
2085
|
+
return source_stats().fl
|
1888
2086
|
|
1889
2087
|
if m == "mxtpkc":
|
1890
|
-
return
|
2088
|
+
return source_stats().pkc
|
1891
2089
|
|
1892
|
-
if m == "
|
1893
|
-
return
|
2090
|
+
if m == "sdco":
|
2091
|
+
return {category: s.dco for category, s in sources_stats().items()}
|
1894
2092
|
|
1895
|
-
if m == "
|
1896
|
-
return
|
2093
|
+
if m == "smin":
|
2094
|
+
return {category: s.min for category, s in sources_stats().items()}
|
1897
2095
|
|
1898
|
-
if m == "
|
1899
|
-
return
|
2096
|
+
if m == "smax":
|
2097
|
+
return {category: s.max for category, s in sources_stats().items()}
|
1900
2098
|
|
1901
|
-
if m == "
|
1902
|
-
return
|
2099
|
+
if m == "spkdb":
|
2100
|
+
return {category: s.pkdb for category, s in sources_stats().items()}
|
1903
2101
|
|
1904
|
-
if m == "
|
1905
|
-
return
|
2102
|
+
if m == "slrms":
|
2103
|
+
return {category: s.lrms for category, s in sources_stats().items()}
|
1906
2104
|
|
1907
|
-
if m == "
|
1908
|
-
return
|
2105
|
+
if m == "spkr":
|
2106
|
+
return {category: s.pkr for category, s in sources_stats().items()}
|
1909
2107
|
|
1910
|
-
if m == "
|
1911
|
-
return
|
2108
|
+
if m == "str":
|
2109
|
+
return {category: s.tr for category, s in sources_stats().items()}
|
1912
2110
|
|
1913
|
-
if m == "
|
1914
|
-
return
|
2111
|
+
if m == "scr":
|
2112
|
+
return {category: s.cr for category, s in sources_stats().items()}
|
1915
2113
|
|
1916
|
-
if m == "
|
1917
|
-
return
|
2114
|
+
if m == "sfl":
|
2115
|
+
return {category: s.fl for category, s in sources_stats().items()}
|
1918
2116
|
|
1919
|
-
if m == "
|
1920
|
-
return
|
2117
|
+
if m == "spkc":
|
2118
|
+
return {category: s.pkc for category, s in sources_stats().items()}
|
1921
2119
|
|
1922
|
-
if m.startswith("
|
1923
|
-
return
|
2120
|
+
if m.startswith("sasr"):
|
2121
|
+
return sources_asr(get_asr_name(m))
|
1924
2122
|
|
1925
|
-
if m.startswith("
|
1926
|
-
return
|
2123
|
+
if m.startswith("mxsasr"):
|
2124
|
+
return source_asr(get_asr_name(m))
|
1927
2125
|
|
1928
2126
|
if m == "ndco":
|
1929
2127
|
return noise_stats().dco
|
@@ -2003,16 +2201,7 @@ def __spectral_mask(db: partial, sm_id: int) -> SpectralMask:
|
|
2003
2201
|
from .db_datatypes import SpectralMaskRecord
|
2004
2202
|
|
2005
2203
|
with db() as c:
|
2006
|
-
spectral_mask = SpectralMaskRecord(
|
2007
|
-
*c.execute(
|
2008
|
-
"""
|
2009
|
-
SELECT *
|
2010
|
-
FROM spectral_mask
|
2011
|
-
WHERE ? = spectral_mask.id
|
2012
|
-
""",
|
2013
|
-
(sm_id,),
|
2014
|
-
).fetchone()
|
2015
|
-
)
|
2204
|
+
spectral_mask = SpectralMaskRecord(*c.execute("SELECT * FROM spectral_mask WHERE ? = id", (sm_id,)).fetchone())
|
2016
2205
|
return SpectralMask(
|
2017
2206
|
f_max_width=spectral_mask.f_max_width,
|
2018
2207
|
f_num=spectral_mask.f_num,
|
@@ -2022,82 +2211,72 @@ def __spectral_mask(db: partial, sm_id: int) -> SpectralMask:
|
|
2022
2211
|
)
|
2023
2212
|
|
2024
2213
|
|
2025
|
-
def
|
2026
|
-
"""Get
|
2214
|
+
def _num_source_files(db: partial, category: str, use_cache: bool = True) -> int:
|
2215
|
+
"""Get number of source files from category from db
|
2027
2216
|
|
2028
2217
|
:param db: Database context
|
2029
|
-
:param
|
2218
|
+
:param category: Source category
|
2030
2219
|
:param use_cache: If true, use LRU caching
|
2031
|
-
:return:
|
2220
|
+
:return: Number of source files
|
2032
2221
|
"""
|
2033
2222
|
if use_cache:
|
2034
|
-
return
|
2035
|
-
return
|
2223
|
+
return __num_source_files(db, category)
|
2224
|
+
return __num_source_files.__wrapped__(db, category)
|
2036
2225
|
|
2037
2226
|
|
2038
2227
|
@lru_cache
|
2039
|
-
def
|
2040
|
-
"""Get
|
2228
|
+
def __num_source_files(db: partial, category: str) -> int:
|
2229
|
+
"""Get number of source files from category from db
|
2041
2230
|
|
2042
2231
|
:param db: Database context
|
2043
|
-
:param
|
2044
|
-
:
|
2045
|
-
:return: Target file
|
2232
|
+
:param category: Source category
|
2233
|
+
:return: Number of source files
|
2046
2234
|
"""
|
2047
|
-
import json
|
2048
|
-
|
2049
|
-
from .db_datatypes import TargetFileRecord
|
2050
|
-
|
2051
2235
|
with db() as c:
|
2052
|
-
|
2053
|
-
*c.execute(
|
2054
|
-
"""
|
2055
|
-
SELECT *
|
2056
|
-
FROM target_file
|
2057
|
-
WHERE ? = target_file.id
|
2058
|
-
""",
|
2059
|
-
(t_id,),
|
2060
|
-
).fetchone()
|
2061
|
-
)
|
2062
|
-
|
2063
|
-
return TargetFile(
|
2064
|
-
name=target_file.name,
|
2065
|
-
samples=target_file.samples,
|
2066
|
-
class_indices=json.loads(target_file.class_indices),
|
2067
|
-
level_type=target_file.level_type,
|
2068
|
-
truth_configs=_target_truth_configs(db, t_id, use_cache),
|
2069
|
-
speaker_id=target_file.speaker_id,
|
2070
|
-
)
|
2236
|
+
return int(c.execute("SELECT count(id) FROM source_file WHERE ? = category", (category,)).fetchone()[0])
|
2071
2237
|
|
2072
2238
|
|
2073
|
-
def
|
2074
|
-
"""Get
|
2239
|
+
def _source_file(db: partial, s_id: int, use_cache: bool = True) -> SourceFile:
|
2240
|
+
"""Get source file with ID from db
|
2075
2241
|
|
2076
2242
|
:param db: Database context
|
2077
|
-
:param
|
2243
|
+
:param s_id: Source file ID
|
2078
2244
|
:param use_cache: If true, use LRU caching
|
2079
|
-
:return:
|
2245
|
+
:return: Source file
|
2080
2246
|
"""
|
2081
2247
|
if use_cache:
|
2082
|
-
return
|
2083
|
-
return
|
2248
|
+
return __source_file(db, s_id, use_cache)
|
2249
|
+
return __source_file.__wrapped__(db, s_id, use_cache)
|
2084
2250
|
|
2085
2251
|
|
2086
2252
|
@lru_cache
|
2087
|
-
def
|
2253
|
+
def __source_file(db: partial, s_id: int, use_cache: bool = True) -> SourceFile:
|
2254
|
+
"""Get source file with ID from db
|
2255
|
+
|
2256
|
+
:param db: Database context
|
2257
|
+
:param s_id: Source file ID
|
2258
|
+
:param use_cache: If true, use LRU caching
|
2259
|
+
:return: Source file
|
2260
|
+
"""
|
2261
|
+
import json
|
2262
|
+
|
2263
|
+
from .db_datatypes import SourceFileRecord
|
2264
|
+
|
2088
2265
|
with db() as c:
|
2089
|
-
|
2090
|
-
|
2091
|
-
|
2092
|
-
|
2093
|
-
|
2094
|
-
|
2095
|
-
(
|
2096
|
-
|
2097
|
-
|
2266
|
+
source_file = SourceFileRecord(*c.execute("SELECT * FROM source_file WHERE ? = id", (s_id,)).fetchone())
|
2267
|
+
|
2268
|
+
return SourceFile(
|
2269
|
+
category=source_file.category,
|
2270
|
+
name=source_file.name,
|
2271
|
+
samples=source_file.samples,
|
2272
|
+
class_indices=json.loads(source_file.class_indices),
|
2273
|
+
level_type=source_file.level_type,
|
2274
|
+
truth_configs=_source_truth_configs(db, s_id, use_cache),
|
2275
|
+
speaker_id=source_file.speaker_id,
|
2276
|
+
)
|
2098
2277
|
|
2099
2278
|
|
2100
|
-
def
|
2279
|
+
def _ir_file(db: partial, ir_id: int, use_cache: bool = True) -> str:
|
2101
2280
|
"""Get impulse response file name with ID from db
|
2102
2281
|
|
2103
2282
|
:param db: Database context
|
@@ -2106,26 +2285,17 @@ def _impulse_response_file(db: partial, ir_id: int, use_cache: bool = True) -> s
|
|
2106
2285
|
:return: Impulse response file name
|
2107
2286
|
"""
|
2108
2287
|
if use_cache:
|
2109
|
-
return
|
2110
|
-
return
|
2288
|
+
return __ir_file(db, ir_id)
|
2289
|
+
return __ir_file.__wrapped__(db, ir_id)
|
2111
2290
|
|
2112
2291
|
|
2113
2292
|
@lru_cache
|
2114
|
-
def
|
2293
|
+
def __ir_file(db: partial, ir_id: int) -> str:
|
2115
2294
|
with db() as c:
|
2116
|
-
return str(
|
2117
|
-
c.execute(
|
2118
|
-
"""
|
2119
|
-
SELECT impulse_response_file.file
|
2120
|
-
FROM impulse_response_file
|
2121
|
-
WHERE ? = impulse_response_file.id
|
2122
|
-
""",
|
2123
|
-
(ir_id + 1,),
|
2124
|
-
).fetchone()[0]
|
2125
|
-
)
|
2295
|
+
return str(c.execute("SELECT name FROM ir_file WHERE ? = id ", (ir_id + 1,)).fetchone()[0])
|
2126
2296
|
|
2127
2297
|
|
2128
|
-
def
|
2298
|
+
def _ir_delay(db: partial, ir_id: int, use_cache: bool = True) -> int:
|
2129
2299
|
"""Get impulse response delay with ID from db
|
2130
2300
|
|
2131
2301
|
:param db: Database context
|
@@ -2134,23 +2304,14 @@ def _impulse_response_delay(db: partial, ir_id: int, use_cache: bool = True) ->
|
|
2134
2304
|
:return: Impulse response delay
|
2135
2305
|
"""
|
2136
2306
|
if use_cache:
|
2137
|
-
return
|
2138
|
-
return
|
2307
|
+
return __ir_delay(db, ir_id)
|
2308
|
+
return __ir_delay.__wrapped__(db, ir_id)
|
2139
2309
|
|
2140
2310
|
|
2141
2311
|
@lru_cache
|
2142
|
-
def
|
2312
|
+
def __ir_delay(db: partial, ir_id: int) -> int:
|
2143
2313
|
with db() as c:
|
2144
|
-
return int(
|
2145
|
-
c.execute(
|
2146
|
-
"""
|
2147
|
-
SELECT impulse_response_file.delay
|
2148
|
-
FROM impulse_response_file
|
2149
|
-
WHERE ? = impulse_response_file.id
|
2150
|
-
""",
|
2151
|
-
(ir_id + 1,),
|
2152
|
-
).fetchone()[0]
|
2153
|
-
)
|
2314
|
+
return int(c.execute("SELECT delay FROM ir_file WHERE ? = id", (ir_id + 1,)).fetchone()[0])
|
2154
2315
|
|
2155
2316
|
|
2156
2317
|
def _mixture(db: partial, m_id: int, use_cache: bool = True) -> Mixture:
|
@@ -2169,35 +2330,27 @@ def _mixture(db: partial, m_id: int, use_cache: bool = True) -> Mixture:
|
|
2169
2330
|
@lru_cache
|
2170
2331
|
def __mixture(db: partial, m_id: int) -> Mixture:
|
2171
2332
|
from .db_datatypes import MixtureRecord
|
2172
|
-
from .db_datatypes import
|
2333
|
+
from .db_datatypes import SourceRecord
|
2173
2334
|
from .helpers import to_mixture
|
2174
|
-
from .helpers import
|
2335
|
+
from .helpers import to_source
|
2175
2336
|
|
2176
2337
|
with db() as c:
|
2177
|
-
mixture = MixtureRecord(
|
2178
|
-
*c.execute(
|
2179
|
-
"""
|
2180
|
-
SELECT *
|
2181
|
-
FROM mixture
|
2182
|
-
WHERE ? = mixture.id
|
2183
|
-
""",
|
2184
|
-
(m_id + 1,),
|
2185
|
-
).fetchone()
|
2186
|
-
)
|
2338
|
+
mixture = MixtureRecord(*c.execute("SELECT * FROM mixture WHERE ? = id", (m_id + 1,)).fetchone())
|
2187
2339
|
|
2188
|
-
|
2189
|
-
|
2190
|
-
|
2191
|
-
|
2192
|
-
|
2193
|
-
|
2194
|
-
WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id
|
2340
|
+
sources: Sources = {}
|
2341
|
+
for source in c.execute(
|
2342
|
+
"""
|
2343
|
+
SELECT source.*
|
2344
|
+
FROM source, mixture_source
|
2345
|
+
WHERE ? = mixture_source.mixture_id AND source.id = mixture_source.source_id
|
2195
2346
|
""",
|
2196
|
-
|
2197
|
-
|
2198
|
-
|
2347
|
+
(mixture.id,),
|
2348
|
+
).fetchall():
|
2349
|
+
s = SourceRecord(*source)
|
2350
|
+
category = c.execute("SELECT category FROM source_file WHERE ? = id", (s.file_id,)).fetchone()[0]
|
2351
|
+
sources[category] = to_source(s)
|
2199
2352
|
|
2200
|
-
return to_mixture(mixture,
|
2353
|
+
return to_mixture(mixture, sources)
|
2201
2354
|
|
2202
2355
|
|
2203
2356
|
def _speaker(db: partial, s_id: int | None, tier: str, use_cache: bool = True) -> str | None:
|
@@ -2220,27 +2373,55 @@ def __speaker(db: partial, s_id: int | None, tier: str) -> str | None:
|
|
2220
2373
|
return data[0]
|
2221
2374
|
|
2222
2375
|
|
2223
|
-
def
|
2376
|
+
def _category_truth_configs(db: partial, category: str, use_cache: bool = True) -> dict[str, str]:
|
2377
|
+
if use_cache:
|
2378
|
+
return __category_truth_configs(db, category)
|
2379
|
+
return __category_truth_configs.__wrapped__(db, category)
|
2380
|
+
|
2381
|
+
|
2382
|
+
@lru_cache
|
2383
|
+
def __category_truth_configs(db: partial, category: str) -> dict[str, str]:
|
2384
|
+
import json
|
2385
|
+
|
2386
|
+
truth_configs: dict[str, str] = {}
|
2387
|
+
with db() as c:
|
2388
|
+
s_ids = c.execute("SELECT id FROM source_file WHERE ? = category", (category,)).fetchall()
|
2389
|
+
|
2390
|
+
for s_id in s_ids:
|
2391
|
+
for truth_config_record in c.execute(
|
2392
|
+
"""
|
2393
|
+
SELECT truth_config.config
|
2394
|
+
FROM truth_config, source_file_truth_config
|
2395
|
+
WHERE ? = source_file_truth_config.source_file_id AND truth_config.id = source_file_truth_config.truth_config_id
|
2396
|
+
""",
|
2397
|
+
(s_id[0],),
|
2398
|
+
).fetchall():
|
2399
|
+
truth_config = json.loads(truth_config_record[0])
|
2400
|
+
truth_configs[truth_config["name"]] = truth_config["function"]
|
2401
|
+
return truth_configs
|
2402
|
+
|
2403
|
+
|
2404
|
+
def _source_truth_configs(db: partial, s_id: int, use_cache: bool = True) -> TruthConfigs:
|
2224
2405
|
if use_cache:
|
2225
|
-
return
|
2226
|
-
return
|
2406
|
+
return __source_truth_configs(db, s_id)
|
2407
|
+
return __source_truth_configs.__wrapped__(db, s_id)
|
2227
2408
|
|
2228
2409
|
|
2229
2410
|
@lru_cache
|
2230
|
-
def
|
2411
|
+
def __source_truth_configs(db: partial, s_id: int) -> TruthConfigs:
|
2231
2412
|
import json
|
2232
2413
|
|
2233
|
-
from
|
2414
|
+
from ..datatypes import TruthConfig
|
2234
2415
|
|
2235
2416
|
truth_configs: TruthConfigs = {}
|
2236
2417
|
with db() as c:
|
2237
2418
|
for truth_config_record in c.execute(
|
2238
2419
|
"""
|
2239
2420
|
SELECT truth_config.config
|
2240
|
-
FROM truth_config,
|
2241
|
-
WHERE ? =
|
2421
|
+
FROM truth_config, source_file_truth_config
|
2422
|
+
WHERE ? = source_file_truth_config.source_file_id AND truth_config.id = source_file_truth_config.truth_config_id
|
2242
2423
|
""",
|
2243
|
-
(
|
2424
|
+
(s_id,),
|
2244
2425
|
).fetchall():
|
2245
2426
|
truth_config = json.loads(truth_config_record[0])
|
2246
2427
|
truth_configs[truth_config["name"]] = TruthConfig(
|