sonusai 0.18.8__py3-none-any.whl → 0.19.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +93 -77
- sonusai/genmetrics.py +59 -46
- sonusai/genmix.py +116 -104
- sonusai/genmixdb.py +194 -153
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +15 -17
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +19 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +50 -46
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +677 -473
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +52 -85
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +40 -27
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
- sonusai-0.19.5.dist-info/RECORD +125 -0
- {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.8.dist-info/RECORD +0 -125
- {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py
CHANGED
@@ -1,46 +1,47 @@
|
|
1
|
+
# ruff: noqa: S608
|
1
2
|
from functools import cached_property
|
2
3
|
from functools import lru_cache
|
3
4
|
from functools import partial
|
4
5
|
from sqlite3 import Connection
|
5
6
|
from sqlite3 import Cursor
|
6
7
|
from typing import Any
|
7
|
-
|
8
|
-
|
9
|
-
from
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from
|
13
|
-
from
|
14
|
-
from
|
15
|
-
from
|
16
|
-
from
|
17
|
-
from
|
18
|
-
from
|
19
|
-
from
|
20
|
-
from
|
21
|
-
from
|
22
|
-
from
|
23
|
-
from
|
24
|
-
from
|
25
|
-
from
|
26
|
-
from
|
27
|
-
from
|
28
|
-
from
|
29
|
-
from
|
30
|
-
from
|
31
|
-
from
|
32
|
-
from
|
33
|
-
from
|
34
|
-
from
|
8
|
+
|
9
|
+
from .datatypes import ASRConfigs
|
10
|
+
from .datatypes import AudioF
|
11
|
+
from .datatypes import AudiosF
|
12
|
+
from .datatypes import AudiosT
|
13
|
+
from .datatypes import AudioT
|
14
|
+
from .datatypes import ClassCount
|
15
|
+
from .datatypes import Feature
|
16
|
+
from .datatypes import FeatureGeneratorConfig
|
17
|
+
from .datatypes import FeatureGeneratorInfo
|
18
|
+
from .datatypes import GeneralizedIDs
|
19
|
+
from .datatypes import ImpulseResponseFiles
|
20
|
+
from .datatypes import MetricDoc
|
21
|
+
from .datatypes import MetricDocs
|
22
|
+
from .datatypes import Mixture
|
23
|
+
from .datatypes import Mixtures
|
24
|
+
from .datatypes import NoiseFile
|
25
|
+
from .datatypes import NoiseFiles
|
26
|
+
from .datatypes import Segsnr
|
27
|
+
from .datatypes import SpectralMask
|
28
|
+
from .datatypes import SpectralMasks
|
29
|
+
from .datatypes import SpeechMetadata
|
30
|
+
from .datatypes import TargetFile
|
31
|
+
from .datatypes import TargetFiles
|
32
|
+
from .datatypes import TransformConfig
|
33
|
+
from .datatypes import TruthConfigs
|
34
|
+
from .datatypes import TruthDict
|
35
|
+
from .datatypes import UniversalSNR
|
35
36
|
|
36
37
|
|
37
38
|
def db_file(location: str, test: bool = False) -> str:
|
38
39
|
from os.path import join
|
39
40
|
|
40
41
|
if test:
|
41
|
-
name =
|
42
|
+
name = "mixdb_test.db"
|
42
43
|
else:
|
43
|
-
name =
|
44
|
+
name = "mixdb.db"
|
44
45
|
|
45
46
|
return join(location, name)
|
46
47
|
|
@@ -50,19 +51,17 @@ def db_connection(location: str, create: bool = False, readonly: bool = True, te
|
|
50
51
|
from os import remove
|
51
52
|
from os.path import exists
|
52
53
|
|
53
|
-
from sonusai import SonusAIError
|
54
|
-
|
55
54
|
name = db_file(location, test)
|
56
55
|
if create and exists(name):
|
57
56
|
remove(name)
|
58
57
|
|
59
58
|
if not create and not exists(name):
|
60
|
-
raise
|
59
|
+
raise OSError(f"Could not find mixture database in {location}")
|
61
60
|
|
62
61
|
if not create and readonly:
|
63
|
-
name +=
|
62
|
+
name += "?mode=ro"
|
64
63
|
|
65
|
-
connection = sqlite3.connect(
|
64
|
+
connection = sqlite3.connect("file:" + name, uri=True)
|
66
65
|
# connection.set_trace_callback(print)
|
67
66
|
return connection
|
68
67
|
|
@@ -103,25 +102,23 @@ class MixtureDatabase:
|
|
103
102
|
num_classes=self.num_classes,
|
104
103
|
spectral_masks=self.spectral_masks,
|
105
104
|
target_files=self.target_files,
|
106
|
-
truth_mutex=self.truth_mutex,
|
107
|
-
truth_reduction_function=self.truth_reduction_function
|
108
105
|
)
|
109
106
|
return config.to_json(indent=2)
|
110
107
|
|
111
108
|
def save(self) -> None:
|
112
|
-
"""Save the MixtureDatabase as a JSON file
|
113
|
-
"""
|
109
|
+
"""Save the MixtureDatabase as a JSON file"""
|
114
110
|
from os.path import join
|
115
111
|
|
116
|
-
json_name = join(self.location,
|
117
|
-
with open(file=json_name, mode=
|
112
|
+
json_name = join(self.location, "mixdb.json")
|
113
|
+
with open(file=json_name, mode="w") as file:
|
118
114
|
file.write(self.json)
|
119
115
|
|
120
116
|
@cached_property
|
121
117
|
def fg_config(self) -> FeatureGeneratorConfig:
|
122
|
-
return FeatureGeneratorConfig(
|
123
|
-
|
124
|
-
|
118
|
+
return FeatureGeneratorConfig(
|
119
|
+
feature_mode=self.feature,
|
120
|
+
truth_parameters=self.truth_parameters,
|
121
|
+
)
|
125
122
|
|
126
123
|
@cached_property
|
127
124
|
def fg_info(self) -> FeatureGeneratorInfo:
|
@@ -130,19 +127,18 @@ class MixtureDatabase:
|
|
130
127
|
return get_feature_generator_info(self.fg_config)
|
131
128
|
|
132
129
|
@cached_property
|
133
|
-
def
|
134
|
-
with self.db() as c:
|
135
|
-
return int(c.execute("SELECT top.num_classes FROM top").fetchone()[0])
|
136
|
-
|
137
|
-
@cached_property
|
138
|
-
def truth_mutex(self) -> bool:
|
130
|
+
def truth_parameters(self) -> dict[str, int]:
|
139
131
|
with self.db() as c:
|
140
|
-
|
132
|
+
rows = c.execute("SELECT * FROM truth_parameters").fetchall()
|
133
|
+
truth_parameters: dict[str, int] = {}
|
134
|
+
for row in rows:
|
135
|
+
truth_parameters[row[1]] = row[2]
|
136
|
+
return truth_parameters
|
141
137
|
|
142
138
|
@cached_property
|
143
|
-
def
|
139
|
+
def num_classes(self) -> int:
|
144
140
|
with self.db() as c:
|
145
|
-
return
|
141
|
+
return int(c.execute("SELECT top.num_classes FROM top").fetchone()[0])
|
146
142
|
|
147
143
|
@cached_property
|
148
144
|
def noise_mix_mode(self) -> str:
|
@@ -158,68 +154,152 @@ class MixtureDatabase:
|
|
158
154
|
|
159
155
|
@cached_property
|
160
156
|
def supported_metrics(self) -> MetricDocs:
|
161
|
-
metrics = MetricDocs(
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
157
|
+
metrics = MetricDocs(
|
158
|
+
[
|
159
|
+
MetricDoc("Mixture Metrics", "mxsnr", "SNR specification in dB"),
|
160
|
+
MetricDoc(
|
161
|
+
"Mixture Metrics",
|
162
|
+
"mxssnr_avg",
|
163
|
+
"Segmental SNR average over all frames",
|
164
|
+
),
|
165
|
+
MetricDoc(
|
166
|
+
"Mixture Metrics",
|
167
|
+
"mxssnr_std",
|
168
|
+
"Segmental SNR standard deviation over all frames",
|
169
|
+
),
|
170
|
+
MetricDoc(
|
171
|
+
"Mixture Metrics",
|
172
|
+
"mxssnrdb_avg",
|
173
|
+
"Segmental SNR average of the dB frame values over all frames",
|
174
|
+
),
|
175
|
+
MetricDoc(
|
176
|
+
"Mixture Metrics",
|
177
|
+
"mxssnrdb_std",
|
178
|
+
"Segmental SNR standard deviation of the dB frame values over all frames",
|
179
|
+
),
|
180
|
+
MetricDoc(
|
181
|
+
"Mixture Metrics",
|
182
|
+
"mxssnrf_avg",
|
183
|
+
"Per-bin segmental SNR average over all frames (using feature transform)",
|
184
|
+
),
|
185
|
+
MetricDoc(
|
186
|
+
"Mixture Metrics",
|
187
|
+
"mxssnrf_std",
|
188
|
+
"Per-bin segmental SNR standard deviation over all frames (using feature transform)",
|
189
|
+
),
|
190
|
+
MetricDoc(
|
191
|
+
"Mixture Metrics",
|
192
|
+
"mxssnrdbf_avg",
|
193
|
+
"Per-bin segmental average of the dB frame values over all frames (using feature transform)",
|
194
|
+
),
|
195
|
+
MetricDoc(
|
196
|
+
"Mixture Metrics",
|
197
|
+
"mxssnrdbf_std",
|
198
|
+
"Per-bin segmental standard deviation of the dB frame values over all frames (using feature transform)",
|
199
|
+
),
|
200
|
+
MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true target[0]"),
|
201
|
+
MetricDoc(
|
202
|
+
"Mixture Metrics",
|
203
|
+
"mxwsdr",
|
204
|
+
"Weighted signal distorion ratio of mixture versus true target[0]",
|
205
|
+
),
|
206
|
+
MetricDoc(
|
207
|
+
"Mixture Metrics",
|
208
|
+
"mxpd",
|
209
|
+
"Phase distance between mixture and true target[0]",
|
210
|
+
),
|
211
|
+
MetricDoc(
|
212
|
+
"Mixture Metrics",
|
213
|
+
"mxstoi",
|
214
|
+
"Short term objective intelligibility of mixture versus true target[0]",
|
215
|
+
),
|
216
|
+
MetricDoc(
|
217
|
+
"Mixture Metrics",
|
218
|
+
"mxcsig",
|
219
|
+
"Predicted rating of speech distortion of mixture versus true target[0]",
|
220
|
+
),
|
221
|
+
MetricDoc(
|
222
|
+
"Mixture Metrics",
|
223
|
+
"mxcbak",
|
224
|
+
"Predicted rating of background distortion of mixture versus true target[0]",
|
225
|
+
),
|
226
|
+
MetricDoc(
|
227
|
+
"Mixture Metrics",
|
228
|
+
"mxcovl",
|
229
|
+
"Predicted rating of overall quality of mixture versus true target[0]",
|
230
|
+
),
|
231
|
+
MetricDoc("Mixture Metrics", "ssnr", "Segmental SNR"),
|
232
|
+
MetricDoc("Target Metrics", "tdco", "Target[0] DC offset"),
|
233
|
+
MetricDoc("Target Metrics", "tmin", "Target[0] min level"),
|
234
|
+
MetricDoc("Target Metrics", "tmax", "Target[0] max levl"),
|
235
|
+
MetricDoc("Target Metrics", "tpkdb", "Target[0] Pk lev dB"),
|
236
|
+
MetricDoc("Target Metrics", "tlrms", "Target[0] RMS lev dB"),
|
237
|
+
MetricDoc("Target Metrics", "tpkr", "Target[0] RMS Pk dB"),
|
238
|
+
MetricDoc("Target Metrics", "ttr", "Target[0] RMS Tr dB"),
|
239
|
+
MetricDoc("Target Metrics", "tcr", "Target[0] Crest factor"),
|
240
|
+
MetricDoc("Target Metrics", "tfl", "Target[0] Flat factor"),
|
241
|
+
MetricDoc("Target Metrics", "tpkc", "Target[0] Pk count"),
|
242
|
+
MetricDoc("Noise Metrics", "ndco", "Noise DC offset"),
|
243
|
+
MetricDoc("Noise Metrics", "nmin", "Noise min level"),
|
244
|
+
MetricDoc("Noise Metrics", "nmax", "Noise max levl"),
|
245
|
+
MetricDoc("Noise Metrics", "npkdb", "Noise Pk lev dB"),
|
246
|
+
MetricDoc("Noise Metrics", "nlrms", "Noise RMS lev dB"),
|
247
|
+
MetricDoc("Noise Metrics", "npkr", "Noise RMS Pk dB"),
|
248
|
+
MetricDoc("Noise Metrics", "ntr", "Noise RMS Tr dB"),
|
249
|
+
MetricDoc("Noise Metrics", "ncr", "Noise Crest factor"),
|
250
|
+
MetricDoc("Noise Metrics", "nfl", "Noise Flat factor"),
|
251
|
+
MetricDoc("Noise Metrics", "npkc", "Noise Pk count"),
|
252
|
+
MetricDoc(
|
253
|
+
"Truth Metrics",
|
254
|
+
"sedavg",
|
255
|
+
"(not implemented) Average SED activity over all frames [truth_parameters, 1]",
|
256
|
+
),
|
257
|
+
MetricDoc(
|
258
|
+
"Truth Metrics",
|
259
|
+
"sedcnt",
|
260
|
+
"(not implemented) Count in number of frames that SED is active [truth_parameters, 1]",
|
261
|
+
),
|
262
|
+
MetricDoc(
|
263
|
+
"Truth Metrics",
|
264
|
+
"sedtop3",
|
265
|
+
"(not implemented) 3 most active by largest sedavg [3, 1]",
|
266
|
+
),
|
267
|
+
MetricDoc(
|
268
|
+
"Truth Metrics",
|
269
|
+
"sedtopn",
|
270
|
+
"(not implemented) N most active by largest sedavg [N, 1]",
|
271
|
+
),
|
272
|
+
]
|
273
|
+
)
|
216
274
|
for name in self.asr_configs:
|
217
|
-
metrics.append(
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
275
|
+
metrics.append(
|
276
|
+
MetricDoc(
|
277
|
+
"Target Metrics",
|
278
|
+
f"tasr.{name}",
|
279
|
+
f"Target[0] ASR text using {name} ASR as defined in mixdb asr_configs parameter",
|
280
|
+
)
|
281
|
+
)
|
282
|
+
metrics.append(
|
283
|
+
MetricDoc(
|
284
|
+
"Mixture Metrics",
|
285
|
+
f"mxasr.{name}",
|
286
|
+
f"ASR text using {name} ASR as defined in mixdb asr_configs parameter",
|
287
|
+
)
|
288
|
+
)
|
289
|
+
metrics.append(
|
290
|
+
MetricDoc(
|
291
|
+
"Target Metrics",
|
292
|
+
f"basewer.{name}",
|
293
|
+
f"Word error rate of tasr.{name} vs. speech text metadata for the target",
|
294
|
+
)
|
295
|
+
)
|
296
|
+
metrics.append(
|
297
|
+
MetricDoc(
|
298
|
+
"Mixture Metrics",
|
299
|
+
f"mxwer.{name}",
|
300
|
+
f"Word error rate of mxasr.{name} vs. tasr.{name}",
|
301
|
+
)
|
302
|
+
)
|
223
303
|
|
224
304
|
return metrics
|
225
305
|
|
@@ -265,7 +345,7 @@ class MixtureDatabase:
|
|
265
345
|
def transform_frame_ms(self) -> float:
|
266
346
|
from .constants import SAMPLE_RATE
|
267
347
|
|
268
|
-
return float(self.ft_config.
|
348
|
+
return float(self.ft_config.overlap) / float(SAMPLE_RATE / 1000)
|
269
349
|
|
270
350
|
@cached_property
|
271
351
|
def feature_ms(self) -> float:
|
@@ -273,7 +353,7 @@ class MixtureDatabase:
|
|
273
353
|
|
274
354
|
@cached_property
|
275
355
|
def feature_samples(self) -> int:
|
276
|
-
return self.ft_config.
|
356
|
+
return self.ft_config.overlap * self.fg_decimation * self.fg_stride
|
277
357
|
|
278
358
|
@cached_property
|
279
359
|
def feature_step_ms(self) -> float:
|
@@ -281,28 +361,33 @@ class MixtureDatabase:
|
|
281
361
|
|
282
362
|
@cached_property
|
283
363
|
def feature_step_samples(self) -> int:
|
284
|
-
return self.ft_config.
|
364
|
+
return self.ft_config.overlap * self.fg_decimation * self.fg_step
|
285
365
|
|
286
|
-
def total_samples(self, m_ids: GeneralizedIDs =
|
287
|
-
|
366
|
+
def total_samples(self, m_ids: GeneralizedIDs = "*") -> int:
|
367
|
+
samples = 0
|
368
|
+
for m_id in self.mixids_to_list(m_ids):
|
369
|
+
s = self.mixture(m_id).samples
|
370
|
+
if s is not None:
|
371
|
+
samples += s
|
372
|
+
return samples
|
288
373
|
|
289
|
-
def total_transform_frames(self, m_ids: GeneralizedIDs =
|
290
|
-
return self.total_samples(m_ids) // self.ft_config.
|
374
|
+
def total_transform_frames(self, m_ids: GeneralizedIDs = "*") -> int:
|
375
|
+
return self.total_samples(m_ids) // self.ft_config.overlap
|
291
376
|
|
292
|
-
def total_feature_frames(self, m_ids: GeneralizedIDs =
|
377
|
+
def total_feature_frames(self, m_ids: GeneralizedIDs = "*") -> int:
|
293
378
|
return self.total_samples(m_ids) // self.feature_step_samples
|
294
379
|
|
295
380
|
def mixture_transform_frames(self, m_id: int) -> int:
|
296
381
|
from .helpers import frames_from_samples
|
297
382
|
|
298
|
-
return frames_from_samples(self.mixture(m_id).samples, self.ft_config.
|
383
|
+
return frames_from_samples(self.mixture(m_id).samples, self.ft_config.overlap)
|
299
384
|
|
300
385
|
def mixture_feature_frames(self, m_id: int) -> int:
|
301
386
|
from .helpers import frames_from_samples
|
302
387
|
|
303
388
|
return frames_from_samples(self.mixture(m_id).samples, self.feature_step_samples)
|
304
389
|
|
305
|
-
def mixids_to_list(self, m_ids:
|
390
|
+
def mixids_to_list(self, m_ids: GeneralizedIDs = "*") -> list[int]:
|
306
391
|
"""Resolve generalized mixture IDs to a list of integers
|
307
392
|
|
308
393
|
:param m_ids: Generalized mixture IDs
|
@@ -319,8 +404,10 @@ class MixtureDatabase:
|
|
319
404
|
:return: Class labels
|
320
405
|
"""
|
321
406
|
with self.db() as c:
|
322
|
-
return [
|
323
|
-
|
407
|
+
return [
|
408
|
+
str(item[0])
|
409
|
+
for item in c.execute("SELECT class_label.label FROM class_label ORDER BY class_label.id").fetchall()
|
410
|
+
]
|
324
411
|
|
325
412
|
@cached_property
|
326
413
|
def class_weights_thresholds(self) -> list[float]:
|
@@ -329,8 +416,37 @@ class MixtureDatabase:
|
|
329
416
|
:return: Class weights thresholds
|
330
417
|
"""
|
331
418
|
with self.db() as c:
|
332
|
-
return [
|
333
|
-
|
419
|
+
return [
|
420
|
+
float(item[0])
|
421
|
+
for item in c.execute(
|
422
|
+
"SELECT class_weights_threshold.threshold FROM class_weights_threshold"
|
423
|
+
).fetchall()
|
424
|
+
]
|
425
|
+
|
426
|
+
@cached_property
|
427
|
+
def truth_configs(self) -> TruthConfigs:
|
428
|
+
"""Get truth configs from db
|
429
|
+
|
430
|
+
:return: Truth configs
|
431
|
+
"""
|
432
|
+
import json
|
433
|
+
|
434
|
+
from .datatypes import TruthConfig
|
435
|
+
|
436
|
+
with self.db() as c:
|
437
|
+
truth_configs: TruthConfigs = {}
|
438
|
+
for truth_config_record in c.execute("SELECT truth_config.config FROM truth_config").fetchall():
|
439
|
+
truth_config = json.loads(truth_config_record[0])
|
440
|
+
if truth_config["name"] not in truth_configs:
|
441
|
+
truth_configs[truth_config["name"]] = TruthConfig(
|
442
|
+
function=truth_config["function"],
|
443
|
+
stride_reduction=truth_config["stride_reduction"],
|
444
|
+
config=truth_config["config"],
|
445
|
+
)
|
446
|
+
return truth_configs
|
447
|
+
|
448
|
+
def target_truth_configs(self, t_id: int) -> TruthConfigs:
|
449
|
+
return _target_truth_configs(self.db, t_id)
|
334
450
|
|
335
451
|
@cached_property
|
336
452
|
def random_snrs(self) -> list[float]:
|
@@ -339,8 +455,12 @@ class MixtureDatabase:
|
|
339
455
|
:return: Random SNRs
|
340
456
|
"""
|
341
457
|
with self.db() as c:
|
342
|
-
return list(
|
343
|
-
|
458
|
+
return list(
|
459
|
+
{
|
460
|
+
float(item[0])
|
461
|
+
for item in c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 1").fetchall()
|
462
|
+
}
|
463
|
+
)
|
344
464
|
|
345
465
|
@cached_property
|
346
466
|
def snrs(self) -> list[float]:
|
@@ -349,13 +469,21 @@ class MixtureDatabase:
|
|
349
469
|
:return: SNRs
|
350
470
|
"""
|
351
471
|
with self.db() as c:
|
352
|
-
return list(
|
353
|
-
|
472
|
+
return list(
|
473
|
+
{
|
474
|
+
float(item[0])
|
475
|
+
for item in c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 0").fetchall()
|
476
|
+
}
|
477
|
+
)
|
354
478
|
|
355
479
|
@cached_property
|
356
480
|
def all_snrs(self) -> list[UniversalSNR]:
|
357
|
-
return sorted(
|
358
|
-
|
481
|
+
return sorted(
|
482
|
+
set(
|
483
|
+
[UniversalSNR(is_random=False, value=snr) for snr in self.snrs]
|
484
|
+
+ [UniversalSNR(is_random=True, value=snr) for snr in self.random_snrs]
|
485
|
+
)
|
486
|
+
)
|
359
487
|
|
360
488
|
@cached_property
|
361
489
|
def spectral_masks(self) -> SpectralMasks:
|
@@ -366,13 +494,19 @@ class MixtureDatabase:
|
|
366
494
|
from .db_datatypes import SpectralMaskRecord
|
367
495
|
|
368
496
|
with self.db() as c:
|
369
|
-
spectral_masks = [
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
497
|
+
spectral_masks = [
|
498
|
+
SpectralMaskRecord(*result) for result in c.execute("SELECT * FROM spectral_mask").fetchall()
|
499
|
+
]
|
500
|
+
return [
|
501
|
+
SpectralMask(
|
502
|
+
f_max_width=spectral_mask.f_max_width,
|
503
|
+
f_num=spectral_mask.f_num,
|
504
|
+
t_max_width=spectral_mask.t_max_width,
|
505
|
+
t_num=spectral_mask.t_num,
|
506
|
+
t_max_percent=spectral_mask.t_max_percent,
|
507
|
+
)
|
508
|
+
for spectral_mask in spectral_masks
|
509
|
+
]
|
376
510
|
|
377
511
|
def spectral_mask(self, sm_id: int) -> SpectralMask:
|
378
512
|
"""Get spectral mask with ID from db
|
@@ -390,31 +524,40 @@ class MixtureDatabase:
|
|
390
524
|
"""
|
391
525
|
import json
|
392
526
|
|
393
|
-
from .datatypes import
|
394
|
-
from .datatypes import
|
527
|
+
from .datatypes import TruthConfig
|
528
|
+
from .datatypes import TruthConfigs
|
395
529
|
from .db_datatypes import TargetFileRecord
|
396
530
|
|
397
531
|
with self.db() as c:
|
398
532
|
target_files: TargetFiles = []
|
399
|
-
target_file_records = [
|
400
|
-
|
533
|
+
target_file_records = [
|
534
|
+
TargetFileRecord(*result) for result in c.execute("SELECT * FROM target_file").fetchall()
|
535
|
+
]
|
401
536
|
for target_file_record in target_file_records:
|
402
|
-
|
403
|
-
for
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
537
|
+
truth_configs: TruthConfigs = {}
|
538
|
+
for truth_config_records in c.execute(
|
539
|
+
"SELECT truth_config.config "
|
540
|
+
+ "FROM truth_config, target_file_truth_config "
|
541
|
+
+ "WHERE ? = target_file_truth_config.target_file_id "
|
542
|
+
+ "AND truth_config.id = target_file_truth_config.truth_config_id",
|
543
|
+
(target_file_record.id,),
|
544
|
+
).fetchall():
|
545
|
+
truth_config = json.loads(truth_config_records[0])
|
546
|
+
truth_configs[truth_config["name"]] = TruthConfig(
|
547
|
+
function=truth_config["function"],
|
548
|
+
stride_reduction=truth_config["stride_reduction"],
|
549
|
+
config=truth_config["config"],
|
550
|
+
)
|
551
|
+
target_files.append(
|
552
|
+
TargetFile(
|
553
|
+
name=target_file_record.name,
|
554
|
+
samples=target_file_record.samples,
|
555
|
+
class_indices=json.loads(target_file_record.class_indices),
|
556
|
+
level_type=target_file_record.level_type,
|
557
|
+
truth_configs=truth_configs,
|
558
|
+
speaker_id=target_file_record.speaker_id,
|
559
|
+
)
|
560
|
+
)
|
418
561
|
return target_files
|
419
562
|
|
420
563
|
@cached_property
|
@@ -450,8 +593,10 @@ class MixtureDatabase:
|
|
450
593
|
:return: Noise files
|
451
594
|
"""
|
452
595
|
with self.db() as c:
|
453
|
-
return [
|
454
|
-
|
596
|
+
return [
|
597
|
+
NoiseFile(name=noise[0], samples=noise[1])
|
598
|
+
for noise in c.execute("SELECT noise_file.name, samples FROM noise_file").fetchall()
|
599
|
+
]
|
455
600
|
|
456
601
|
@cached_property
|
457
602
|
def noise_file_ids(self) -> list[int]:
|
@@ -485,9 +630,21 @@ class MixtureDatabase:
|
|
485
630
|
|
486
631
|
:return: Impulse response files
|
487
632
|
"""
|
633
|
+
import json
|
634
|
+
|
635
|
+
from .datatypes import ImpulseResponseFile
|
636
|
+
|
488
637
|
with self.db() as c:
|
489
|
-
|
490
|
-
|
638
|
+
# for impulse_response in c.execute(
|
639
|
+
# "SELECT impulse_response_file.* FROM impulse_response_file"
|
640
|
+
# ).fetchall():
|
641
|
+
# print(impulse_response)
|
642
|
+
return [
|
643
|
+
ImpulseResponseFile(impulse_response[1], json.loads(impulse_response[2]))
|
644
|
+
for impulse_response in c.execute(
|
645
|
+
"SELECT impulse_response_file.* FROM impulse_response_file"
|
646
|
+
).fetchall()
|
647
|
+
]
|
491
648
|
|
492
649
|
@cached_property
|
493
650
|
def impulse_response_file_ids(self) -> list[int]:
|
@@ -496,15 +653,19 @@ class MixtureDatabase:
|
|
496
653
|
:return: List of impulse response file IDs
|
497
654
|
"""
|
498
655
|
with self.db() as c:
|
499
|
-
return [
|
500
|
-
|
656
|
+
return [
|
657
|
+
int(item[0])
|
658
|
+
for item in c.execute("SELECT impulse_response_file.id FROM impulse_response_file").fetchall()
|
659
|
+
]
|
501
660
|
|
502
|
-
def impulse_response_file(self, ir_id: int) -> str:
|
661
|
+
def impulse_response_file(self, ir_id: int | None) -> str | None:
|
503
662
|
"""Get impulse response file with ID from db
|
504
663
|
|
505
664
|
:param ir_id: Impulse response file ID
|
506
665
|
:return: Noise
|
507
666
|
"""
|
667
|
+
if ir_id is None:
|
668
|
+
return None
|
508
669
|
return _impulse_response_file(self.db, ir_id)
|
509
670
|
|
510
671
|
@cached_property
|
@@ -522,18 +683,22 @@ class MixtureDatabase:
|
|
522
683
|
|
523
684
|
:return: Mixtures
|
524
685
|
"""
|
525
|
-
from .helpers import to_mixture
|
526
|
-
from .helpers import to_target
|
527
686
|
from .db_datatypes import MixtureRecord
|
528
687
|
from .db_datatypes import TargetRecord
|
688
|
+
from .helpers import to_mixture
|
689
|
+
from .helpers import to_target
|
529
690
|
|
530
691
|
with self.db() as c:
|
531
692
|
mixtures: Mixtures = []
|
532
693
|
for mixture in [MixtureRecord(*record) for record in c.execute("SELECT * FROM mixture").fetchall()]:
|
533
|
-
targets = [
|
534
|
-
|
535
|
-
|
536
|
-
|
694
|
+
targets = [
|
695
|
+
to_target(TargetRecord(*target))
|
696
|
+
for target in c.execute(
|
697
|
+
"SELECT target.* FROM target, mixture_target "
|
698
|
+
+ "WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
|
699
|
+
(mixture.id,),
|
700
|
+
).fetchall()
|
701
|
+
]
|
537
702
|
mixtures.append(to_mixture(mixture, targets))
|
538
703
|
return mixtures
|
539
704
|
|
@@ -559,23 +724,15 @@ class MixtureDatabase:
|
|
559
724
|
with self.db() as c:
|
560
725
|
return int(c.execute("SELECT top.mixid_width FROM top").fetchone()[0])
|
561
726
|
|
562
|
-
def
|
563
|
-
"""
|
727
|
+
def mixture_location(self, m_id: int) -> str:
|
728
|
+
"""Get the file location for the give mixture ID
|
564
729
|
|
565
|
-
:param
|
566
|
-
:return:
|
730
|
+
:param m_id: Zero-based mixture ID
|
731
|
+
:return: File location
|
567
732
|
"""
|
568
733
|
from os.path import join
|
569
734
|
|
570
|
-
return join(self.location, name)
|
571
|
-
|
572
|
-
def mixture_filename(self, m_id: int) -> str:
|
573
|
-
"""Get the HDF5 file name for the give mixture ID
|
574
|
-
|
575
|
-
:param m_id: Zero-based mixture ID
|
576
|
-
:return: File name
|
577
|
-
"""
|
578
|
-
return self.location_filename(self.mixture(m_id).name)
|
735
|
+
return join(self.location, self.mixture(m_id).name)
|
579
736
|
|
580
737
|
@cached_property
|
581
738
|
def num_mixtures(self) -> int:
|
@@ -593,9 +750,9 @@ class MixtureDatabase:
|
|
593
750
|
:param items: String(s) of dataset(s) to retrieve
|
594
751
|
:return: Data (or tuple of data)
|
595
752
|
"""
|
596
|
-
from .
|
753
|
+
from sonusai.mixture import read_cached_data
|
597
754
|
|
598
|
-
return
|
755
|
+
return read_cached_data(self.location, "mixture", self.mixture(m_id).name, items)
|
599
756
|
|
600
757
|
def read_target_audio(self, t_id: int) -> AudioT:
|
601
758
|
"""Read target audio
|
@@ -622,10 +779,19 @@ class MixtureDatabase:
|
|
622
779
|
audio = read_audio(noise.name)
|
623
780
|
audio = apply_augmentation(audio, mixture.noise.augmentation)
|
624
781
|
if mixture.noise.augmentation.ir is not None:
|
625
|
-
audio = apply_impulse_response(
|
782
|
+
audio = apply_impulse_response(
|
783
|
+
audio,
|
784
|
+
read_ir(self.impulse_response_file(mixture.noise.augmentation.ir)),
|
785
|
+
)
|
626
786
|
|
627
787
|
return audio
|
628
788
|
|
789
|
+
def mixture_class_indices(self, m_id: int) -> list[int]:
|
790
|
+
class_indices: list[int] = []
|
791
|
+
for t_id in self.mixture(m_id).target_ids:
|
792
|
+
class_indices.extend(self.target_file(t_id).class_indices)
|
793
|
+
return sorted(set(class_indices))
|
794
|
+
|
629
795
|
def mixture_targets(self, m_id: int, force: bool = False) -> AudiosT:
|
630
796
|
"""Get the list of augmented target audio data (one per target in the mixup) for the given mixture ID
|
631
797
|
|
@@ -633,36 +799,34 @@ class MixtureDatabase:
|
|
633
799
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
634
800
|
:return: List of augmented target audio data (one per target in the mixup)
|
635
801
|
"""
|
636
|
-
from sonusai import SonusAIError
|
637
802
|
from .augmentation import apply_augmentation
|
638
803
|
from .augmentation import apply_gain
|
639
804
|
from .augmentation import pad_audio_to_length
|
640
805
|
|
641
806
|
if not force:
|
642
|
-
targets_audio = self.read_mixture_data(m_id,
|
807
|
+
targets_audio = self.read_mixture_data(m_id, "targets")
|
643
808
|
if targets_audio is not None:
|
644
809
|
return list(targets_audio)
|
645
810
|
|
646
811
|
mixture = self.mixture(m_id)
|
647
812
|
if mixture is None:
|
648
|
-
raise
|
813
|
+
raise ValueError(f"Could not find mixture for m_id: {m_id}")
|
649
814
|
|
650
815
|
targets_audio = []
|
651
816
|
for target in mixture.targets:
|
652
817
|
target_audio = self.read_target_audio(target.file_id)
|
653
|
-
target_audio = apply_augmentation(
|
654
|
-
|
655
|
-
|
818
|
+
target_audio = apply_augmentation(
|
819
|
+
audio=target_audio,
|
820
|
+
augmentation=target.augmentation,
|
821
|
+
frame_length=self.feature_step_samples,
|
822
|
+
)
|
656
823
|
target_audio = apply_gain(audio=target_audio, gain=mixture.target_snr_gain)
|
657
824
|
target_audio = pad_audio_to_length(audio=target_audio, length=mixture.samples)
|
658
825
|
targets_audio.append(target_audio)
|
659
826
|
|
660
827
|
return targets_audio
|
661
828
|
|
662
|
-
def mixture_targets_f(self,
|
663
|
-
m_id: int,
|
664
|
-
targets: Optional[AudiosT] = None,
|
665
|
-
force: bool = False) -> AudiosF:
|
829
|
+
def mixture_targets_f(self, m_id: int, targets: AudiosT | None = None, force: bool = False) -> AudiosF:
|
666
830
|
"""Get the list of augmented target transform data (one per target in the mixup) for the given mixture ID
|
667
831
|
|
668
832
|
:param m_id: Zero-based mixture ID
|
@@ -677,10 +841,7 @@ class MixtureDatabase:
|
|
677
841
|
|
678
842
|
return [forward_transform(target, self.ft_config) for target in targets]
|
679
843
|
|
680
|
-
def mixture_target(self,
|
681
|
-
m_id: int,
|
682
|
-
targets: Optional[AudiosT] = None,
|
683
|
-
force: bool = False) -> AudioT:
|
844
|
+
def mixture_target(self, m_id: int, targets: AudiosT | None = None, force: bool = False) -> AudioT:
|
684
845
|
"""Get the augmented target audio data for the given mixture ID
|
685
846
|
|
686
847
|
:param m_id: Zero-based mixture ID
|
@@ -691,7 +852,7 @@ class MixtureDatabase:
|
|
691
852
|
from .helpers import get_target
|
692
853
|
|
693
854
|
if not force:
|
694
|
-
target = self.read_mixture_data(m_id,
|
855
|
+
target = self.read_mixture_data(m_id, "target")
|
695
856
|
if target is not None:
|
696
857
|
return target
|
697
858
|
|
@@ -700,11 +861,13 @@ class MixtureDatabase:
|
|
700
861
|
|
701
862
|
return get_target(self, self.mixture(m_id), targets)
|
702
863
|
|
703
|
-
def mixture_target_f(
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
864
|
+
def mixture_target_f(
|
865
|
+
self,
|
866
|
+
m_id: int,
|
867
|
+
targets: AudiosT | None = None,
|
868
|
+
target: AudioT | None = None,
|
869
|
+
force: bool = False,
|
870
|
+
) -> AudioF:
|
708
871
|
"""Get the augmented target transform data for the given mixture ID
|
709
872
|
|
710
873
|
:param m_id: Zero-based mixture ID
|
@@ -720,9 +883,7 @@ class MixtureDatabase:
|
|
720
883
|
|
721
884
|
return forward_transform(target, self.ft_config)
|
722
885
|
|
723
|
-
def mixture_noise(self,
|
724
|
-
m_id: int,
|
725
|
-
force: bool = False) -> AudioT:
|
886
|
+
def mixture_noise(self, m_id: int, force: bool = False) -> AudioT:
|
726
887
|
"""Get the augmented noise audio data for the given mixture ID
|
727
888
|
|
728
889
|
:param m_id: Zero-based mixture ID
|
@@ -733,7 +894,7 @@ class MixtureDatabase:
|
|
733
894
|
from .augmentation import apply_gain
|
734
895
|
|
735
896
|
if not force:
|
736
|
-
noise = self.read_mixture_data(m_id,
|
897
|
+
noise = self.read_mixture_data(m_id, "noise")
|
737
898
|
if noise is not None:
|
738
899
|
return noise
|
739
900
|
|
@@ -742,10 +903,7 @@ class MixtureDatabase:
|
|
742
903
|
noise = get_next_noise(audio=noise, offset=mixture.noise.offset, length=mixture.samples)
|
743
904
|
return apply_gain(audio=noise, gain=mixture.noise_snr_gain)
|
744
905
|
|
745
|
-
def mixture_noise_f(self,
|
746
|
-
m_id: int,
|
747
|
-
noise: Optional[AudioT] = None,
|
748
|
-
force: bool = False) -> AudioF:
|
906
|
+
def mixture_noise_f(self, m_id: int, noise: AudioT | None = None, force: bool = False) -> AudioF:
|
749
907
|
"""Get the augmented noise transform for the given mixture ID
|
750
908
|
|
751
909
|
:param m_id: Zero-based mixture ID
|
@@ -760,12 +918,14 @@ class MixtureDatabase:
|
|
760
918
|
|
761
919
|
return forward_transform(noise, self.ft_config)
|
762
920
|
|
763
|
-
def mixture_mixture(
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
921
|
+
def mixture_mixture(
|
922
|
+
self,
|
923
|
+
m_id: int,
|
924
|
+
targets: AudiosT | None = None,
|
925
|
+
target: AudioT | None = None,
|
926
|
+
noise: AudioT | None = None,
|
927
|
+
force: bool = False,
|
928
|
+
) -> AudioT:
|
769
929
|
"""Get the mixture audio data for the given mixture ID
|
770
930
|
|
771
931
|
:param m_id: Zero-based mixture ID
|
@@ -776,7 +936,7 @@ class MixtureDatabase:
|
|
776
936
|
:return: Mixture audio data
|
777
937
|
"""
|
778
938
|
if not force:
|
779
|
-
mixture = self.read_mixture_data(m_id,
|
939
|
+
mixture = self.read_mixture_data(m_id, "mixture")
|
780
940
|
if mixture is not None:
|
781
941
|
return mixture
|
782
942
|
|
@@ -788,13 +948,15 @@ class MixtureDatabase:
|
|
788
948
|
|
789
949
|
return target + noise
|
790
950
|
|
791
|
-
def mixture_mixture_f(
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
951
|
+
def mixture_mixture_f(
|
952
|
+
self,
|
953
|
+
m_id: int,
|
954
|
+
targets: AudiosT | None = None,
|
955
|
+
target: AudioT | None = None,
|
956
|
+
noise: AudioT | None = None,
|
957
|
+
mixture: AudioT | None = None,
|
958
|
+
force: bool = False,
|
959
|
+
) -> AudioF:
|
798
960
|
"""Get the mixture transform for the given mixture ID
|
799
961
|
|
800
962
|
:param m_id: Zero-based mixture ID
|
@@ -815,18 +977,22 @@ class MixtureDatabase:
|
|
815
977
|
|
816
978
|
m = self.mixture(m_id)
|
817
979
|
if m.spectral_mask_id is not None:
|
818
|
-
mixture_f = apply_spectral_mask(
|
819
|
-
|
820
|
-
|
980
|
+
mixture_f = apply_spectral_mask(
|
981
|
+
audio_f=mixture_f,
|
982
|
+
spectral_mask=self.spectral_mask(int(m.spectral_mask_id)),
|
983
|
+
seed=m.spectral_mask_seed,
|
984
|
+
)
|
821
985
|
|
822
986
|
return mixture_f
|
823
987
|
|
824
|
-
def mixture_truth_t(
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
988
|
+
def mixture_truth_t(
|
989
|
+
self,
|
990
|
+
m_id: int,
|
991
|
+
targets: AudiosT | None = None,
|
992
|
+
noise: AudioT | None = None,
|
993
|
+
mixture: AudioT | None = None,
|
994
|
+
force: bool = False,
|
995
|
+
) -> TruthDict:
|
830
996
|
"""Get the truth_t data for the given mixture ID
|
831
997
|
|
832
998
|
:param m_id: Zero-based mixture ID
|
@@ -836,10 +1002,10 @@ class MixtureDatabase:
|
|
836
1002
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
837
1003
|
:return: truth_t data
|
838
1004
|
"""
|
839
|
-
from .helpers import
|
1005
|
+
from .helpers import get_truth
|
840
1006
|
|
841
1007
|
if not force:
|
842
|
-
truth_t = self.read_mixture_data(m_id,
|
1008
|
+
truth_t = self.read_mixture_data(m_id, "truth_t")
|
843
1009
|
if truth_t is not None:
|
844
1010
|
return truth_t
|
845
1011
|
|
@@ -850,19 +1016,18 @@ class MixtureDatabase:
|
|
850
1016
|
noise = self.mixture_noise(m_id, force)
|
851
1017
|
|
852
1018
|
if force or mixture is None:
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
force: bool = False) -> Segsnr:
|
1019
|
+
mixture = self.mixture_mixture(m_id, targets=targets, noise=noise, force=force)
|
1020
|
+
|
1021
|
+
return get_truth(self, self.mixture(m_id), targets, noise, mixture)
|
1022
|
+
|
1023
|
+
def mixture_segsnr_t(
|
1024
|
+
self,
|
1025
|
+
m_id: int,
|
1026
|
+
targets: AudiosT | None = None,
|
1027
|
+
target: AudioT | None = None,
|
1028
|
+
noise: AudioT | None = None,
|
1029
|
+
force: bool = False,
|
1030
|
+
) -> Segsnr:
|
866
1031
|
"""Get the segsnr_t data for the given mixture ID
|
867
1032
|
|
868
1033
|
:param m_id: Zero-based mixture ID
|
@@ -875,7 +1040,7 @@ class MixtureDatabase:
|
|
875
1040
|
from .helpers import get_segsnr_t
|
876
1041
|
|
877
1042
|
if not force:
|
878
|
-
segsnr_t = self.read_mixture_data(m_id,
|
1043
|
+
segsnr_t = self.read_mixture_data(m_id, "segsnr_t")
|
879
1044
|
if segsnr_t is not None:
|
880
1045
|
return segsnr_t
|
881
1046
|
|
@@ -887,13 +1052,15 @@ class MixtureDatabase:
|
|
887
1052
|
|
888
1053
|
return get_segsnr_t(self, self.mixture(m_id), target, noise)
|
889
1054
|
|
890
|
-
def mixture_segsnr(
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
1055
|
+
def mixture_segsnr(
|
1056
|
+
self,
|
1057
|
+
m_id: int,
|
1058
|
+
segsnr_t: Segsnr | None = None,
|
1059
|
+
targets: AudiosT | None = None,
|
1060
|
+
target: AudioT | None = None,
|
1061
|
+
noise: AudioT | None = None,
|
1062
|
+
force: bool = False,
|
1063
|
+
) -> Segsnr:
|
897
1064
|
"""Get the segsnr data for the given mixture ID
|
898
1065
|
|
899
1066
|
:param m_id: Zero-based mixture ID
|
@@ -905,28 +1072,30 @@ class MixtureDatabase:
|
|
905
1072
|
:return: segsnr data
|
906
1073
|
"""
|
907
1074
|
if not force:
|
908
|
-
segsnr = self.read_mixture_data(m_id,
|
1075
|
+
segsnr = self.read_mixture_data(m_id, "segsnr")
|
909
1076
|
if segsnr is not None:
|
910
1077
|
return segsnr
|
911
1078
|
|
912
|
-
segsnr_t = self.read_mixture_data(m_id,
|
1079
|
+
segsnr_t = self.read_mixture_data(m_id, "segsnr_t")
|
913
1080
|
if segsnr_t is not None:
|
914
|
-
return segsnr_t[0::self.ft_config.
|
1081
|
+
return segsnr_t[0 :: self.ft_config.overlap]
|
915
1082
|
|
916
1083
|
if force or segsnr_t is None:
|
917
1084
|
segsnr_t = self.mixture_segsnr_t(m_id, targets, target, noise, force)
|
918
1085
|
|
919
|
-
return segsnr_t[0::self.ft_config.
|
920
|
-
|
921
|
-
def mixture_ft(
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
1086
|
+
return segsnr_t[0 :: self.ft_config.overlap]
|
1087
|
+
|
1088
|
+
def mixture_ft(
|
1089
|
+
self,
|
1090
|
+
m_id: int,
|
1091
|
+
targets: AudiosT | None = None,
|
1092
|
+
target: AudioT | None = None,
|
1093
|
+
noise: AudioT | None = None,
|
1094
|
+
mixture_f: AudioF | None = None,
|
1095
|
+
mixture: AudioT | None = None,
|
1096
|
+
truth_t: TruthDict | None = None,
|
1097
|
+
force: bool = False,
|
1098
|
+
) -> tuple[Feature, TruthDict]:
|
930
1099
|
"""Get the feature and truth_f data for the given mixture ID
|
931
1100
|
|
932
1101
|
:param m_id: Zero-based mixture ID
|
@@ -939,63 +1108,45 @@ class MixtureDatabase:
|
|
939
1108
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
940
1109
|
:return: Tuple of (feature, truth_f) data
|
941
1110
|
"""
|
942
|
-
from dataclasses import asdict
|
943
|
-
|
944
|
-
import numpy as np
|
945
1111
|
from pyaaware import FeatureGenerator
|
946
1112
|
|
947
|
-
from .truth import
|
1113
|
+
from .truth import truth_stride_reduction
|
948
1114
|
|
949
1115
|
if not force:
|
950
|
-
feature, truth_f = self.read_mixture_data(m_id, [
|
1116
|
+
feature, truth_f = self.read_mixture_data(m_id, ["feature", "truth_f"])
|
951
1117
|
if feature is not None and truth_f is not None:
|
952
1118
|
return feature, truth_f
|
953
1119
|
|
954
1120
|
if force or mixture_f is None:
|
955
|
-
mixture_f = self.mixture_mixture_f(
|
956
|
-
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
|
1121
|
+
mixture_f = self.mixture_mixture_f(
|
1122
|
+
m_id=m_id,
|
1123
|
+
targets=targets,
|
1124
|
+
target=target,
|
1125
|
+
noise=noise,
|
1126
|
+
mixture=mixture,
|
1127
|
+
force=force,
|
1128
|
+
)
|
961
1129
|
|
962
1130
|
if force or truth_t is None:
|
963
1131
|
truth_t = self.mixture_truth_t(m_id=m_id, targets=targets, noise=noise, force=force)
|
964
1132
|
|
965
|
-
|
966
|
-
transform_frames = self.mixture_transform_frames(m_id)
|
967
|
-
feature_frames = self.mixture_feature_frames(m_id)
|
968
|
-
|
969
|
-
if truth_t is None:
|
970
|
-
truth_t = np.zeros((m.samples, self.num_classes), dtype=np.float32)
|
1133
|
+
fg = FeatureGenerator(self.fg_config.feature_mode, self.fg_config.truth_parameters)
|
971
1134
|
|
972
|
-
feature =
|
973
|
-
|
974
|
-
|
975
|
-
fg = FeatureGenerator(**asdict(self.fg_config))
|
976
|
-
feature_frame = 0
|
977
|
-
for transform_frame in range(transform_frames):
|
978
|
-
indices = slice(transform_frame * self.ft_config.R, (transform_frame + 1) * self.ft_config.R)
|
979
|
-
fg.execute(mixture_f[transform_frame],
|
980
|
-
truth_reduction(truth_t[indices], self.truth_reduction_function))
|
981
|
-
|
982
|
-
if fg.eof():
|
983
|
-
feature[feature_frame] = fg.feature()
|
984
|
-
truth_f[feature_frame] = fg.truth()
|
985
|
-
feature_frame += 1
|
986
|
-
|
987
|
-
if np.isreal(truth_f).all():
|
988
|
-
return feature, truth_f.real
|
1135
|
+
feature, truth_f = fg.execute_all(mixture_f, truth_t)
|
1136
|
+
for key in self.truth_configs:
|
1137
|
+
truth_f[key] = truth_stride_reduction(truth_f[key], self.truth_configs[key].stride_reduction)
|
989
1138
|
|
990
1139
|
return feature, truth_f
|
991
1140
|
|
992
|
-
def mixture_feature(
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
1141
|
+
def mixture_feature(
|
1142
|
+
self,
|
1143
|
+
m_id: int,
|
1144
|
+
targets: AudiosT | None = None,
|
1145
|
+
noise: AudioT | None = None,
|
1146
|
+
mixture: AudioT | None = None,
|
1147
|
+
truth_t: TruthDict | None = None,
|
1148
|
+
force: bool = False,
|
1149
|
+
) -> Feature:
|
999
1150
|
"""Get the feature data for the given mixture ID
|
1000
1151
|
|
1001
1152
|
:param m_id: Zero-based mixture ID
|
@@ -1006,21 +1157,25 @@ class MixtureDatabase:
|
|
1006
1157
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1007
1158
|
:return: Feature data
|
1008
1159
|
"""
|
1009
|
-
feature, _ = self.mixture_ft(
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1160
|
+
feature, _ = self.mixture_ft(
|
1161
|
+
m_id=m_id,
|
1162
|
+
targets=targets,
|
1163
|
+
noise=noise,
|
1164
|
+
mixture=mixture,
|
1165
|
+
truth_t=truth_t,
|
1166
|
+
force=force,
|
1167
|
+
)
|
1015
1168
|
return feature
|
1016
1169
|
|
1017
|
-
def mixture_truth_f(
|
1018
|
-
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1170
|
+
def mixture_truth_f(
|
1171
|
+
self,
|
1172
|
+
m_id: int,
|
1173
|
+
targets: AudiosT | None = None,
|
1174
|
+
noise: AudioT | None = None,
|
1175
|
+
mixture: AudioT | None = None,
|
1176
|
+
truth_t: TruthDict | None = None,
|
1177
|
+
force: bool = False,
|
1178
|
+
) -> TruthDict:
|
1024
1179
|
"""Get the truth_f data for the given mixture ID
|
1025
1180
|
|
1026
1181
|
:param m_id: Zero-based mixture ID
|
@@ -1031,20 +1186,24 @@ class MixtureDatabase:
|
|
1031
1186
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1032
1187
|
:return: truth_f data
|
1033
1188
|
"""
|
1034
|
-
_, truth_f = self.mixture_ft(
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
1189
|
+
_, truth_f = self.mixture_ft(
|
1190
|
+
m_id=m_id,
|
1191
|
+
targets=targets,
|
1192
|
+
noise=noise,
|
1193
|
+
mixture=mixture,
|
1194
|
+
truth_t=truth_t,
|
1195
|
+
force=force,
|
1196
|
+
)
|
1040
1197
|
return truth_f
|
1041
1198
|
|
1042
|
-
def mixture_class_count(
|
1043
|
-
|
1044
|
-
|
1045
|
-
|
1046
|
-
|
1047
|
-
|
1199
|
+
def mixture_class_count(
|
1200
|
+
self,
|
1201
|
+
m_id: int,
|
1202
|
+
targets: AudiosT | None = None,
|
1203
|
+
noise: AudioT | None = None,
|
1204
|
+
truth_t: TruthDict | None = None,
|
1205
|
+
) -> ClassCount:
|
1206
|
+
"""Compute the number of frames for which each class index is active for the given mixture ID
|
1048
1207
|
|
1049
1208
|
:param m_id: Zero-based mixture ID
|
1050
1209
|
:param targets: List of augmented target audio (one per target in the mixup)
|
@@ -1059,10 +1218,9 @@ class MixtureDatabase:
|
|
1059
1218
|
|
1060
1219
|
class_count = [0] * self.num_classes
|
1061
1220
|
num_classes = self.num_classes
|
1062
|
-
if self.
|
1063
|
-
|
1064
|
-
|
1065
|
-
class_count[cl] = int(np.sum(truth_t[:, cl] >= self.class_weights_thresholds[cl]))
|
1221
|
+
if "sed" in self.truth_configs:
|
1222
|
+
for cl in range(num_classes):
|
1223
|
+
class_count[cl] = int(np.sum(truth_t["sed"][:, cl] >= self.class_weights_thresholds[cl]))
|
1066
1224
|
|
1067
1225
|
return class_count
|
1068
1226
|
|
@@ -1084,7 +1242,7 @@ class MixtureDatabase:
|
|
1084
1242
|
def speech_metadata_tiers(self) -> list[str]:
|
1085
1243
|
return sorted(set(self.speaker_metadata_tiers + self.textgrid_metadata_tiers))
|
1086
1244
|
|
1087
|
-
def speaker(self, s_id: int | None, tier: str) ->
|
1245
|
+
def speaker(self, s_id: int | None, tier: str) -> str | None:
|
1088
1246
|
return _speaker(self.db, s_id, tier)
|
1089
1247
|
|
1090
1248
|
def speech_metadata(self, tier: str) -> list[str]:
|
@@ -1124,9 +1282,13 @@ class MixtureDatabase:
|
|
1124
1282
|
entries = []
|
1125
1283
|
for entry in data:
|
1126
1284
|
if target.augmentation.tempo is not None:
|
1127
|
-
entries.append(
|
1128
|
-
|
1129
|
-
|
1285
|
+
entries.append(
|
1286
|
+
Interval(
|
1287
|
+
entry.start / target.augmentation.tempo,
|
1288
|
+
entry.end / target.augmentation.tempo,
|
1289
|
+
entry.label,
|
1290
|
+
)
|
1291
|
+
)
|
1130
1292
|
else:
|
1131
1293
|
entries.append(entry)
|
1132
1294
|
results.append(entries)
|
@@ -1136,12 +1298,9 @@ class MixtureDatabase:
|
|
1136
1298
|
for target in self.mixture(mixid).targets:
|
1137
1299
|
results.append(self.speaker(self.target_file(target.file_id).speaker_id, tier))
|
1138
1300
|
|
1139
|
-
return
|
1301
|
+
return results
|
1140
1302
|
|
1141
|
-
def mixids_for_speech_metadata(self,
|
1142
|
-
tier: str,
|
1143
|
-
value: str = None,
|
1144
|
-
where: str = None) -> list[int]:
|
1303
|
+
def mixids_for_speech_metadata(self, tier: str, value: str | None = None, where: str | None = None) -> list[int]:
|
1145
1304
|
"""Get a list of mixture IDs for the given speech metadata tier.
|
1146
1305
|
|
1147
1306
|
If 'where' is None, then include mixture IDs whose tier values are equal to the given 'value'.
|
@@ -1160,25 +1319,27 @@ class MixtureDatabase:
|
|
1160
1319
|
>>> mixids = mixdb.mixids_for_speech_metadata('dialect', where="dialect in ('New York City', 'Northern')")
|
1161
1320
|
Get mixture IDs for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
|
1162
1321
|
"""
|
1163
|
-
from sonusai import SonusAIError
|
1164
|
-
|
1165
1322
|
if value is None and where is None:
|
1166
|
-
raise
|
1323
|
+
raise ValueError("Must provide either value or where")
|
1167
1324
|
|
1168
1325
|
if where is None:
|
1169
1326
|
where = f"{tier} = '{value}'"
|
1170
1327
|
|
1171
1328
|
if tier in self.textgrid_metadata_tiers:
|
1172
|
-
raise
|
1329
|
+
raise ValueError(f"TextGrid tier data, '{tier}', is not supported in mixids_for_speech_metadata().")
|
1173
1330
|
|
1174
1331
|
with self.db() as c:
|
1175
|
-
speaker_ids = [
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1332
|
+
speaker_ids = [
|
1333
|
+
speaker_id[0] for speaker_id in c.execute(f"SELECT id FROM speaker WHERE {where}").fetchall()
|
1334
|
+
]
|
1335
|
+
results = c.execute(
|
1336
|
+
"SELECT id FROM target_file " + f"WHERE speaker_id IN ({','.join(map(str, speaker_ids))})"
|
1337
|
+
).fetchall()
|
1179
1338
|
target_file_ids = [target_file_id[0] for target_file_id in results]
|
1180
|
-
results = c.execute(
|
1181
|
-
|
1339
|
+
results = c.execute(
|
1340
|
+
"SELECT mixture_id FROM mixture_target "
|
1341
|
+
+ f"WHERE mixture_target.target_id IN ({','.join(map(str, target_file_ids))})"
|
1342
|
+
).fetchall()
|
1182
1343
|
|
1183
1344
|
return [mixture_id[0] - 1 for mixture_id in results]
|
1184
1345
|
|
@@ -1187,9 +1348,9 @@ class MixtureDatabase:
|
|
1187
1348
|
|
1188
1349
|
return mixture_all_speech_metadata(self, self.mixture(m_id))
|
1189
1350
|
|
1190
|
-
def mixture_metrics(
|
1191
|
-
|
1192
|
-
|
1351
|
+
def mixture_metrics(
|
1352
|
+
self, m_id: int, metrics: list[str], force: bool = False
|
1353
|
+
) -> list[float | int | str | Segsnr | None]:
|
1193
1354
|
"""Get metrics data for the given mixture ID
|
1194
1355
|
|
1195
1356
|
:param m_id: Zero-based mixture ID
|
@@ -1197,12 +1358,11 @@ class MixtureDatabase:
|
|
1197
1358
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1198
1359
|
:return: List of metric data
|
1199
1360
|
"""
|
1200
|
-
from
|
1361
|
+
from collections.abc import Callable
|
1201
1362
|
|
1202
1363
|
import numpy as np
|
1203
1364
|
from pystoi import stoi
|
1204
1365
|
|
1205
|
-
from sonusai import SonusAIError
|
1206
1366
|
from sonusai.metrics import calc_audio_stats
|
1207
1367
|
from sonusai.metrics import calc_phase_distance
|
1208
1368
|
from sonusai.metrics import calc_segsnr_f
|
@@ -1312,7 +1472,7 @@ class MixtureDatabase:
|
|
1312
1472
|
def get() -> AudioStatsMetrics:
|
1313
1473
|
nonlocal state
|
1314
1474
|
if state is None:
|
1315
|
-
state = calc_audio_stats(target_audio(), self.fg_info.ft_config.
|
1475
|
+
state = calc_audio_stats(target_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
|
1316
1476
|
return state
|
1317
1477
|
|
1318
1478
|
return get
|
@@ -1325,7 +1485,7 @@ class MixtureDatabase:
|
|
1325
1485
|
def get() -> AudioStatsMetrics:
|
1326
1486
|
nonlocal state
|
1327
1487
|
if state is None:
|
1328
|
-
state = calc_audio_stats(noise_audio(), self.fg_info.ft_config.
|
1488
|
+
state = calc_audio_stats(noise_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
|
1329
1489
|
return state
|
1330
1490
|
|
1331
1491
|
return get
|
@@ -1338,9 +1498,10 @@ class MixtureDatabase:
|
|
1338
1498
|
def get(asr_name) -> dict:
|
1339
1499
|
nonlocal state
|
1340
1500
|
if asr_name not in state:
|
1341
|
-
|
1342
|
-
if
|
1343
|
-
raise
|
1501
|
+
value = self.asr_configs.get(asr_name, None)
|
1502
|
+
if value is None:
|
1503
|
+
raise ValueError(f"Unrecognized ASR name: '{asr_name}'")
|
1504
|
+
state[asr_name] = value
|
1344
1505
|
return state[asr_name]
|
1345
1506
|
|
1346
1507
|
return get
|
@@ -1374,15 +1535,14 @@ class MixtureDatabase:
|
|
1374
1535
|
mixture_asr = create_mixture_asr()
|
1375
1536
|
|
1376
1537
|
def get_asr_name(m: str) -> str:
|
1377
|
-
parts = m.split(
|
1538
|
+
parts = m.split(".")
|
1378
1539
|
if len(parts) != 2:
|
1379
|
-
raise
|
1380
|
-
f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
|
1540
|
+
raise ValueError(f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
|
1381
1541
|
asr_name = parts[1]
|
1382
1542
|
return asr_name
|
1383
1543
|
|
1384
|
-
def calc(m: str) -> float | int | str | Segsnr:
|
1385
|
-
if m ==
|
1544
|
+
def calc(m: str) -> float | int | str | Segsnr | None:
|
1545
|
+
if m == "mxsnr":
|
1386
1546
|
return self.mixture(m_id).snr
|
1387
1547
|
|
1388
1548
|
# Get cached data first, if exists
|
@@ -1392,12 +1552,12 @@ class MixtureDatabase:
|
|
1392
1552
|
return value
|
1393
1553
|
|
1394
1554
|
# Otherwise, generate data as needed
|
1395
|
-
if m.startswith(
|
1555
|
+
if m.startswith("mxwer"):
|
1396
1556
|
asr_name = get_asr_name(m)
|
1397
1557
|
|
1398
1558
|
if self.mixture(m_id).snr < -96:
|
1399
1559
|
# noise only, ignore/reset target asr
|
1400
|
-
return float(
|
1560
|
+
return float("nan")
|
1401
1561
|
|
1402
1562
|
if target_asr(asr_name):
|
1403
1563
|
return calc_wer(mixture_asr(asr_name), target_asr(asr_name)).wer * 100
|
@@ -1405,149 +1565,166 @@ class MixtureDatabase:
|
|
1405
1565
|
# TODO: should this be NaN like above?
|
1406
1566
|
return float(0)
|
1407
1567
|
|
1408
|
-
if m.startswith(
|
1568
|
+
if m.startswith("basewer"):
|
1569
|
+
asr_name = get_asr_name(m)
|
1570
|
+
|
1571
|
+
text = self.mixture_speech_metadata(m_id, "text")[0]
|
1572
|
+
if text is not None:
|
1573
|
+
return calc_wer(target_asr(asr_name), text).wer * 100
|
1574
|
+
|
1575
|
+
# TODO: should this be NaN like above?
|
1576
|
+
return float(0)
|
1577
|
+
|
1578
|
+
if m.startswith("mxasr"):
|
1409
1579
|
return mixture_asr(get_asr_name(m))
|
1410
1580
|
|
1411
|
-
if m ==
|
1581
|
+
if m == "mxssnr_avg":
|
1412
1582
|
return calc_segsnr_f(segsnr_f()).avg
|
1413
1583
|
|
1414
|
-
if m ==
|
1584
|
+
if m == "mxssnr_std":
|
1415
1585
|
return calc_segsnr_f(segsnr_f()).std
|
1416
1586
|
|
1417
|
-
if m ==
|
1587
|
+
if m == "mxssnrdb_avg":
|
1418
1588
|
return calc_segsnr_f(segsnr_f()).db_avg
|
1419
1589
|
|
1420
|
-
if m ==
|
1590
|
+
if m == "mxssnrdb_std":
|
1421
1591
|
return calc_segsnr_f(segsnr_f()).db_std
|
1422
1592
|
|
1423
|
-
if m ==
|
1593
|
+
if m == "mxssnrf_avg":
|
1424
1594
|
return calc_segsnr_f_bin(target_f(), noise_f()).avg
|
1425
1595
|
|
1426
|
-
if m ==
|
1596
|
+
if m == "mxssnrf_std":
|
1427
1597
|
return calc_segsnr_f_bin(target_f(), noise_f()).std
|
1428
1598
|
|
1429
|
-
if m ==
|
1599
|
+
if m == "mxssnrdbf_avg":
|
1430
1600
|
return calc_segsnr_f_bin(target_f(), noise_f()).db_avg
|
1431
1601
|
|
1432
|
-
if m ==
|
1602
|
+
if m == "mxssnrdbf_std":
|
1433
1603
|
return calc_segsnr_f_bin(target_f(), noise_f()).db_std
|
1434
1604
|
|
1435
|
-
if m ==
|
1605
|
+
if m == "mxpesq":
|
1436
1606
|
if self.mixture(m_id).snr < -96:
|
1437
1607
|
return 0
|
1438
1608
|
return speech().pesq
|
1439
1609
|
|
1440
|
-
if m ==
|
1610
|
+
if m == "mxcsig":
|
1441
1611
|
if self.mixture(m_id).snr < -96:
|
1442
1612
|
return 0
|
1443
1613
|
return speech().csig
|
1444
1614
|
|
1445
|
-
if m ==
|
1615
|
+
if m == "mxcbak":
|
1446
1616
|
if self.mixture(m_id).snr < -96:
|
1447
1617
|
return 0
|
1448
1618
|
return speech().cbak
|
1449
1619
|
|
1450
|
-
if m ==
|
1620
|
+
if m == "mxcovl":
|
1451
1621
|
if self.mixture(m_id).snr < -96:
|
1452
1622
|
return 0
|
1453
1623
|
return speech().covl
|
1454
1624
|
|
1455
|
-
if m ==
|
1625
|
+
if m == "mxwsdr":
|
1456
1626
|
mixture = mixture_audio()[:, np.newaxis]
|
1457
1627
|
target = target_audio()[:, np.newaxis]
|
1458
1628
|
noise = noise_audio()[:, np.newaxis]
|
1459
|
-
return calc_wsdr(
|
1460
|
-
|
1461
|
-
|
1629
|
+
return calc_wsdr(
|
1630
|
+
hypothesis=np.concatenate((mixture, noise), axis=1),
|
1631
|
+
reference=np.concatenate((target, noise), axis=1),
|
1632
|
+
with_log=True,
|
1633
|
+
)[0]
|
1462
1634
|
|
1463
|
-
if m ==
|
1635
|
+
if m == "mxpd":
|
1464
1636
|
mixture_f = self.mixture_mixture_f(m_id)
|
1465
1637
|
return calc_phase_distance(hypothesis=mixture_f, reference=target_f())[0]
|
1466
1638
|
|
1467
|
-
if m ==
|
1468
|
-
return stoi(
|
1639
|
+
if m == "mxstoi":
|
1640
|
+
return stoi(
|
1641
|
+
x=target_audio(),
|
1642
|
+
y=mixture_audio(),
|
1643
|
+
fs_sig=SAMPLE_RATE,
|
1644
|
+
extended=False,
|
1645
|
+
)
|
1469
1646
|
|
1470
|
-
if m ==
|
1647
|
+
if m == "tdco":
|
1471
1648
|
return target_stats().dco
|
1472
1649
|
|
1473
|
-
if m ==
|
1650
|
+
if m == "tmin":
|
1474
1651
|
return target_stats().min
|
1475
1652
|
|
1476
|
-
if m ==
|
1653
|
+
if m == "tmax":
|
1477
1654
|
return target_stats().max
|
1478
1655
|
|
1479
|
-
if m ==
|
1656
|
+
if m == "tpkdb":
|
1480
1657
|
return target_stats().pkdb
|
1481
1658
|
|
1482
|
-
if m ==
|
1659
|
+
if m == "tlrms":
|
1483
1660
|
return target_stats().lrms
|
1484
1661
|
|
1485
|
-
if m ==
|
1662
|
+
if m == "tpkr":
|
1486
1663
|
return target_stats().pkr
|
1487
1664
|
|
1488
|
-
if m ==
|
1665
|
+
if m == "ttr":
|
1489
1666
|
return target_stats().tr
|
1490
1667
|
|
1491
|
-
if m ==
|
1668
|
+
if m == "tcr":
|
1492
1669
|
return target_stats().cr
|
1493
1670
|
|
1494
|
-
if m ==
|
1671
|
+
if m == "tfl":
|
1495
1672
|
return target_stats().fl
|
1496
1673
|
|
1497
|
-
if m ==
|
1674
|
+
if m == "tpkc":
|
1498
1675
|
return target_stats().pkc
|
1499
1676
|
|
1500
|
-
if m.startswith(
|
1677
|
+
if m.startswith("tasr"):
|
1501
1678
|
return target_asr(get_asr_name(m))
|
1502
1679
|
|
1503
|
-
if m ==
|
1680
|
+
if m == "ndco":
|
1504
1681
|
return noise_stats().dco
|
1505
1682
|
|
1506
|
-
if m ==
|
1683
|
+
if m == "nmin":
|
1507
1684
|
return noise_stats().min
|
1508
1685
|
|
1509
|
-
if m ==
|
1686
|
+
if m == "nmax":
|
1510
1687
|
return noise_stats().max
|
1511
1688
|
|
1512
|
-
if m ==
|
1689
|
+
if m == "npkdb":
|
1513
1690
|
return noise_stats().pkdb
|
1514
1691
|
|
1515
|
-
if m ==
|
1692
|
+
if m == "nlrms":
|
1516
1693
|
return noise_stats().lrms
|
1517
1694
|
|
1518
|
-
if m ==
|
1695
|
+
if m == "npkr":
|
1519
1696
|
return noise_stats().pkr
|
1520
1697
|
|
1521
|
-
if m ==
|
1698
|
+
if m == "ntr":
|
1522
1699
|
return noise_stats().tr
|
1523
1700
|
|
1524
|
-
if m ==
|
1701
|
+
if m == "ncr":
|
1525
1702
|
return noise_stats().cr
|
1526
1703
|
|
1527
|
-
if m ==
|
1704
|
+
if m == "nfl":
|
1528
1705
|
return noise_stats().fl
|
1529
1706
|
|
1530
|
-
if m ==
|
1707
|
+
if m == "npkc":
|
1531
1708
|
return noise_stats().pkc
|
1532
1709
|
|
1533
|
-
if m ==
|
1710
|
+
if m == "sedavg":
|
1534
1711
|
return 0
|
1535
1712
|
|
1536
|
-
if m ==
|
1713
|
+
if m == "sedcnt":
|
1537
1714
|
return 0
|
1538
1715
|
|
1539
|
-
if m ==
|
1716
|
+
if m == "sedtop3":
|
1540
1717
|
return np.zeros(3, dtype=np.float32)
|
1541
1718
|
|
1542
|
-
if m ==
|
1719
|
+
if m == "sedtopn":
|
1543
1720
|
return 0
|
1544
1721
|
|
1545
|
-
if m ==
|
1722
|
+
if m == "ssnr":
|
1546
1723
|
return segsnr_f()
|
1547
1724
|
|
1548
|
-
raise
|
1725
|
+
raise AttributeError(f"Unrecognized metric: '{m}'")
|
1549
1726
|
|
1550
|
-
result: list[float | int | str | Segsnr] = []
|
1727
|
+
result: list[float | int | str | Segsnr | None] = []
|
1551
1728
|
for metric in metrics:
|
1552
1729
|
result.append(calc(metric))
|
1553
1730
|
|
@@ -1565,13 +1742,16 @@ def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
|
|
1565
1742
|
from .db_datatypes import SpectralMaskRecord
|
1566
1743
|
|
1567
1744
|
with db() as c:
|
1568
|
-
spectral_mask = SpectralMaskRecord(
|
1569
|
-
|
1570
|
-
|
1571
|
-
|
1572
|
-
|
1573
|
-
|
1574
|
-
|
1745
|
+
spectral_mask = SpectralMaskRecord(
|
1746
|
+
*c.execute("SELECT * FROM spectral_mask WHERE ? = spectral_mask.id", (sm_id,)).fetchone()
|
1747
|
+
)
|
1748
|
+
return SpectralMask(
|
1749
|
+
f_max_width=spectral_mask.f_max_width,
|
1750
|
+
f_num=spectral_mask.f_num,
|
1751
|
+
t_max_width=spectral_mask.t_max_width,
|
1752
|
+
t_num=spectral_mask.t_num,
|
1753
|
+
t_max_percent=spectral_mask.t_max_percent,
|
1754
|
+
)
|
1575
1755
|
|
1576
1756
|
|
1577
1757
|
@lru_cache
|
@@ -1584,30 +1764,21 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
|
|
1584
1764
|
"""
|
1585
1765
|
import json
|
1586
1766
|
|
1587
|
-
from .datatypes import TruthSetting
|
1588
|
-
from .datatypes import TruthSettings
|
1589
1767
|
from .db_datatypes import TargetFileRecord
|
1590
1768
|
|
1591
1769
|
with db() as c:
|
1592
1770
|
target_file = TargetFileRecord(
|
1593
|
-
*c.execute("SELECT * FROM target_file WHERE ? = target_file.id", (t_id,)).fetchone()
|
1594
|
-
|
1595
|
-
|
1596
|
-
|
1597
|
-
|
1598
|
-
|
1599
|
-
|
1600
|
-
|
1601
|
-
|
1602
|
-
|
1603
|
-
|
1604
|
-
function=entry.get('function', None),
|
1605
|
-
index=entry.get('index', None)))
|
1606
|
-
return TargetFile(name=target_file.name,
|
1607
|
-
samples=target_file.samples,
|
1608
|
-
level_type=target_file.level_type,
|
1609
|
-
truth_settings=truth_settings,
|
1610
|
-
speaker_id=target_file.speaker_id)
|
1771
|
+
*c.execute("SELECT * FROM target_file WHERE ? = target_file.id", (t_id,)).fetchone()
|
1772
|
+
)
|
1773
|
+
|
1774
|
+
return TargetFile(
|
1775
|
+
name=target_file.name,
|
1776
|
+
samples=target_file.samples,
|
1777
|
+
class_indices=json.loads(target_file.class_indices),
|
1778
|
+
level_type=target_file.level_type,
|
1779
|
+
truth_configs=_target_truth_configs(db, t_id),
|
1780
|
+
speaker_id=target_file.speaker_id,
|
1781
|
+
)
|
1611
1782
|
|
1612
1783
|
|
1613
1784
|
@lru_cache
|
@@ -1619,8 +1790,10 @@ def _noise_file(db: partial, n_id: int) -> NoiseFile:
|
|
1619
1790
|
:return: Noise file
|
1620
1791
|
"""
|
1621
1792
|
with db() as c:
|
1622
|
-
noise = c.execute(
|
1623
|
-
|
1793
|
+
noise = c.execute(
|
1794
|
+
"SELECT noise_file.name, samples FROM noise_file WHERE ? = noise_file.id",
|
1795
|
+
(n_id,),
|
1796
|
+
).fetchone()
|
1624
1797
|
return NoiseFile(name=noise[0], samples=noise[1])
|
1625
1798
|
|
1626
1799
|
|
@@ -1633,9 +1806,12 @@ def _impulse_response_file(db: partial, ir_id: int) -> str:
|
|
1633
1806
|
:return: Noise
|
1634
1807
|
"""
|
1635
1808
|
with db() as c:
|
1636
|
-
return str(
|
1637
|
-
|
1638
|
-
|
1809
|
+
return str(
|
1810
|
+
c.execute(
|
1811
|
+
"SELECT impulse_response_file.file FROM impulse_response_file WHERE ? = impulse_response_file.id",
|
1812
|
+
(ir_id + 1,),
|
1813
|
+
).fetchone()[0]
|
1814
|
+
)
|
1639
1815
|
|
1640
1816
|
|
1641
1817
|
@lru_cache
|
@@ -1646,31 +1822,59 @@ def _mixture(db: partial, m_id: int) -> Mixture:
|
|
1646
1822
|
:param m_id: Zero-based mixture ID
|
1647
1823
|
:return: Mixture record
|
1648
1824
|
"""
|
1649
|
-
from .helpers import to_mixture
|
1650
|
-
from .helpers import to_target
|
1651
1825
|
from .db_datatypes import MixtureRecord
|
1652
1826
|
from .db_datatypes import TargetRecord
|
1827
|
+
from .helpers import to_mixture
|
1828
|
+
from .helpers import to_target
|
1653
1829
|
|
1654
1830
|
with db() as c:
|
1655
1831
|
mixture = MixtureRecord(*c.execute("SELECT * FROM mixture WHERE ? = mixture.id", (m_id + 1,)).fetchone())
|
1656
|
-
targets = [
|
1657
|
-
|
1658
|
-
|
1659
|
-
|
1660
|
-
|
1832
|
+
targets = [
|
1833
|
+
to_target(TargetRecord(*target))
|
1834
|
+
for target in c.execute(
|
1835
|
+
"SELECT target.* "
|
1836
|
+
+ "FROM target, mixture_target "
|
1837
|
+
+ "WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
|
1838
|
+
(mixture.id,),
|
1839
|
+
).fetchall()
|
1840
|
+
]
|
1661
1841
|
|
1662
1842
|
return to_mixture(mixture, targets)
|
1663
1843
|
|
1664
1844
|
|
1665
1845
|
@lru_cache
|
1666
|
-
def _speaker(db: partial, s_id: int | None, tier: str) ->
|
1846
|
+
def _speaker(db: partial, s_id: int | None, tier: str) -> str | None:
|
1667
1847
|
if s_id is None:
|
1668
1848
|
return None
|
1669
1849
|
|
1670
1850
|
with db() as c:
|
1671
|
-
data = c.execute(f
|
1851
|
+
data = c.execute(f"SELECT {tier} FROM speaker WHERE ? = id", (s_id,)).fetchone()
|
1672
1852
|
if data is None:
|
1673
1853
|
return None
|
1674
1854
|
if data[0] is None:
|
1675
1855
|
return None
|
1676
1856
|
return data[0]
|
1857
|
+
|
1858
|
+
|
1859
|
+
@lru_cache
|
1860
|
+
def _target_truth_configs(db: partial, t_id: int) -> TruthConfigs:
|
1861
|
+
import json
|
1862
|
+
|
1863
|
+
from .datatypes import TruthConfig
|
1864
|
+
|
1865
|
+
truth_configs: TruthConfigs = {}
|
1866
|
+
with db() as c:
|
1867
|
+
for truth_config_record in c.execute(
|
1868
|
+
"SELECT truth_config.config "
|
1869
|
+
+ "FROM truth_config, target_file_truth_config "
|
1870
|
+
+ "WHERE ? = target_file_truth_config.target_file_id "
|
1871
|
+
+ "AND truth_config.id = target_file_truth_config.truth_config_id",
|
1872
|
+
(t_id,),
|
1873
|
+
).fetchall():
|
1874
|
+
truth_config = json.loads(truth_config_record[0])
|
1875
|
+
truth_configs[truth_config["name"]] = TruthConfig(
|
1876
|
+
function=truth_config["function"],
|
1877
|
+
stride_reduction=truth_config["stride_reduction"],
|
1878
|
+
config=truth_config["config"],
|
1879
|
+
)
|
1880
|
+
return truth_configs
|