sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
sonusai/mixture/mixdb.py
ADDED
@@ -0,0 +1,1857 @@
|
|
1
|
+
# ruff: noqa: S608
|
2
|
+
from functools import cached_property
|
3
|
+
from functools import lru_cache
|
4
|
+
from functools import partial
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
from ..datatypes import ASRConfigs
|
8
|
+
from ..datatypes import AudioF
|
9
|
+
from ..datatypes import AudioT
|
10
|
+
from ..datatypes import ClassCount
|
11
|
+
from ..datatypes import Feature
|
12
|
+
from ..datatypes import FeatureGeneratorConfig
|
13
|
+
from ..datatypes import FeatureGeneratorInfo
|
14
|
+
from ..datatypes import GeneralizedIDs
|
15
|
+
from ..datatypes import ImpulseResponseFile
|
16
|
+
from ..datatypes import MetricDoc
|
17
|
+
from ..datatypes import MetricDocs
|
18
|
+
from ..datatypes import Mixture
|
19
|
+
from ..datatypes import Segsnr
|
20
|
+
from ..datatypes import SourceFile
|
21
|
+
from ..datatypes import Sources
|
22
|
+
from ..datatypes import SourcesAudioF
|
23
|
+
from ..datatypes import SourcesAudioT
|
24
|
+
from ..datatypes import SpectralMask
|
25
|
+
from ..datatypes import SpeechMetadata
|
26
|
+
from ..datatypes import TransformConfig
|
27
|
+
from ..datatypes import TruthConfigs
|
28
|
+
from ..datatypes import TruthDict
|
29
|
+
from ..datatypes import TruthsConfigs
|
30
|
+
from ..datatypes import TruthsDict
|
31
|
+
from ..datatypes import UniversalSNR
|
32
|
+
from .db import SQLiteDatabase
|
33
|
+
from .db import db_file
|
34
|
+
|
35
|
+
|
36
|
+
class MixtureDatabase:
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
location: str,
|
40
|
+
test: bool = False,
|
41
|
+
verbose: bool = False,
|
42
|
+
use_cache: bool = True,
|
43
|
+
) -> None:
|
44
|
+
self.location = location
|
45
|
+
self.test = test
|
46
|
+
self.db_path = db_file(location=self.location, test=self.test)
|
47
|
+
self.verbose = verbose
|
48
|
+
self.use_cache = use_cache
|
49
|
+
|
50
|
+
self.db = partial(SQLiteDatabase, location=self.location, test=self.test, verbose=self.verbose)
|
51
|
+
|
52
|
+
# Update ASR configs
|
53
|
+
self.update_asr_configs()
|
54
|
+
|
55
|
+
def update_asr_configs(self) -> None:
|
56
|
+
"""Update the asr_configs column in the top table with the current asr_configs in the config.yml file."""
|
57
|
+
import json
|
58
|
+
|
59
|
+
from ..config.config import load_config
|
60
|
+
|
61
|
+
# Check config.yml to see if asr_configs has changed and update the database if needed
|
62
|
+
config = load_config(self.location)
|
63
|
+
new_asr_configs = json.dumps(config["asr_configs"])
|
64
|
+
with SQLiteDatabase(
|
65
|
+
location=self.location,
|
66
|
+
readonly=False,
|
67
|
+
test=self.test,
|
68
|
+
verbose=self.verbose,
|
69
|
+
) as c:
|
70
|
+
old_asr_configs = c.execute("SELECT asr_configs FROM top").fetchone()
|
71
|
+
|
72
|
+
if old_asr_configs is not None and new_asr_configs != old_asr_configs[0]:
|
73
|
+
c.execute("UPDATE top SET asr_configs = ? WHERE ? = id", (new_asr_configs,))
|
74
|
+
|
75
|
+
@cached_property
|
76
|
+
def json(self) -> str:
|
77
|
+
from ..datatypes import MixtureDatabaseConfig
|
78
|
+
|
79
|
+
config = MixtureDatabaseConfig(
|
80
|
+
asr_configs=self.asr_configs,
|
81
|
+
class_balancing=self.class_balancing,
|
82
|
+
class_labels=self.class_labels,
|
83
|
+
class_weights_threshold=self.class_weights_thresholds,
|
84
|
+
feature=self.feature,
|
85
|
+
ir_files=self.ir_files,
|
86
|
+
mixtures=self.mixtures,
|
87
|
+
num_classes=self.num_classes,
|
88
|
+
spectral_masks=self.spectral_masks,
|
89
|
+
source_files=self.source_files,
|
90
|
+
)
|
91
|
+
return config.to_json(indent=2)
|
92
|
+
|
93
|
+
def save(self) -> None:
|
94
|
+
"""Save the MixtureDatabase as a JSON file"""
|
95
|
+
from os.path import join
|
96
|
+
|
97
|
+
json_name = join(self.location, "mixdb.json")
|
98
|
+
with open(file=json_name, mode="w") as file:
|
99
|
+
file.write(self.json)
|
100
|
+
|
101
|
+
@cached_property
|
102
|
+
def fg_config(self) -> FeatureGeneratorConfig:
|
103
|
+
return FeatureGeneratorConfig(
|
104
|
+
feature_mode=self.feature,
|
105
|
+
truth_parameters=self.truth_parameters,
|
106
|
+
)
|
107
|
+
|
108
|
+
@cached_property
|
109
|
+
def fg_info(self) -> FeatureGeneratorInfo:
|
110
|
+
from .helpers import get_feature_generator_info
|
111
|
+
|
112
|
+
return get_feature_generator_info(self.fg_config)
|
113
|
+
|
114
|
+
@cached_property
|
115
|
+
def truth_parameters(self) -> dict[str, dict[str, int | None]]:
|
116
|
+
with self.db() as c:
|
117
|
+
rows = c.execute("SELECT category, name, parameters FROM truth_parameters").fetchall()
|
118
|
+
truth_parameters: dict[str, dict[str, int | None]] = {}
|
119
|
+
for row in rows:
|
120
|
+
category, name, parameters = row
|
121
|
+
if category not in truth_parameters:
|
122
|
+
truth_parameters[category] = {}
|
123
|
+
truth_parameters[category][name] = parameters
|
124
|
+
return truth_parameters
|
125
|
+
|
126
|
+
@cached_property
|
127
|
+
def num_classes(self) -> int:
|
128
|
+
with self.db() as c:
|
129
|
+
return int(c.execute("SELECT num_classes FROM top").fetchone()[0])
|
130
|
+
|
131
|
+
@cached_property
|
132
|
+
def asr_configs(self) -> ASRConfigs:
|
133
|
+
import json
|
134
|
+
|
135
|
+
with self.db() as c:
|
136
|
+
return json.loads(c.execute("SELECT asr_configs FROM top").fetchone()[0])
|
137
|
+
|
138
|
+
@cached_property
|
139
|
+
def supported_metrics(self) -> MetricDocs:
|
140
|
+
metrics = MetricDocs(
|
141
|
+
[
|
142
|
+
MetricDoc("Mixture Metrics", "mxsnr", "SNR specification in dB"),
|
143
|
+
MetricDoc(
|
144
|
+
"Mixture Metrics",
|
145
|
+
"mxssnr_avg",
|
146
|
+
"Segmental SNR average over all frames",
|
147
|
+
),
|
148
|
+
MetricDoc(
|
149
|
+
"Mixture Metrics",
|
150
|
+
"mxssnr_std",
|
151
|
+
"Segmental SNR standard deviation over all frames",
|
152
|
+
),
|
153
|
+
MetricDoc(
|
154
|
+
"Mixture Metrics",
|
155
|
+
"mxssnrdb_avg",
|
156
|
+
"Segmental SNR average of the dB frame values over all frames",
|
157
|
+
),
|
158
|
+
MetricDoc(
|
159
|
+
"Mixture Metrics",
|
160
|
+
"mxssnrdb_std",
|
161
|
+
"Segmental SNR standard deviation of the dB frame values over all frames",
|
162
|
+
),
|
163
|
+
MetricDoc(
|
164
|
+
"Mixture Metrics",
|
165
|
+
"mxssnrf_avg",
|
166
|
+
"Per-bin segmental SNR average over all frames (using feature transform)",
|
167
|
+
),
|
168
|
+
MetricDoc(
|
169
|
+
"Mixture Metrics",
|
170
|
+
"mxssnrf_std",
|
171
|
+
"Per-bin segmental SNR standard deviation over all frames (using feature transform)",
|
172
|
+
),
|
173
|
+
MetricDoc(
|
174
|
+
"Mixture Metrics",
|
175
|
+
"mxssnrdbf_avg",
|
176
|
+
"Per-bin segmental average of the dB frame values over all frames (using feature transform)",
|
177
|
+
),
|
178
|
+
MetricDoc(
|
179
|
+
"Mixture Metrics",
|
180
|
+
"mxssnrdbf_std",
|
181
|
+
"Per-bin segmental standard deviation of the dB frame values over all frames (using feature transform)",
|
182
|
+
),
|
183
|
+
MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true sources"),
|
184
|
+
MetricDoc(
|
185
|
+
"Mixture Metrics",
|
186
|
+
"mxwsdr",
|
187
|
+
"Weighted signal distortion ratio of mixture versus true sources",
|
188
|
+
),
|
189
|
+
MetricDoc(
|
190
|
+
"Mixture Metrics",
|
191
|
+
"mxpd",
|
192
|
+
"Phase distance between mixture and true sources",
|
193
|
+
),
|
194
|
+
MetricDoc(
|
195
|
+
"Mixture Metrics",
|
196
|
+
"mxstoi",
|
197
|
+
"Short term objective intelligibility of mixture versus true sources",
|
198
|
+
),
|
199
|
+
MetricDoc(
|
200
|
+
"Mixture Metrics",
|
201
|
+
"mxcsig",
|
202
|
+
"Predicted rating of speech distortion of mixture versus true sources",
|
203
|
+
),
|
204
|
+
MetricDoc(
|
205
|
+
"Mixture Metrics",
|
206
|
+
"mxcbak",
|
207
|
+
"Predicted rating of background distortion of mixture versus true sources",
|
208
|
+
),
|
209
|
+
MetricDoc(
|
210
|
+
"Mixture Metrics",
|
211
|
+
"mxcovl",
|
212
|
+
"Predicted rating of overall quality of mixture versus true sources",
|
213
|
+
),
|
214
|
+
MetricDoc("Mixture Metrics", "ssnr", "Segmental SNR"),
|
215
|
+
MetricDoc("Mixture Metrics", "mxdco", "Mixture DC offset"),
|
216
|
+
MetricDoc("Mixture Metrics", "mxmin", "Mixture min level"),
|
217
|
+
MetricDoc("Mixture Metrics", "mxmax", "Mixture max levl"),
|
218
|
+
MetricDoc("Mixture Metrics", "mxpkdb", "Mixture Pk lev dB"),
|
219
|
+
MetricDoc("Mixture Metrics", "mxlrms", "Mixture RMS lev dB"),
|
220
|
+
MetricDoc("Mixture Metrics", "mxpkr", "Mixture RMS Pk dB"),
|
221
|
+
MetricDoc("Mixture Metrics", "mxtr", "Mixture RMS Tr dB"),
|
222
|
+
MetricDoc("Mixture Metrics", "mxcr", "Mixture Crest factor"),
|
223
|
+
MetricDoc("Mixture Metrics", "mxfl", "Mixture Flat factor"),
|
224
|
+
MetricDoc("Mixture Metrics", "mxpkc", "Mixture Pk count"),
|
225
|
+
MetricDoc("Sources Metrics", "sdco", "Sources DC offset"),
|
226
|
+
MetricDoc("Sources Metrics", "smin", "Sources min level"),
|
227
|
+
MetricDoc("Sources Metrics", "smax", "Sources max levl"),
|
228
|
+
MetricDoc("Sources Metrics", "spkdb", "Sources Pk lev dB"),
|
229
|
+
MetricDoc("Sources Metrics", "slrms", "Sources RMS lev dB"),
|
230
|
+
MetricDoc("Sources Metrics", "spkr", "Sources RMS Pk dB"),
|
231
|
+
MetricDoc("Sources Metrics", "str", "Sources RMS Tr dB"),
|
232
|
+
MetricDoc("Sources Metrics", "scr", "Sources Crest factor"),
|
233
|
+
MetricDoc("Sources Metrics", "sfl", "Sources Flat factor"),
|
234
|
+
MetricDoc("Sources Metrics", "spkc", "Sources Pk count"),
|
235
|
+
MetricDoc("Source Metrics", "mxsdco", "Source DC offset"),
|
236
|
+
MetricDoc("Source Metrics", "mxsmin", "Source min level"),
|
237
|
+
MetricDoc("Source Metrics", "mxsmax", "Source max levl"),
|
238
|
+
MetricDoc("Source Metrics", "mxspkdb", "Source Pk lev dB"),
|
239
|
+
MetricDoc("Source Metrics", "mxslrms", "Source RMS lev dB"),
|
240
|
+
MetricDoc("Source Metrics", "mxspkr", "Source RMS Pk dB"),
|
241
|
+
MetricDoc("Source Metrics", "mxstr", "Source RMS Tr dB"),
|
242
|
+
MetricDoc("Source Metrics", "mxscr", "Source Crest factor"),
|
243
|
+
MetricDoc("Source Metrics", "mxsfl", "Source Flat factor"),
|
244
|
+
MetricDoc("Source Metrics", "mxspkc", "Source Pk count"),
|
245
|
+
MetricDoc("Noise Metrics", "ndco", "Noise DC offset"),
|
246
|
+
MetricDoc("Noise Metrics", "nmin", "Noise min level"),
|
247
|
+
MetricDoc("Noise Metrics", "nmax", "Noise max levl"),
|
248
|
+
MetricDoc("Noise Metrics", "npkdb", "Noise Pk lev dB"),
|
249
|
+
MetricDoc("Noise Metrics", "nlrms", "Noise RMS lev dB"),
|
250
|
+
MetricDoc("Noise Metrics", "npkr", "Noise RMS Pk dB"),
|
251
|
+
MetricDoc("Noise Metrics", "ntr", "Noise RMS Tr dB"),
|
252
|
+
MetricDoc("Noise Metrics", "ncr", "Noise Crest factor"),
|
253
|
+
MetricDoc("Noise Metrics", "nfl", "Noise Flat factor"),
|
254
|
+
MetricDoc("Noise Metrics", "npkc", "Noise Pk count"),
|
255
|
+
MetricDoc(
|
256
|
+
"Truth Metrics",
|
257
|
+
"sedavg",
|
258
|
+
"(not implemented) Average SED activity over all frames [truth_parameters, 1]",
|
259
|
+
),
|
260
|
+
MetricDoc(
|
261
|
+
"Truth Metrics",
|
262
|
+
"sedcnt",
|
263
|
+
"(not implemented) Count in number of frames that SED is active [truth_parameters, 1]",
|
264
|
+
),
|
265
|
+
MetricDoc(
|
266
|
+
"Truth Metrics",
|
267
|
+
"sedtop3",
|
268
|
+
"(not implemented) 3 most active by largest sedavg [3, 1]",
|
269
|
+
),
|
270
|
+
MetricDoc(
|
271
|
+
"Truth Metrics",
|
272
|
+
"sedtopn",
|
273
|
+
"(not implemented) N most active by largest sedavg [N, 1]",
|
274
|
+
),
|
275
|
+
]
|
276
|
+
)
|
277
|
+
for name in self.asr_configs:
|
278
|
+
metrics.append(
|
279
|
+
MetricDoc(
|
280
|
+
"Source Metrics",
|
281
|
+
f"mxsasr.{name}",
|
282
|
+
f"Source ASR text using {name} ASR as defined in mixdb asr_configs parameter",
|
283
|
+
)
|
284
|
+
)
|
285
|
+
metrics.append(
|
286
|
+
MetricDoc(
|
287
|
+
"Sources Metrics",
|
288
|
+
f"sasr.{name}",
|
289
|
+
f"Sources ASR text using {name} ASR as defined in mixdb asr_configs parameter",
|
290
|
+
)
|
291
|
+
)
|
292
|
+
metrics.append(
|
293
|
+
MetricDoc(
|
294
|
+
"Mixture Metrics",
|
295
|
+
f"mxasr.{name}",
|
296
|
+
f"ASR text using {name} ASR as defined in mixdb asr_configs parameter",
|
297
|
+
)
|
298
|
+
)
|
299
|
+
metrics.append(
|
300
|
+
MetricDoc(
|
301
|
+
"Sources Metrics",
|
302
|
+
f"basewer.{name}",
|
303
|
+
f"Word error rate of sasr.{name} vs. speech text metadata for the source",
|
304
|
+
)
|
305
|
+
)
|
306
|
+
metrics.append(
|
307
|
+
MetricDoc(
|
308
|
+
"Mixture Metrics",
|
309
|
+
f"mxwer.{name}",
|
310
|
+
f"Word error rate of mxasr.{name} vs. sasr.{name}",
|
311
|
+
)
|
312
|
+
)
|
313
|
+
|
314
|
+
return metrics
|
315
|
+
|
316
|
+
@cached_property
|
317
|
+
def class_balancing(self) -> bool:
|
318
|
+
with self.db() as c:
|
319
|
+
return bool(c.execute("SELECT class_balancing FROM top").fetchone()[0])
|
320
|
+
|
321
|
+
@cached_property
|
322
|
+
def feature(self) -> str:
|
323
|
+
with self.db() as c:
|
324
|
+
return str(c.execute("SELECT feature FROM top").fetchone()[0])
|
325
|
+
|
326
|
+
@cached_property
|
327
|
+
def fg_decimation(self) -> int:
|
328
|
+
return self.fg_info.decimation
|
329
|
+
|
330
|
+
@cached_property
|
331
|
+
def fg_stride(self) -> int:
|
332
|
+
return self.fg_info.stride
|
333
|
+
|
334
|
+
@cached_property
|
335
|
+
def fg_step(self) -> int:
|
336
|
+
return self.fg_info.step
|
337
|
+
|
338
|
+
@cached_property
|
339
|
+
def feature_parameters(self) -> int:
|
340
|
+
return self.fg_info.feature_parameters
|
341
|
+
|
342
|
+
@cached_property
|
343
|
+
def ft_config(self) -> TransformConfig:
|
344
|
+
return self.fg_info.ft_config
|
345
|
+
|
346
|
+
@cached_property
|
347
|
+
def eft_config(self) -> TransformConfig:
|
348
|
+
return self.fg_info.eft_config
|
349
|
+
|
350
|
+
@cached_property
|
351
|
+
def it_config(self) -> TransformConfig:
|
352
|
+
return self.fg_info.it_config
|
353
|
+
|
354
|
+
@cached_property
|
355
|
+
def transform_frame_ms(self) -> float:
|
356
|
+
from ..constants import SAMPLE_RATE
|
357
|
+
|
358
|
+
return float(self.ft_config.overlap) / float(SAMPLE_RATE / 1000)
|
359
|
+
|
360
|
+
@cached_property
|
361
|
+
def feature_ms(self) -> float:
|
362
|
+
return self.transform_frame_ms * self.fg_decimation * self.fg_stride
|
363
|
+
|
364
|
+
@cached_property
|
365
|
+
def feature_samples(self) -> int:
|
366
|
+
return self.ft_config.overlap * self.fg_decimation * self.fg_stride
|
367
|
+
|
368
|
+
@cached_property
|
369
|
+
def feature_step_ms(self) -> float:
|
370
|
+
return self.transform_frame_ms * self.fg_decimation * self.fg_step
|
371
|
+
|
372
|
+
@cached_property
|
373
|
+
def feature_step_samples(self) -> int:
|
374
|
+
return self.ft_config.overlap * self.fg_decimation * self.fg_step
|
375
|
+
|
376
|
+
def total_samples(self, m_ids: GeneralizedIDs = "*") -> int:
|
377
|
+
return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(m_ids)])
|
378
|
+
|
379
|
+
def total_transform_frames(self, m_ids: GeneralizedIDs = "*") -> int:
|
380
|
+
return self.total_samples(m_ids) // self.ft_config.overlap
|
381
|
+
|
382
|
+
def total_feature_frames(self, m_ids: GeneralizedIDs = "*") -> int:
|
383
|
+
return self.total_samples(m_ids) // self.feature_step_samples
|
384
|
+
|
385
|
+
def mixture_transform_frames(self, m_id: int) -> int:
|
386
|
+
from .helpers import frames_from_samples
|
387
|
+
|
388
|
+
return frames_from_samples(self.mixture(m_id).samples, self.ft_config.overlap)
|
389
|
+
|
390
|
+
def mixture_feature_frames(self, m_id: int) -> int:
|
391
|
+
from .helpers import frames_from_samples
|
392
|
+
|
393
|
+
return frames_from_samples(self.mixture(m_id).samples, self.feature_step_samples)
|
394
|
+
|
395
|
+
def mixids_to_list(self, m_ids: GeneralizedIDs = "*") -> list[int]:
|
396
|
+
"""Resolve generalized mixture IDs to a list of integers
|
397
|
+
|
398
|
+
:param m_ids: Generalized mixture IDs
|
399
|
+
:return: List of mixture ID integers
|
400
|
+
"""
|
401
|
+
from .helpers import generic_ids_to_list
|
402
|
+
|
403
|
+
return generic_ids_to_list(self.num_mixtures, m_ids)
|
404
|
+
|
405
|
+
@cached_property
|
406
|
+
def class_labels(self) -> list[str]:
|
407
|
+
"""Get class labels from db
|
408
|
+
|
409
|
+
:return: Class labels
|
410
|
+
"""
|
411
|
+
with self.db() as c:
|
412
|
+
return [str(item[0]) for item in c.execute("SELECT label FROM class_label ORDER BY id").fetchall()]
|
413
|
+
|
414
|
+
@cached_property
|
415
|
+
def class_weights_thresholds(self) -> list[float]:
|
416
|
+
"""Get class weights thresholds from db
|
417
|
+
|
418
|
+
:return: Class weights thresholds
|
419
|
+
"""
|
420
|
+
with self.db() as c:
|
421
|
+
return [float(item[0]) for item in c.execute("SELECT threshold FROM class_weights_threshold").fetchall()]
|
422
|
+
|
423
|
+
def category_truth_configs(self, category: str) -> dict[str, str]:
|
424
|
+
return _category_truth_configs(self.db, category, self.use_cache)
|
425
|
+
|
426
|
+
def source_truth_configs(self, s_id: int) -> TruthConfigs:
|
427
|
+
return _source_truth_configs(self.db, s_id, self.use_cache)
|
428
|
+
|
429
|
+
def mixture_truth_configs(self, m_id: int) -> TruthsConfigs:
|
430
|
+
mixture = self.mixture(m_id)
|
431
|
+
return {
|
432
|
+
category: self.source_truth_configs(mixture.all_sources[category].file_id)
|
433
|
+
for category in mixture.all_sources
|
434
|
+
}
|
435
|
+
|
436
|
+
@cached_property
|
437
|
+
def random_snrs(self) -> list[float]:
|
438
|
+
"""Get random snrs from db
|
439
|
+
|
440
|
+
:return: Random SNRs
|
441
|
+
"""
|
442
|
+
with self.db() as c:
|
443
|
+
return list(
|
444
|
+
{float(item[0]) for item in c.execute("SELECT snr FROM source WHERE snr_random == 1").fetchall()}
|
445
|
+
)
|
446
|
+
|
447
|
+
@cached_property
|
448
|
+
def snrs(self) -> list[float]:
|
449
|
+
"""Get snrs from db
|
450
|
+
|
451
|
+
:return: SNRs
|
452
|
+
"""
|
453
|
+
with self.db() as c:
|
454
|
+
return list(
|
455
|
+
{float(item[0]) for item in c.execute("SELECT snr FROM source WHERE snr_random == 0").fetchall()}
|
456
|
+
)
|
457
|
+
|
458
|
+
@cached_property
|
459
|
+
def all_snrs(self) -> list[UniversalSNR]:
|
460
|
+
return sorted(
|
461
|
+
set(
|
462
|
+
[UniversalSNR(is_random=False, value=snr) for snr in self.snrs]
|
463
|
+
+ [UniversalSNR(is_random=True, value=snr) for snr in self.random_snrs]
|
464
|
+
)
|
465
|
+
)
|
466
|
+
|
467
|
+
@cached_property
|
468
|
+
def spectral_masks(self) -> list[SpectralMask]:
|
469
|
+
"""Get spectral masks from db
|
470
|
+
|
471
|
+
:return: Spectral masks
|
472
|
+
"""
|
473
|
+
from .db_datatypes import SpectralMaskRecord
|
474
|
+
|
475
|
+
with self.db() as c:
|
476
|
+
spectral_masks = [
|
477
|
+
SpectralMaskRecord(*result) for result in c.execute("SELECT * FROM spectral_mask").fetchall()
|
478
|
+
]
|
479
|
+
return [
|
480
|
+
SpectralMask(
|
481
|
+
f_max_width=spectral_mask.f_max_width,
|
482
|
+
f_num=spectral_mask.f_num,
|
483
|
+
t_max_width=spectral_mask.t_max_width,
|
484
|
+
t_num=spectral_mask.t_num,
|
485
|
+
t_max_percent=spectral_mask.t_max_percent,
|
486
|
+
)
|
487
|
+
for spectral_mask in spectral_masks
|
488
|
+
]
|
489
|
+
|
490
|
+
def spectral_mask(self, sm_id: int) -> SpectralMask:
|
491
|
+
"""Get spectral mask with ID from db
|
492
|
+
|
493
|
+
:param sm_id: Spectral mask ID
|
494
|
+
:return: Spectral mask
|
495
|
+
"""
|
496
|
+
return _spectral_mask(self.db, sm_id, self.use_cache)
|
497
|
+
|
498
|
+
@cached_property
|
499
|
+
def source_files(self) -> dict[str, list[SourceFile]]:
|
500
|
+
"""Get source files from db
|
501
|
+
|
502
|
+
:return: Source files
|
503
|
+
"""
|
504
|
+
import json
|
505
|
+
|
506
|
+
from ..datatypes import TruthConfig
|
507
|
+
from ..datatypes import TruthConfigs
|
508
|
+
from .db_datatypes import SourceFileRecord
|
509
|
+
|
510
|
+
with self.db() as c:
|
511
|
+
source_files: dict[str, list[SourceFile]] = {}
|
512
|
+
categories = c.execute("SELECT DISTINCT category FROM source_file").fetchall()
|
513
|
+
for category in categories:
|
514
|
+
source_files[category[0]] = []
|
515
|
+
source_file_records = [
|
516
|
+
SourceFileRecord(*result)
|
517
|
+
for result in c.execute("SELECT * FROM source_file WHERE ? = category", (category[0],)).fetchall()
|
518
|
+
]
|
519
|
+
for source_file_record in source_file_records:
|
520
|
+
truth_configs: TruthConfigs = {}
|
521
|
+
for truth_config_records in c.execute(
|
522
|
+
"""
|
523
|
+
SELECT truth_config.config
|
524
|
+
FROM truth_config, source_file_truth_config
|
525
|
+
WHERE ? = source_file_truth_config.source_file_id
|
526
|
+
AND truth_config.id = source_file_truth_config.truth_config_id
|
527
|
+
""",
|
528
|
+
(source_file_record.id,),
|
529
|
+
).fetchall():
|
530
|
+
truth_config = json.loads(truth_config_records[0])
|
531
|
+
truth_configs[truth_config["name"]] = TruthConfig(
|
532
|
+
function=truth_config["function"],
|
533
|
+
stride_reduction=truth_config["stride_reduction"],
|
534
|
+
config=truth_config["config"],
|
535
|
+
)
|
536
|
+
source_files[source_file_record.category].append(
|
537
|
+
SourceFile(
|
538
|
+
id=source_file_record.id,
|
539
|
+
category=source_file_record.category,
|
540
|
+
name=source_file_record.name,
|
541
|
+
samples=source_file_record.samples,
|
542
|
+
class_indices=json.loads(source_file_record.class_indices),
|
543
|
+
level_type=source_file_record.level_type,
|
544
|
+
truth_configs=truth_configs,
|
545
|
+
speaker_id=source_file_record.speaker_id,
|
546
|
+
)
|
547
|
+
)
|
548
|
+
return source_files
|
549
|
+
|
550
|
+
@cached_property
|
551
|
+
def source_file_ids(self) -> dict[str, list[int]]:
|
552
|
+
"""Get source file IDs from db
|
553
|
+
|
554
|
+
:return: Dictionary of a list of source file IDs
|
555
|
+
"""
|
556
|
+
with self.db() as c:
|
557
|
+
source_file_ids: dict[str, list[int]] = {}
|
558
|
+
categories = c.execute("SELECT DISTINCT category FROM source_file").fetchall()
|
559
|
+
for category in categories:
|
560
|
+
source_file_ids[category[0]] = [
|
561
|
+
int(item[0])
|
562
|
+
for item in c.execute("SELECT id FROM source_file WHERE ? = category", (category[0],)).fetchall()
|
563
|
+
]
|
564
|
+
return source_file_ids
|
565
|
+
|
566
|
+
def source_file(self, s_id: int) -> SourceFile:
|
567
|
+
"""Get the source file with ID from db
|
568
|
+
|
569
|
+
:param s_id: Source file ID
|
570
|
+
:return: Source file
|
571
|
+
"""
|
572
|
+
return _source_file(self.db, s_id, self.use_cache)
|
573
|
+
|
574
|
+
def num_source_files(self, category: str) -> int:
|
575
|
+
"""Get the number of source files from the category from db
|
576
|
+
|
577
|
+
:param category: Source category
|
578
|
+
:return: Number of source files
|
579
|
+
"""
|
580
|
+
return _num_source_files(self.db, category, self.use_cache)
|
581
|
+
|
582
|
+
@cached_property
|
583
|
+
def ir_files(self) -> list[ImpulseResponseFile]:
|
584
|
+
"""Get impulse response files from db
|
585
|
+
|
586
|
+
:return: Impulse response files
|
587
|
+
"""
|
588
|
+
from .db_datatypes import ImpulseResponseFileRecord
|
589
|
+
|
590
|
+
with self.db() as c:
|
591
|
+
files: list[ImpulseResponseFile] = []
|
592
|
+
entries = c.execute("SELECT * FROM ir_file").fetchall()
|
593
|
+
for entry in entries:
|
594
|
+
file = ImpulseResponseFileRecord(*entry)
|
595
|
+
|
596
|
+
tags = [
|
597
|
+
tag[0]
|
598
|
+
for tag in c.execute(
|
599
|
+
"""
|
600
|
+
SELECT ir_tag.tag
|
601
|
+
FROM ir_tag, ir_file_ir_tag
|
602
|
+
WHERE ? = ir_file_ir_tag.file_id
|
603
|
+
AND ir_tag.id = ir_file_ir_tag.tag_id
|
604
|
+
""",
|
605
|
+
(file.id,),
|
606
|
+
).fetchall()
|
607
|
+
]
|
608
|
+
|
609
|
+
files.append(
|
610
|
+
ImpulseResponseFile(
|
611
|
+
delay=file.delay,
|
612
|
+
name=file.name,
|
613
|
+
tags=tags,
|
614
|
+
)
|
615
|
+
)
|
616
|
+
|
617
|
+
return files
|
618
|
+
|
619
|
+
@cached_property
|
620
|
+
def ir_file_ids(self) -> list[int]:
|
621
|
+
"""Get impulse response file IDs from db
|
622
|
+
|
623
|
+
:return: List of impulse response file IDs
|
624
|
+
"""
|
625
|
+
with self.db() as c:
|
626
|
+
return [int(item[0]) for item in c.execute("SELECT id FROM ir_file").fetchall()]
|
627
|
+
|
628
|
+
def ir_file_ids_for_tag(self, tag: str) -> list[int]:
|
629
|
+
"""Get impulse response file IDs for the given tag from db
|
630
|
+
|
631
|
+
:return: List of impulse response file IDs for the given tag
|
632
|
+
"""
|
633
|
+
with self.db() as c:
|
634
|
+
tag_id = c.execute("SELECT id FROM ir_tag WHERE ? = tag", (tag,)).fetchone()
|
635
|
+
if not tag_id:
|
636
|
+
return []
|
637
|
+
|
638
|
+
return [
|
639
|
+
int(item[0] - 1)
|
640
|
+
for item in c.execute("SELECT file_id FROM ir_file_ir_tag WHERE ? = tag_id", (tag_id[0],)).fetchall()
|
641
|
+
]
|
642
|
+
|
643
|
+
def ir_file(self, ir_id: int) -> str:
|
644
|
+
"""Get impulse response file name with ID from db
|
645
|
+
|
646
|
+
:param ir_id: Impulse response file ID
|
647
|
+
:return: Impulse response file name
|
648
|
+
"""
|
649
|
+
return _ir_file(self.db, ir_id, self.use_cache)
|
650
|
+
|
651
|
+
def ir_delay(self, ir_id: int) -> int:
|
652
|
+
"""Get impulse response delay with ID from db
|
653
|
+
|
654
|
+
:param ir_id: Impulse response file ID
|
655
|
+
:return: Impulse response delay
|
656
|
+
"""
|
657
|
+
return _ir_delay(self.db, ir_id, self.use_cache)
|
658
|
+
|
659
|
+
@cached_property
|
660
|
+
def num_ir_files(self) -> int:
|
661
|
+
"""Get number of impulse response files from db
|
662
|
+
|
663
|
+
:return: Number of impulse response files
|
664
|
+
"""
|
665
|
+
with self.db() as c:
|
666
|
+
return int(c.execute("SELECT count(id) FROM ir_file").fetchone()[0])
|
667
|
+
|
668
|
+
@cached_property
|
669
|
+
def ir_tags(self) -> list[str]:
|
670
|
+
"""Get tags of impulse response files from db
|
671
|
+
|
672
|
+
:return: Tags of impulse response files
|
673
|
+
"""
|
674
|
+
with self.db() as c:
|
675
|
+
return [tag[0] for tag in c.execute("SELECT tag FROM ir_tag").fetchall()]
|
676
|
+
|
677
|
+
@property
|
678
|
+
def mixtures(self) -> list[Mixture]:
|
679
|
+
"""Get mixtures from db
|
680
|
+
|
681
|
+
:return: Mixtures
|
682
|
+
"""
|
683
|
+
from .db_datatypes import MixtureRecord
|
684
|
+
from .db_datatypes import SourceRecord
|
685
|
+
from .helpers import to_mixture
|
686
|
+
from .helpers import to_source
|
687
|
+
|
688
|
+
with self.db() as c:
|
689
|
+
mixtures: list[Mixture] = []
|
690
|
+
for mixture in [MixtureRecord(*record) for record in c.execute("SELECT * FROM mixture").fetchall()]:
|
691
|
+
sources_list = [
|
692
|
+
to_source(SourceRecord(*source))
|
693
|
+
for source in c.execute(
|
694
|
+
"""
|
695
|
+
SELECT source.*
|
696
|
+
FROM source, mixture_source
|
697
|
+
WHERE ? = mixture_source.mixture_id AND source.id = mixture_source.source_id
|
698
|
+
""",
|
699
|
+
(mixture.id,),
|
700
|
+
).fetchall()
|
701
|
+
]
|
702
|
+
|
703
|
+
sources: Sources = {}
|
704
|
+
for source in sources_list:
|
705
|
+
sources[self.source_file(source.file_id).category] = source
|
706
|
+
|
707
|
+
mixtures.append(to_mixture(mixture, sources))
|
708
|
+
|
709
|
+
return mixtures
|
710
|
+
|
711
|
+
@cached_property
|
712
|
+
def mixture_ids(self) -> list[int]:
|
713
|
+
"""Get mixture IDs from db
|
714
|
+
|
715
|
+
:return: List of zero-based mixture IDs
|
716
|
+
"""
|
717
|
+
with self.db() as c:
|
718
|
+
return [int(item[0]) - 1 for item in c.execute("SELECT id FROM mixture").fetchall()]
|
719
|
+
|
720
|
+
def mixture(self, m_id: int) -> Mixture:
|
721
|
+
"""Get mixture record with ID from db
|
722
|
+
|
723
|
+
:param m_id: Zero-based mixture ID
|
724
|
+
:return: Mixture record
|
725
|
+
"""
|
726
|
+
return _mixture(self.db, m_id, self.use_cache)
|
727
|
+
|
728
|
+
@cached_property
|
729
|
+
def mixid_width(self) -> int:
|
730
|
+
with self.db() as c:
|
731
|
+
return int(c.execute("SELECT mixid_width FROM top").fetchone()[0])
|
732
|
+
|
733
|
+
def mixture_location(self, m_id: int) -> str:
|
734
|
+
"""Get the file location for the give mixture ID
|
735
|
+
|
736
|
+
:param m_id: Zero-based mixture ID
|
737
|
+
:return: File location
|
738
|
+
"""
|
739
|
+
from os.path import join
|
740
|
+
|
741
|
+
return join(self.location, self.mixture(m_id).name)
|
742
|
+
|
743
|
+
@cached_property
|
744
|
+
def num_mixtures(self) -> int:
|
745
|
+
"""Get the number of mixtures from db
|
746
|
+
|
747
|
+
:return: Number of mixtures
|
748
|
+
"""
|
749
|
+
with self.db() as c:
|
750
|
+
return int(c.execute("SELECT count(id) FROM mixture").fetchone()[0])
|
751
|
+
|
752
|
+
def read_mixture_data(self, m_id: int, items: list[str] | str) -> dict[str, Any]:
|
753
|
+
"""Read mixture data
|
754
|
+
|
755
|
+
:param m_id: Zero-based mixture ID
|
756
|
+
:param items: String(s) of dataset(s) to retrieve
|
757
|
+
:return: Dictionary of name: data
|
758
|
+
"""
|
759
|
+
from .data_io import read_cached_data
|
760
|
+
|
761
|
+
return read_cached_data(self.location, "mixture", self.mixture(m_id).name, items)
|
762
|
+
|
763
|
+
def read_source_audio(self, s_id: int) -> AudioT:
|
764
|
+
"""Read source audio
|
765
|
+
|
766
|
+
:param s_id: Source ID
|
767
|
+
:return: Source audio
|
768
|
+
"""
|
769
|
+
from .audio import read_audio
|
770
|
+
|
771
|
+
return read_audio(self.source_file(s_id).name, self.use_cache)
|
772
|
+
|
773
|
+
def mixture_class_indices(self, m_id: int) -> list[int]:
|
774
|
+
class_indices: list[int] = []
|
775
|
+
for s_id in self.mixture(m_id).source_ids.values():
|
776
|
+
class_indices.extend(self.source_file(s_id).class_indices)
|
777
|
+
return sorted(set(class_indices))
|
778
|
+
|
779
|
+
def mixture_sources(self, m_id: int, force: bool = False, cache: bool = False) -> SourcesAudioT:
|
780
|
+
"""Get the pre-truth source audio data (one per source in the mixture) for the given mixture ID
|
781
|
+
|
782
|
+
:param m_id: Zero-based mixture ID
|
783
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
784
|
+
:param cache: Cache result
|
785
|
+
:return: Dictionary of pre-truth source audio data (one per source in the mixture)
|
786
|
+
"""
|
787
|
+
from .data_io import write_cached_data
|
788
|
+
from .effects import apply_effects
|
789
|
+
from .effects import conform_audio_to_length
|
790
|
+
|
791
|
+
if not force:
|
792
|
+
sources = self.read_mixture_data(m_id, "sources")["sources"]
|
793
|
+
if sources is not None:
|
794
|
+
return sources
|
795
|
+
|
796
|
+
mixture = self.mixture(m_id)
|
797
|
+
if mixture is None:
|
798
|
+
raise ValueError(f"Could not find mixture for m_id: {m_id}")
|
799
|
+
|
800
|
+
sources = {}
|
801
|
+
for category, source in mixture.all_sources.items():
|
802
|
+
source = mixture.all_sources[category]
|
803
|
+
audio = self.read_source_audio(source.file_id)
|
804
|
+
audio = apply_effects(self, audio, source.effects, pre=True, post=False)
|
805
|
+
audio = conform_audio_to_length(audio, mixture.samples, source.loop, source.start)
|
806
|
+
sources[category] = audio
|
807
|
+
|
808
|
+
if cache:
|
809
|
+
write_cached_data(
|
810
|
+
location=self.location,
|
811
|
+
name="mixture",
|
812
|
+
index=mixture.name,
|
813
|
+
items={"sources": sources},
|
814
|
+
)
|
815
|
+
|
816
|
+
return sources
|
817
|
+
|
818
|
+
def mixture_sources_f(
|
819
|
+
self,
|
820
|
+
m_id: int,
|
821
|
+
sources: SourcesAudioT | None = None,
|
822
|
+
force: bool = False,
|
823
|
+
cache: bool = False,
|
824
|
+
) -> SourcesAudioF:
|
825
|
+
"""Get the pre-truth source transform data (one per source in the mixture) for the given mixture ID
|
826
|
+
|
827
|
+
:param m_id: Zero-based mixture ID
|
828
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
829
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
830
|
+
:param cache: Cache result
|
831
|
+
:return: Dictionary of pre-truth source transform data (one per source in the mixture)
|
832
|
+
"""
|
833
|
+
from .data_io import write_cached_data
|
834
|
+
from .helpers import forward_transform
|
835
|
+
|
836
|
+
if sources is None:
|
837
|
+
sources = self.mixture_sources(m_id, force)
|
838
|
+
|
839
|
+
sources_f = {category: forward_transform(sources[category], self.ft_config) for category in sources}
|
840
|
+
|
841
|
+
if cache:
|
842
|
+
write_cached_data(
|
843
|
+
location=self.location,
|
844
|
+
name="mixture",
|
845
|
+
index=self.mixture(m_id).name,
|
846
|
+
items={"sources_f": sources_f},
|
847
|
+
)
|
848
|
+
|
849
|
+
return sources_f
|
850
|
+
|
851
|
+
def mixture_source(
|
852
|
+
self,
|
853
|
+
m_id: int,
|
854
|
+
sources: SourcesAudioT | None = None,
|
855
|
+
force: bool = False,
|
856
|
+
cache: bool = False,
|
857
|
+
) -> AudioT:
|
858
|
+
"""Get the post-truth, summed, and gained source audio data for the given mixture ID
|
859
|
+
|
860
|
+
:param m_id: Zero-based mixture ID
|
861
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
862
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
863
|
+
:param cache: Cache result
|
864
|
+
:return: Post-truth, gained, and summed source audio data
|
865
|
+
"""
|
866
|
+
import numpy as np
|
867
|
+
|
868
|
+
from .data_io import write_cached_data
|
869
|
+
from .effects import apply_effects
|
870
|
+
|
871
|
+
if not force:
|
872
|
+
source = self.read_mixture_data(m_id, "source")["source"]
|
873
|
+
if source is not None:
|
874
|
+
return source
|
875
|
+
|
876
|
+
if sources is None:
|
877
|
+
sources = self.mixture_sources(m_id, force)
|
878
|
+
|
879
|
+
mixture = self.mixture(m_id)
|
880
|
+
|
881
|
+
source = np.sum(
|
882
|
+
[
|
883
|
+
apply_effects(
|
884
|
+
self,
|
885
|
+
audio=sources[category],
|
886
|
+
effects=mixture.all_sources[category].effects,
|
887
|
+
pre=False,
|
888
|
+
post=True,
|
889
|
+
)
|
890
|
+
* mixture.all_sources[category].snr_gain
|
891
|
+
for category in sources
|
892
|
+
if category != "noise"
|
893
|
+
],
|
894
|
+
axis=0,
|
895
|
+
)
|
896
|
+
|
897
|
+
if cache:
|
898
|
+
write_cached_data(
|
899
|
+
location=self.location,
|
900
|
+
name="mixture",
|
901
|
+
index=mixture.name,
|
902
|
+
items={"source": source},
|
903
|
+
)
|
904
|
+
|
905
|
+
return source
|
906
|
+
|
907
|
+
def mixture_source_f(
|
908
|
+
self,
|
909
|
+
m_id: int,
|
910
|
+
sources: SourcesAudioT | None = None,
|
911
|
+
source: AudioT | None = None,
|
912
|
+
force: bool = False,
|
913
|
+
cache: bool = False,
|
914
|
+
) -> AudioF:
|
915
|
+
"""Get the post-truth, summed, and gained source transform data for the given mixture ID
|
916
|
+
|
917
|
+
:param m_id: Zero-based mixture ID
|
918
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
919
|
+
:param source: Post-truth, gained, and summed source audio for the given m_id
|
920
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
921
|
+
:param cache: Cache result
|
922
|
+
:return: Post-truth, gained, and summed source transform data
|
923
|
+
"""
|
924
|
+
from .data_io import write_cached_data
|
925
|
+
from .helpers import forward_transform
|
926
|
+
|
927
|
+
if source is None:
|
928
|
+
source = self.mixture_source(m_id, sources, force)
|
929
|
+
|
930
|
+
source_f = forward_transform(source, self.ft_config)
|
931
|
+
|
932
|
+
if cache:
|
933
|
+
write_cached_data(
|
934
|
+
location=self.location,
|
935
|
+
name="mixture",
|
936
|
+
index=self.mixture(m_id).name,
|
937
|
+
items={"source_f": source_f},
|
938
|
+
)
|
939
|
+
|
940
|
+
return source_f
|
941
|
+
|
942
|
+
def mixture_noise(
|
943
|
+
self,
|
944
|
+
m_id: int,
|
945
|
+
sources: SourcesAudioT | None = None,
|
946
|
+
force: bool = False,
|
947
|
+
cache: bool = False,
|
948
|
+
) -> AudioT:
|
949
|
+
"""Get the post-truth and gained noise audio data for the given mixture ID
|
950
|
+
|
951
|
+
:param m_id: Zero-based mixture ID
|
952
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
953
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
954
|
+
:param cache: Cache result
|
955
|
+
:return: Post-truth and gained noise audio data
|
956
|
+
"""
|
957
|
+
from .data_io import write_cached_data
|
958
|
+
from .effects import apply_effects
|
959
|
+
|
960
|
+
if not force:
|
961
|
+
noise = self.read_mixture_data(m_id, "noise")["noise"]
|
962
|
+
if noise is not None:
|
963
|
+
return noise
|
964
|
+
|
965
|
+
if sources is None:
|
966
|
+
sources = self.mixture_sources(m_id, force)
|
967
|
+
|
968
|
+
noise = self.mixture(m_id).noise
|
969
|
+
noise = apply_effects(self, sources["noise"], noise.effects, pre=False, post=True) * noise.snr_gain
|
970
|
+
|
971
|
+
if cache:
|
972
|
+
write_cached_data(
|
973
|
+
location=self.location,
|
974
|
+
name="mixture",
|
975
|
+
index=self.mixture(m_id).name,
|
976
|
+
items={"noise": noise},
|
977
|
+
)
|
978
|
+
|
979
|
+
return noise
|
980
|
+
|
981
|
+
def mixture_noise_f(
|
982
|
+
self,
|
983
|
+
m_id: int,
|
984
|
+
sources: SourcesAudioT | None = None,
|
985
|
+
noise: AudioT | None = None,
|
986
|
+
force: bool = False,
|
987
|
+
cache: bool = False,
|
988
|
+
) -> AudioF:
|
989
|
+
"""Get the post-truth and gained noise transform for the given mixture ID
|
990
|
+
|
991
|
+
:param m_id: Zero-based mixture ID
|
992
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
993
|
+
:param noise: Post-truth and gained noise audio data
|
994
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
995
|
+
:param cache: Cache result
|
996
|
+
:return: Post-truth and gained noise transform data
|
997
|
+
"""
|
998
|
+
from .data_io import write_cached_data
|
999
|
+
from .helpers import forward_transform
|
1000
|
+
|
1001
|
+
if force or noise is None:
|
1002
|
+
noise = self.mixture_noise(m_id, sources, force)
|
1003
|
+
|
1004
|
+
noise_f = forward_transform(noise, self.ft_config)
|
1005
|
+
if cache:
|
1006
|
+
write_cached_data(
|
1007
|
+
location=self.location,
|
1008
|
+
name="mixture",
|
1009
|
+
index=self.mixture(m_id).name,
|
1010
|
+
items={"noise_f": noise_f},
|
1011
|
+
)
|
1012
|
+
|
1013
|
+
return noise_f
|
1014
|
+
|
1015
|
+
def mixture_mixture(
|
1016
|
+
self,
|
1017
|
+
m_id: int,
|
1018
|
+
sources: SourcesAudioT | None = None,
|
1019
|
+
source: AudioT | None = None,
|
1020
|
+
noise: AudioT | None = None,
|
1021
|
+
force: bool = False,
|
1022
|
+
cache: bool = False,
|
1023
|
+
) -> AudioT:
|
1024
|
+
"""Get the mixture audio data for the given mixture ID
|
1025
|
+
|
1026
|
+
:param m_id: Zero-based mixture ID
|
1027
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1028
|
+
:param source: Post-truth, gained, and summed source audio data
|
1029
|
+
:param noise: Post-truth and gained noise audio data
|
1030
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
1031
|
+
:param cache: Cache result
|
1032
|
+
:return: Mixture audio data
|
1033
|
+
"""
|
1034
|
+
from .data_io import write_cached_data
|
1035
|
+
|
1036
|
+
if not force:
|
1037
|
+
mixture = self.read_mixture_data(m_id, "mixture")["mixture"]
|
1038
|
+
if mixture is not None:
|
1039
|
+
return mixture
|
1040
|
+
|
1041
|
+
if source is None:
|
1042
|
+
source = self.mixture_source(m_id, sources, force)
|
1043
|
+
|
1044
|
+
if noise is None:
|
1045
|
+
noise = self.mixture_noise(m_id, sources, force)
|
1046
|
+
|
1047
|
+
mixture = source + noise
|
1048
|
+
|
1049
|
+
if cache:
|
1050
|
+
write_cached_data(
|
1051
|
+
location=self.location,
|
1052
|
+
name="mixture",
|
1053
|
+
index=self.mixture(m_id).name,
|
1054
|
+
items={"mixture": mixture},
|
1055
|
+
)
|
1056
|
+
|
1057
|
+
return mixture
|
1058
|
+
|
1059
|
+
def mixture_mixture_f(
|
1060
|
+
self,
|
1061
|
+
m_id: int,
|
1062
|
+
sources: SourcesAudioT | None = None,
|
1063
|
+
source: AudioT | None = None,
|
1064
|
+
noise: AudioT | None = None,
|
1065
|
+
mixture: AudioT | None = None,
|
1066
|
+
force: bool = False,
|
1067
|
+
cache: bool = False,
|
1068
|
+
) -> AudioF:
|
1069
|
+
"""Get the mixture transform for the given mixture ID
|
1070
|
+
|
1071
|
+
:param m_id: Zero-based mixture ID
|
1072
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1073
|
+
:param source: Post-truth, gained, and summed source audio data
|
1074
|
+
:param noise: Post-truth and gained noise audio data
|
1075
|
+
:param mixture: Mixture audio data
|
1076
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
1077
|
+
:param cache: Cache result
|
1078
|
+
:return: Mixture transform data
|
1079
|
+
"""
|
1080
|
+
from .data_io import write_cached_data
|
1081
|
+
from .helpers import forward_transform
|
1082
|
+
from .spectral_mask import apply_spectral_mask
|
1083
|
+
|
1084
|
+
if mixture is None:
|
1085
|
+
mixture = self.mixture_mixture(m_id, sources, source, noise, force)
|
1086
|
+
|
1087
|
+
mixture_f = forward_transform(mixture, self.ft_config)
|
1088
|
+
|
1089
|
+
m = self.mixture(m_id)
|
1090
|
+
if m.spectral_mask_id is not None:
|
1091
|
+
mixture_f = apply_spectral_mask(
|
1092
|
+
audio_f=mixture_f,
|
1093
|
+
spectral_mask=self.spectral_mask(int(m.spectral_mask_id)),
|
1094
|
+
seed=m.spectral_mask_seed,
|
1095
|
+
)
|
1096
|
+
|
1097
|
+
if cache:
|
1098
|
+
write_cached_data(
|
1099
|
+
location=self.location,
|
1100
|
+
name="mixture",
|
1101
|
+
index=self.mixture(m_id).name,
|
1102
|
+
items={"mixture_f": mixture_f},
|
1103
|
+
)
|
1104
|
+
|
1105
|
+
return mixture_f
|
1106
|
+
|
1107
|
+
def mixture_truth_t(self, m_id: int, force: bool = False, cache: bool = False) -> TruthsDict:
|
1108
|
+
"""Get the truth_t data for the given mixture ID
|
1109
|
+
|
1110
|
+
:param m_id: Zero-based mixture ID
|
1111
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
1112
|
+
:param cache: Cache result
|
1113
|
+
:return: list of truth_t data
|
1114
|
+
"""
|
1115
|
+
from .data_io import write_cached_data
|
1116
|
+
from .truth import truth_function
|
1117
|
+
|
1118
|
+
if not force:
|
1119
|
+
truth_t = self.read_mixture_data(m_id, "truth_t")["truth_t"]
|
1120
|
+
if truth_t is not None:
|
1121
|
+
return truth_t
|
1122
|
+
|
1123
|
+
truth_t = truth_function(self, m_id)
|
1124
|
+
|
1125
|
+
if cache:
|
1126
|
+
write_cached_data(
|
1127
|
+
location=self.location,
|
1128
|
+
name="mixture",
|
1129
|
+
index=self.mixture(m_id).name,
|
1130
|
+
items={"truth_t": truth_t},
|
1131
|
+
)
|
1132
|
+
|
1133
|
+
return truth_t
|
1134
|
+
|
1135
|
+
def mixture_segsnr_t(
|
1136
|
+
self,
|
1137
|
+
m_id: int,
|
1138
|
+
sources: SourcesAudioT | None = None,
|
1139
|
+
source: AudioT | None = None,
|
1140
|
+
noise: AudioT | None = None,
|
1141
|
+
force: bool = False,
|
1142
|
+
cache: bool = False,
|
1143
|
+
) -> Segsnr:
|
1144
|
+
"""Get the segsnr_t data for the given mixture ID
|
1145
|
+
|
1146
|
+
:param m_id: Zero-based mixture ID
|
1147
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1148
|
+
:param source: Post-truth, gained, and summed source audio data
|
1149
|
+
:param noise: Post-truth and gained noise audio data
|
1150
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
1151
|
+
:param cache: Cache result
|
1152
|
+
:return: segsnr_t data
|
1153
|
+
"""
|
1154
|
+
import numpy as np
|
1155
|
+
import torch
|
1156
|
+
from pyaaware import ForwardTransform
|
1157
|
+
|
1158
|
+
from .data_io import write_cached_data
|
1159
|
+
|
1160
|
+
if not force:
|
1161
|
+
segsnr_t = self.read_mixture_data(m_id, "segsnr_t")["segsnr_t"]
|
1162
|
+
if segsnr_t is not None:
|
1163
|
+
return segsnr_t
|
1164
|
+
|
1165
|
+
if source is None:
|
1166
|
+
source = self.mixture_source(m_id, sources, force)
|
1167
|
+
|
1168
|
+
if noise is None:
|
1169
|
+
noise = self.mixture_noise(m_id, sources, force)
|
1170
|
+
|
1171
|
+
ft = ForwardTransform(
|
1172
|
+
length=self.ft_config.length,
|
1173
|
+
overlap=self.ft_config.overlap,
|
1174
|
+
bin_start=self.ft_config.bin_start,
|
1175
|
+
bin_end=self.ft_config.bin_end,
|
1176
|
+
ttype=self.ft_config.ttype,
|
1177
|
+
)
|
1178
|
+
|
1179
|
+
mixture = self.mixture(m_id)
|
1180
|
+
|
1181
|
+
segsnr_t = np.empty(mixture.samples, dtype=np.float32)
|
1182
|
+
|
1183
|
+
source_energy = ft.execute_all(torch.from_numpy(source))[1].numpy()
|
1184
|
+
noise_energy = ft.execute_all(torch.from_numpy(noise))[1].numpy()
|
1185
|
+
|
1186
|
+
offsets = range(0, mixture.samples, self.ft_config.overlap)
|
1187
|
+
if len(source_energy) != len(offsets):
|
1188
|
+
raise ValueError(
|
1189
|
+
f"Number of frames in energy, {len(source_energy)}, is not number of frames in mixture, {len(offsets)}"
|
1190
|
+
)
|
1191
|
+
|
1192
|
+
for idx, offset in enumerate(offsets):
|
1193
|
+
indices = slice(offset, offset + self.ft_config.overlap)
|
1194
|
+
|
1195
|
+
if noise_energy[idx] == 0:
|
1196
|
+
snr = np.float32(np.inf)
|
1197
|
+
else:
|
1198
|
+
snr = np.float32(source_energy[idx] / noise_energy[idx])
|
1199
|
+
|
1200
|
+
segsnr_t[indices] = snr
|
1201
|
+
|
1202
|
+
if cache:
|
1203
|
+
write_cached_data(
|
1204
|
+
location=self.location,
|
1205
|
+
name="mixture",
|
1206
|
+
index=mixture.name,
|
1207
|
+
items={"segsnr_t": segsnr_t},
|
1208
|
+
)
|
1209
|
+
|
1210
|
+
return segsnr_t
|
1211
|
+
|
1212
|
+
def mixture_segsnr(
|
1213
|
+
self,
|
1214
|
+
m_id: int,
|
1215
|
+
segsnr_t: Segsnr | None = None,
|
1216
|
+
sources: SourcesAudioT | None = None,
|
1217
|
+
source: AudioT | None = None,
|
1218
|
+
noise: AudioT | None = None,
|
1219
|
+
force: bool = False,
|
1220
|
+
cache: bool = False,
|
1221
|
+
) -> Segsnr:
|
1222
|
+
"""Get the segsnr data for the given mixture ID
|
1223
|
+
|
1224
|
+
:param m_id: Zero-based mixture ID
|
1225
|
+
:param segsnr_t: segsnr_t data
|
1226
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1227
|
+
:param source: Post-truth, gained, and summed source audio data
|
1228
|
+
:param noise: Post-truth and gained noise audio data
|
1229
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
1230
|
+
:param cache: Cache result
|
1231
|
+
:return: segsnr data
|
1232
|
+
"""
|
1233
|
+
from .data_io import write_cached_data
|
1234
|
+
|
1235
|
+
if not force:
|
1236
|
+
segsnr = self.read_mixture_data(m_id, "segsnr")["segsnr"]
|
1237
|
+
if segsnr is not None:
|
1238
|
+
return segsnr
|
1239
|
+
|
1240
|
+
if segsnr_t is None:
|
1241
|
+
segsnr_t = self.mixture_segsnr_t(m_id, sources, source, noise, force)
|
1242
|
+
|
1243
|
+
segsnr = segsnr_t[0 :: self.ft_config.overlap]
|
1244
|
+
|
1245
|
+
if cache:
|
1246
|
+
write_cached_data(
|
1247
|
+
location=self.location,
|
1248
|
+
name="mixture",
|
1249
|
+
index=self.mixture(m_id).name,
|
1250
|
+
items={"segsnr": segsnr},
|
1251
|
+
)
|
1252
|
+
|
1253
|
+
return segsnr
|
1254
|
+
|
1255
|
+
def mixture_ft(
|
1256
|
+
self,
|
1257
|
+
m_id: int,
|
1258
|
+
sources: SourcesAudioT | None = None,
|
1259
|
+
source: AudioT | None = None,
|
1260
|
+
noise: AudioT | None = None,
|
1261
|
+
mixture_f: AudioF | None = None,
|
1262
|
+
mixture: AudioT | None = None,
|
1263
|
+
truth_t: TruthsDict | None = None,
|
1264
|
+
force: bool = False,
|
1265
|
+
cache: bool = False,
|
1266
|
+
) -> tuple[Feature, TruthsDict]:
|
1267
|
+
"""Get the feature and truth_f data for the given mixture ID
|
1268
|
+
|
1269
|
+
:param m_id: Zero-based mixture ID
|
1270
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1271
|
+
:param source: Post-truth, gained, and summed source audio data
|
1272
|
+
:param noise: Post-truth and gained noise audio data
|
1273
|
+
:param mixture_f: Mixture transform data
|
1274
|
+
:param mixture: Mixture audio data
|
1275
|
+
:param truth_t: truth_t
|
1276
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
1277
|
+
:param cache: Cache result
|
1278
|
+
:return: Tuple of (feature, truth_f) data
|
1279
|
+
"""
|
1280
|
+
from pyaaware import FeatureGenerator
|
1281
|
+
|
1282
|
+
from .data_io import write_cached_data
|
1283
|
+
from .truth import truth_stride_reduction
|
1284
|
+
|
1285
|
+
if not force:
|
1286
|
+
ft = self.read_mixture_data(m_id, ["feature", "truth_f"])
|
1287
|
+
if ft["feature"] is not None and ft["truth_f"] is not None:
|
1288
|
+
return ft["feature"], ft["truth_f"]
|
1289
|
+
|
1290
|
+
if mixture_f is None:
|
1291
|
+
mixture_f = self.mixture_mixture_f(
|
1292
|
+
m_id=m_id,
|
1293
|
+
sources=sources,
|
1294
|
+
source=source,
|
1295
|
+
noise=noise,
|
1296
|
+
mixture=mixture,
|
1297
|
+
force=force,
|
1298
|
+
)
|
1299
|
+
|
1300
|
+
if truth_t is None:
|
1301
|
+
truth_t = self.mixture_truth_t(m_id, force)
|
1302
|
+
|
1303
|
+
fg = FeatureGenerator(self.fg_config.feature_mode, self.fg_config.truth_parameters)
|
1304
|
+
|
1305
|
+
feature, truth_f = fg.execute_all(mixture_f, truth_t)
|
1306
|
+
if truth_f is None:
|
1307
|
+
raise TypeError("Unexpected truth of None from feature generator")
|
1308
|
+
|
1309
|
+
truth_configs = self.mixture_truth_configs(m_id)
|
1310
|
+
for category, configs in truth_configs.items():
|
1311
|
+
for name, config in configs.items():
|
1312
|
+
if self.truth_parameters[category][name] is not None:
|
1313
|
+
truth_f[category][name] = truth_stride_reduction(truth_f[category][name], config.stride_reduction)
|
1314
|
+
|
1315
|
+
if cache:
|
1316
|
+
write_cached_data(
|
1317
|
+
location=self.location,
|
1318
|
+
name="mixture",
|
1319
|
+
index=self.mixture(m_id).name,
|
1320
|
+
items={"feature": truth_f, "truth_f": truth_f},
|
1321
|
+
)
|
1322
|
+
|
1323
|
+
return feature, truth_f
|
1324
|
+
|
1325
|
+
def mixture_feature(
|
1326
|
+
self,
|
1327
|
+
m_id: int,
|
1328
|
+
sources: SourcesAudioT | None = None,
|
1329
|
+
noise: AudioT | None = None,
|
1330
|
+
mixture: AudioT | None = None,
|
1331
|
+
truth_t: TruthsDict | None = None,
|
1332
|
+
force: bool = False,
|
1333
|
+
cache: bool = False,
|
1334
|
+
) -> Feature:
|
1335
|
+
"""Get the feature data for the given mixture ID
|
1336
|
+
|
1337
|
+
:param m_id: Zero-based mixture ID
|
1338
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1339
|
+
:param noise: Post-truth and gained noise audio data
|
1340
|
+
:param mixture: Mixture audio data
|
1341
|
+
:param truth_t: truth_t
|
1342
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
1343
|
+
:param cache: Cache result
|
1344
|
+
:return: Feature data
|
1345
|
+
"""
|
1346
|
+
from .data_io import write_cached_data
|
1347
|
+
|
1348
|
+
feature = self.mixture_ft(
|
1349
|
+
m_id=m_id,
|
1350
|
+
sources=sources,
|
1351
|
+
noise=noise,
|
1352
|
+
mixture=mixture,
|
1353
|
+
truth_t=truth_t,
|
1354
|
+
force=force,
|
1355
|
+
)[0]
|
1356
|
+
|
1357
|
+
if cache:
|
1358
|
+
write_cached_data(
|
1359
|
+
location=self.location,
|
1360
|
+
name="mixture",
|
1361
|
+
index=self.mixture(m_id).name,
|
1362
|
+
items={"feature": feature},
|
1363
|
+
)
|
1364
|
+
|
1365
|
+
return feature
|
1366
|
+
|
1367
|
+
def mixture_truth_f(
|
1368
|
+
self,
|
1369
|
+
m_id: int,
|
1370
|
+
sources: SourcesAudioT | None = None,
|
1371
|
+
noise: AudioT | None = None,
|
1372
|
+
mixture: AudioT | None = None,
|
1373
|
+
truth_t: TruthsDict | None = None,
|
1374
|
+
force: bool = False,
|
1375
|
+
cache: bool = False,
|
1376
|
+
) -> TruthDict:
|
1377
|
+
"""Get the truth_f data for the given mixture ID
|
1378
|
+
|
1379
|
+
:param m_id: Zero-based mixture ID
|
1380
|
+
:param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
|
1381
|
+
:param noise: Post-truth and gained noise audio data
|
1382
|
+
:param mixture: Mixture audio data
|
1383
|
+
:param truth_t: truth_t
|
1384
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
1385
|
+
:param cache: Cache result
|
1386
|
+
:return: truth_f data
|
1387
|
+
"""
|
1388
|
+
from .data_io import write_cached_data
|
1389
|
+
|
1390
|
+
truth_f = self.mixture_ft(
|
1391
|
+
m_id=m_id,
|
1392
|
+
sources=sources,
|
1393
|
+
noise=noise,
|
1394
|
+
mixture=mixture,
|
1395
|
+
truth_t=truth_t,
|
1396
|
+
force=force,
|
1397
|
+
)[1]
|
1398
|
+
|
1399
|
+
if cache:
|
1400
|
+
write_cached_data(
|
1401
|
+
location=self.location,
|
1402
|
+
name="mixture",
|
1403
|
+
index=self.mixture(m_id).name,
|
1404
|
+
items={"truth_f": truth_f},
|
1405
|
+
)
|
1406
|
+
|
1407
|
+
return truth_f
|
1408
|
+
|
1409
|
+
def mixture_class_count(self, m_id: int, truth_t: TruthsDict | None = None) -> dict[str, ClassCount]:
|
1410
|
+
"""Compute the number of frames for which each class index is active for the given mixture ID
|
1411
|
+
|
1412
|
+
:param m_id: Zero-based mixture ID
|
1413
|
+
:param truth_t: truth_t
|
1414
|
+
:return: Dictionary of class counts
|
1415
|
+
"""
|
1416
|
+
import numpy as np
|
1417
|
+
|
1418
|
+
if truth_t is None:
|
1419
|
+
truth_t = self.mixture_truth_t(m_id)
|
1420
|
+
|
1421
|
+
class_count: dict[str, ClassCount] = {}
|
1422
|
+
|
1423
|
+
truth_configs = self.mixture_truth_configs(m_id)
|
1424
|
+
for category in truth_configs:
|
1425
|
+
class_count[category] = [0] * self.num_classes
|
1426
|
+
for configs in truth_configs[category]:
|
1427
|
+
if "sed" in configs:
|
1428
|
+
for cl in range(self.num_classes):
|
1429
|
+
class_count[category][cl] = int(
|
1430
|
+
np.sum(truth_t[category]["sed"][:, cl] >= self.class_weights_thresholds[cl])
|
1431
|
+
)
|
1432
|
+
|
1433
|
+
return class_count
|
1434
|
+
|
1435
|
+
@cached_property
|
1436
|
+
def speaker_metadata_tiers(self) -> list[str]:
|
1437
|
+
import json
|
1438
|
+
|
1439
|
+
with self.db() as c:
|
1440
|
+
return json.loads(c.execute("SELECT speaker_metadata_tiers FROM top WHERE 1 = id").fetchone()[0])
|
1441
|
+
|
1442
|
+
@cached_property
|
1443
|
+
def textgrid_metadata_tiers(self) -> list[str]:
|
1444
|
+
import json
|
1445
|
+
|
1446
|
+
with self.db() as c:
|
1447
|
+
return json.loads(c.execute("SELECT textgrid_metadata_tiers FROM top WHERE 1 = id").fetchone()[0])
|
1448
|
+
|
1449
|
+
@cached_property
|
1450
|
+
def speech_metadata_tiers(self) -> list[str]:
|
1451
|
+
return sorted(set(self.speaker_metadata_tiers + self.textgrid_metadata_tiers))
|
1452
|
+
|
1453
|
+
def speaker(self, s_id: int | None, tier: str) -> str | None:
|
1454
|
+
return _speaker(self.db, s_id, tier, self.use_cache)
|
1455
|
+
|
1456
|
+
def speech_metadata(self, tier: str) -> list[str]:
|
1457
|
+
from .helpers import get_textgrid_tier_from_source_file
|
1458
|
+
|
1459
|
+
results: set[str] = set()
|
1460
|
+
if tier in self.textgrid_metadata_tiers:
|
1461
|
+
for source_files in self.source_files.values():
|
1462
|
+
for source_file in source_files:
|
1463
|
+
data = get_textgrid_tier_from_source_file(source_file.name, tier)
|
1464
|
+
if data is None:
|
1465
|
+
continue
|
1466
|
+
if isinstance(data, list):
|
1467
|
+
for item in data:
|
1468
|
+
results.add(item.label)
|
1469
|
+
else:
|
1470
|
+
results.add(data)
|
1471
|
+
elif tier in self.speaker_metadata_tiers:
|
1472
|
+
for source_files in self.source_files.values():
|
1473
|
+
for source_file in source_files:
|
1474
|
+
data = self.speaker(source_file.speaker_id, tier)
|
1475
|
+
if data is not None:
|
1476
|
+
results.add(data)
|
1477
|
+
|
1478
|
+
return sorted(results)
|
1479
|
+
|
1480
|
+
def mixture_speech_metadata(self, mixid: int, tier: str) -> dict[str, SpeechMetadata]:
|
1481
|
+
from praatio.utilities.constants import Interval
|
1482
|
+
|
1483
|
+
from .helpers import get_textgrid_tier_from_source_file
|
1484
|
+
|
1485
|
+
results: dict[str, SpeechMetadata] = {}
|
1486
|
+
is_textgrid = tier in self.textgrid_metadata_tiers
|
1487
|
+
if is_textgrid:
|
1488
|
+
for category, source in self.mixture(mixid).all_sources.items():
|
1489
|
+
data = get_textgrid_tier_from_source_file(self.source_file(source.file_id).name, tier)
|
1490
|
+
if isinstance(data, list):
|
1491
|
+
# Check for tempo effect and adjust Interval start and end data as needed
|
1492
|
+
entries = []
|
1493
|
+
for entry in data:
|
1494
|
+
entries.append(
|
1495
|
+
Interval(
|
1496
|
+
entry.start / source.pre_tempo,
|
1497
|
+
entry.end / source.pre_tempo,
|
1498
|
+
entry.label,
|
1499
|
+
)
|
1500
|
+
)
|
1501
|
+
results[category] = entries
|
1502
|
+
else:
|
1503
|
+
results[category] = data
|
1504
|
+
else:
|
1505
|
+
for category, source in self.mixture(mixid).all_sources.items():
|
1506
|
+
results[category] = self.speaker(self.source_file(source.file_id).speaker_id, tier)
|
1507
|
+
|
1508
|
+
return results
|
1509
|
+
|
1510
|
+
def mixids_for_speech_metadata(
|
1511
|
+
self,
|
1512
|
+
tier: str | None = None,
|
1513
|
+
value: str | None = None,
|
1514
|
+
where: str | None = None,
|
1515
|
+
) -> dict[str, list[int]]:
|
1516
|
+
"""Get a list of mixture IDs for the given speech metadata tier.
|
1517
|
+
|
1518
|
+
If 'where' is None, then include mixture IDs whose tier values are equal to the given 'value'.
|
1519
|
+
If 'where' is not None, then ignore 'value' and use the given SQL where clause to determine
|
1520
|
+
which entries to include.
|
1521
|
+
|
1522
|
+
Examples:
|
1523
|
+
>>> mixdb = MixtureDatabase('/mixdb_location')
|
1524
|
+
|
1525
|
+
>>> mixids = mixdb.mixids_for_speech_metadata('speaker_id', 'TIMIT_ABW0')
|
1526
|
+
Get mixture IDs for mixtures with speakers whose speaker_ids are 'TIMIT_ABW0'.
|
1527
|
+
|
1528
|
+
>>> mixids = mixdb.mixids_for_speech_metadata(where='age >= 27')
|
1529
|
+
Get mixture IDs for mixtures with speakers whose ages are greater than or equal to 27.
|
1530
|
+
|
1531
|
+
>>> mixids = mixdb.mixids_for_speech_metadata(where="dialect in ('New York City', 'Northern')")
|
1532
|
+
Get mixture IDs for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
|
1533
|
+
"""
|
1534
|
+
if value is None and where is None:
|
1535
|
+
raise ValueError("Must provide either value or where")
|
1536
|
+
|
1537
|
+
if where is None:
|
1538
|
+
if tier is None:
|
1539
|
+
raise ValueError("Must provide tier")
|
1540
|
+
where = f"{tier} = '{value}'"
|
1541
|
+
|
1542
|
+
if tier is not None and tier in self.textgrid_metadata_tiers:
|
1543
|
+
raise ValueError(f"TextGrid tier data, '{tier}', is not supported in mixids_for_speech_metadata().")
|
1544
|
+
|
1545
|
+
with self.db() as c:
|
1546
|
+
results = c.execute(f"SELECT id FROM speaker WHERE {where}").fetchall()
|
1547
|
+
speaker_ids = ",".join(map(str, [i[0] for i in results]))
|
1548
|
+
|
1549
|
+
results = c.execute(f"SELECT id, category FROM source_file WHERE speaker_id IN ({speaker_ids})").fetchall()
|
1550
|
+
source_file_ids: dict[str, list[int]] = {}
|
1551
|
+
for result in results:
|
1552
|
+
source_file_id, category = result
|
1553
|
+
if category not in source_file_ids:
|
1554
|
+
source_file_ids[category] = [source_file_id]
|
1555
|
+
else:
|
1556
|
+
source_file_ids[category].append(source_file_id)
|
1557
|
+
|
1558
|
+
mixids: dict[str, list[int]] = {}
|
1559
|
+
for category in source_file_ids:
|
1560
|
+
id_str = ",".join(map(str, source_file_ids[category]))
|
1561
|
+
results = c.execute(f"SELECT id FROM source WHERE file_id IN ({id_str})").fetchall()
|
1562
|
+
source_ids = ",".join(map(str, [i[0] for i in results]))
|
1563
|
+
|
1564
|
+
results = c.execute(
|
1565
|
+
f"SELECT mixture_id FROM mixture_source WHERE source_id IN ({source_ids})"
|
1566
|
+
).fetchall()
|
1567
|
+
mixids[category] = [mixture_id[0] - 1 for mixture_id in results]
|
1568
|
+
|
1569
|
+
return mixids
|
1570
|
+
|
1571
|
+
def mixture_all_speech_metadata(self, m_id: int) -> dict[str, dict[str, SpeechMetadata]]:
|
1572
|
+
from .helpers import mixture_all_speech_metadata
|
1573
|
+
|
1574
|
+
return mixture_all_speech_metadata(self, self.mixture(m_id))
|
1575
|
+
|
1576
|
+
def cached_metrics(self, m_ids: GeneralizedIDs = "*") -> list[str]:
|
1577
|
+
"""Get a list of cached metrics for all mixtures."""
|
1578
|
+
from glob import glob
|
1579
|
+
from os.path import join
|
1580
|
+
from pathlib import Path
|
1581
|
+
|
1582
|
+
supported_metrics = self.supported_metrics.names
|
1583
|
+
first = True
|
1584
|
+
result: set[str] = set()
|
1585
|
+
for m_id in self.mixids_to_list(m_ids):
|
1586
|
+
mixture_dir = join(self.location, "mixture", self.mixture(m_id).name)
|
1587
|
+
found = {Path(f).stem for f in glob(join(mixture_dir, "*.pkl"))}
|
1588
|
+
if first:
|
1589
|
+
first = False
|
1590
|
+
for f in found:
|
1591
|
+
if f in supported_metrics:
|
1592
|
+
result.add(f)
|
1593
|
+
else:
|
1594
|
+
result = result & found
|
1595
|
+
|
1596
|
+
return sorted(result)
|
1597
|
+
|
1598
|
+
def mixture_metrics(self, m_id: int, metrics: list[str], force: bool = False) -> dict[str, Any]:
|
1599
|
+
"""Get metrics data for the given mixture ID
|
1600
|
+
|
1601
|
+
:param m_id: Zero-based mixture ID
|
1602
|
+
:param metrics: List of metrics to get
|
1603
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
1604
|
+
:return: Dictionary of metric data
|
1605
|
+
"""
|
1606
|
+
from ..metrics import calculate_metrics
|
1607
|
+
|
1608
|
+
return calculate_metrics(self, m_id, metrics, force)
|
1609
|
+
|
1610
|
+
|
1611
|
+
def _spectral_mask(db: partial, sm_id: int, use_cache: bool = True) -> SpectralMask:
|
1612
|
+
"""Get spectral mask with ID from db
|
1613
|
+
|
1614
|
+
:param db: Database context
|
1615
|
+
:param sm_id: Spectral mask ID
|
1616
|
+
:param use_cache: If true, use LRU caching
|
1617
|
+
:return: Spectral mask
|
1618
|
+
"""
|
1619
|
+
if use_cache:
|
1620
|
+
return __spectral_mask(db, sm_id)
|
1621
|
+
return __spectral_mask.__wrapped__(db, sm_id)
|
1622
|
+
|
1623
|
+
|
1624
|
+
@lru_cache
|
1625
|
+
def __spectral_mask(db: partial, sm_id: int) -> SpectralMask:
|
1626
|
+
from .db_datatypes import SpectralMaskRecord
|
1627
|
+
|
1628
|
+
with db() as c:
|
1629
|
+
spectral_mask = SpectralMaskRecord(*c.execute("SELECT * FROM spectral_mask WHERE ? = id", (sm_id,)).fetchone())
|
1630
|
+
return SpectralMask(
|
1631
|
+
f_max_width=spectral_mask.f_max_width,
|
1632
|
+
f_num=spectral_mask.f_num,
|
1633
|
+
t_max_width=spectral_mask.t_max_width,
|
1634
|
+
t_num=spectral_mask.t_num,
|
1635
|
+
t_max_percent=spectral_mask.t_max_percent,
|
1636
|
+
)
|
1637
|
+
|
1638
|
+
|
1639
|
+
def _num_source_files(db: partial, category: str, use_cache: bool = True) -> int:
|
1640
|
+
"""Get the number of source files from a category from db
|
1641
|
+
|
1642
|
+
:param db: Database context
|
1643
|
+
:param category: Source category
|
1644
|
+
:param use_cache: If true, use LRU caching
|
1645
|
+
:return: Number of source files
|
1646
|
+
"""
|
1647
|
+
if use_cache:
|
1648
|
+
return __num_source_files(db, category)
|
1649
|
+
return __num_source_files.__wrapped__(db, category)
|
1650
|
+
|
1651
|
+
|
1652
|
+
@lru_cache
|
1653
|
+
def __num_source_files(db: partial, category: str) -> int:
|
1654
|
+
"""Get the number of source files from a category from db
|
1655
|
+
|
1656
|
+
:param db: Database context
|
1657
|
+
:param category: Source category
|
1658
|
+
:return: Number of source files
|
1659
|
+
"""
|
1660
|
+
with db() as c:
|
1661
|
+
return int(c.execute("SELECT count(id) FROM source_file WHERE ? = category", (category,)).fetchone()[0])
|
1662
|
+
|
1663
|
+
|
1664
|
+
def _source_file(db: partial, s_id: int, use_cache: bool = True) -> SourceFile:
|
1665
|
+
"""Get the source file with ID from db
|
1666
|
+
|
1667
|
+
:param db: Database context
|
1668
|
+
:param s_id: Source file ID
|
1669
|
+
:param use_cache: If true, use LRU caching
|
1670
|
+
:return: Source file
|
1671
|
+
"""
|
1672
|
+
if use_cache:
|
1673
|
+
return __source_file(db, s_id, use_cache)
|
1674
|
+
return __source_file.__wrapped__(db, s_id, use_cache)
|
1675
|
+
|
1676
|
+
|
1677
|
+
@lru_cache
|
1678
|
+
def __source_file(db: partial, s_id: int, use_cache: bool = True) -> SourceFile:
|
1679
|
+
"""Get the source file with ID from db
|
1680
|
+
|
1681
|
+
:param db: Database context
|
1682
|
+
:param s_id: Source file ID
|
1683
|
+
:param use_cache: If true, use LRU caching
|
1684
|
+
:return: Source file
|
1685
|
+
"""
|
1686
|
+
import json
|
1687
|
+
|
1688
|
+
from .db_datatypes import SourceFileRecord
|
1689
|
+
|
1690
|
+
with db() as c:
|
1691
|
+
source_file = SourceFileRecord(*c.execute("SELECT * FROM source_file WHERE ? = id", (s_id,)).fetchone())
|
1692
|
+
|
1693
|
+
return SourceFile(
|
1694
|
+
category=source_file.category,
|
1695
|
+
name=source_file.name,
|
1696
|
+
samples=source_file.samples,
|
1697
|
+
class_indices=json.loads(source_file.class_indices),
|
1698
|
+
level_type=source_file.level_type,
|
1699
|
+
truth_configs=_source_truth_configs(db, s_id, use_cache),
|
1700
|
+
speaker_id=source_file.speaker_id,
|
1701
|
+
)
|
1702
|
+
|
1703
|
+
|
1704
|
+
def _ir_file(db: partial, ir_id: int, use_cache: bool = True) -> str:
|
1705
|
+
"""Get impulse response file name with ID from db
|
1706
|
+
|
1707
|
+
:param db: Database context
|
1708
|
+
:param ir_id: Impulse response file ID
|
1709
|
+
:param use_cache: If true, use LRU caching
|
1710
|
+
:return: Impulse response file name
|
1711
|
+
"""
|
1712
|
+
if use_cache:
|
1713
|
+
return __ir_file(db, ir_id)
|
1714
|
+
return __ir_file.__wrapped__(db, ir_id)
|
1715
|
+
|
1716
|
+
|
1717
|
+
@lru_cache
|
1718
|
+
def __ir_file(db: partial, ir_id: int) -> str:
|
1719
|
+
with db() as c:
|
1720
|
+
return str(c.execute("SELECT name FROM ir_file WHERE ? = id ", (ir_id + 1,)).fetchone()[0])
|
1721
|
+
|
1722
|
+
|
1723
|
+
def _ir_delay(db: partial, ir_id: int, use_cache: bool = True) -> int:
|
1724
|
+
"""Get impulse response delay with ID from db
|
1725
|
+
|
1726
|
+
:param db: Database context
|
1727
|
+
:param ir_id: Impulse response file ID
|
1728
|
+
:param use_cache: If true, use LRU caching
|
1729
|
+
:return: Impulse response delay
|
1730
|
+
"""
|
1731
|
+
if use_cache:
|
1732
|
+
return __ir_delay(db, ir_id)
|
1733
|
+
return __ir_delay.__wrapped__(db, ir_id)
|
1734
|
+
|
1735
|
+
|
1736
|
+
@lru_cache
|
1737
|
+
def __ir_delay(db: partial, ir_id: int) -> int:
|
1738
|
+
with db() as c:
|
1739
|
+
return int(c.execute("SELECT delay FROM ir_file WHERE ? = id", (ir_id + 1,)).fetchone()[0])
|
1740
|
+
|
1741
|
+
|
1742
|
+
def _mixture(db: partial, m_id: int, use_cache: bool = True) -> Mixture:
|
1743
|
+
"""Get mixture record with ID from db
|
1744
|
+
|
1745
|
+
:param db: Database context
|
1746
|
+
:param m_id: Zero-based mixture ID
|
1747
|
+
:param use_cache: If true, use LRU caching
|
1748
|
+
:return: Mixture record
|
1749
|
+
"""
|
1750
|
+
if use_cache:
|
1751
|
+
return __mixture(db, m_id)
|
1752
|
+
return __mixture.__wrapped__(db, m_id)
|
1753
|
+
|
1754
|
+
|
1755
|
+
@lru_cache
|
1756
|
+
def __mixture(db: partial, m_id: int) -> Mixture:
|
1757
|
+
from .db_datatypes import MixtureRecord
|
1758
|
+
from .db_datatypes import SourceRecord
|
1759
|
+
from .helpers import to_mixture
|
1760
|
+
from .helpers import to_source
|
1761
|
+
|
1762
|
+
with db() as c:
|
1763
|
+
mixture = MixtureRecord(*c.execute("SELECT * FROM mixture WHERE ? = id", (m_id + 1,)).fetchone())
|
1764
|
+
|
1765
|
+
sources: Sources = {}
|
1766
|
+
for source in c.execute(
|
1767
|
+
"""
|
1768
|
+
SELECT source.*
|
1769
|
+
FROM source, mixture_source
|
1770
|
+
WHERE ? = mixture_source.mixture_id AND source.id = mixture_source.source_id
|
1771
|
+
""",
|
1772
|
+
(mixture.id,),
|
1773
|
+
).fetchall():
|
1774
|
+
s = SourceRecord(*source)
|
1775
|
+
category = c.execute("SELECT category FROM source_file WHERE ? = id", (s.file_id,)).fetchone()[0]
|
1776
|
+
sources[category] = to_source(s)
|
1777
|
+
|
1778
|
+
return to_mixture(mixture, sources)
|
1779
|
+
|
1780
|
+
|
1781
|
+
def _speaker(db: partial, s_id: int | None, tier: str, use_cache: bool = True) -> str | None:
|
1782
|
+
if use_cache:
|
1783
|
+
return __speaker(db, s_id, tier)
|
1784
|
+
return __speaker.__wrapped__(db, s_id, tier)
|
1785
|
+
|
1786
|
+
|
1787
|
+
@lru_cache
|
1788
|
+
def __speaker(db: partial, s_id: int | None, tier: str) -> str | None:
|
1789
|
+
if s_id is None:
|
1790
|
+
return None
|
1791
|
+
|
1792
|
+
with db() as c:
|
1793
|
+
data = c.execute(f"SELECT {tier} FROM speaker WHERE ? = id", (s_id,)).fetchone()
|
1794
|
+
if data is None:
|
1795
|
+
return None
|
1796
|
+
if data[0] is None:
|
1797
|
+
return None
|
1798
|
+
return data[0]
|
1799
|
+
|
1800
|
+
|
1801
|
+
def _category_truth_configs(db: partial, category: str, use_cache: bool = True) -> dict[str, str]:
|
1802
|
+
if use_cache:
|
1803
|
+
return __category_truth_configs(db, category)
|
1804
|
+
return __category_truth_configs.__wrapped__(db, category)
|
1805
|
+
|
1806
|
+
|
1807
|
+
@lru_cache
|
1808
|
+
def __category_truth_configs(db: partial, category: str) -> dict[str, str]:
|
1809
|
+
import json
|
1810
|
+
|
1811
|
+
truth_configs: dict[str, str] = {}
|
1812
|
+
with db() as c:
|
1813
|
+
s_ids = c.execute("SELECT id FROM source_file WHERE ? = category", (category,)).fetchall()
|
1814
|
+
|
1815
|
+
for s_id in s_ids:
|
1816
|
+
for truth_config_record in c.execute(
|
1817
|
+
"""
|
1818
|
+
SELECT truth_config.config
|
1819
|
+
FROM truth_config, source_file_truth_config
|
1820
|
+
WHERE ? = source_file_truth_config.source_file_id AND truth_config.id = source_file_truth_config.truth_config_id
|
1821
|
+
""",
|
1822
|
+
(s_id[0],),
|
1823
|
+
).fetchall():
|
1824
|
+
truth_config = json.loads(truth_config_record[0])
|
1825
|
+
truth_configs[truth_config["name"]] = truth_config["function"]
|
1826
|
+
return truth_configs
|
1827
|
+
|
1828
|
+
|
1829
|
+
def _source_truth_configs(db: partial, s_id: int, use_cache: bool = True) -> TruthConfigs:
|
1830
|
+
if use_cache:
|
1831
|
+
return __source_truth_configs(db, s_id)
|
1832
|
+
return __source_truth_configs.__wrapped__(db, s_id)
|
1833
|
+
|
1834
|
+
|
1835
|
+
@lru_cache
|
1836
|
+
def __source_truth_configs(db: partial, s_id: int) -> TruthConfigs:
|
1837
|
+
import json
|
1838
|
+
|
1839
|
+
from ..datatypes import TruthConfig
|
1840
|
+
|
1841
|
+
truth_configs: TruthConfigs = {}
|
1842
|
+
with db() as c:
|
1843
|
+
for truth_config_record in c.execute(
|
1844
|
+
"""
|
1845
|
+
SELECT truth_config.config
|
1846
|
+
FROM truth_config, source_file_truth_config
|
1847
|
+
WHERE ? = source_file_truth_config.source_file_id AND truth_config.id = source_file_truth_config.truth_config_id
|
1848
|
+
""",
|
1849
|
+
(s_id,),
|
1850
|
+
).fetchall():
|
1851
|
+
truth_config = json.loads(truth_config_record[0])
|
1852
|
+
truth_configs[truth_config["name"]] = TruthConfig(
|
1853
|
+
function=truth_config["function"],
|
1854
|
+
stride_reduction=truth_config["stride_reduction"],
|
1855
|
+
config=truth_config["config"],
|
1856
|
+
)
|
1857
|
+
return truth_configs
|