sonusai 0.18.9__py3-none-any.whl → 0.19.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +93 -77
- sonusai/genmetrics.py +59 -46
- sonusai/genmix.py +116 -104
- sonusai/genmixdb.py +194 -153
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +15 -17
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +19 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +40 -38
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +669 -477
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +52 -85
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +40 -27
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
- sonusai-0.19.5.dist-info/RECORD +125 -0
- {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.9.dist-info/RECORD +0 -125
- {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py
CHANGED
@@ -1,46 +1,47 @@
|
|
1
|
+
# ruff: noqa: S608
|
1
2
|
from functools import cached_property
|
2
3
|
from functools import lru_cache
|
3
4
|
from functools import partial
|
4
5
|
from sqlite3 import Connection
|
5
6
|
from sqlite3 import Cursor
|
6
7
|
from typing import Any
|
7
|
-
|
8
|
-
|
9
|
-
from
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from
|
13
|
-
from
|
14
|
-
from
|
15
|
-
from
|
16
|
-
from
|
17
|
-
from
|
18
|
-
from
|
19
|
-
from
|
20
|
-
from
|
21
|
-
from
|
22
|
-
from
|
23
|
-
from
|
24
|
-
from
|
25
|
-
from
|
26
|
-
from
|
27
|
-
from
|
28
|
-
from
|
29
|
-
from
|
30
|
-
from
|
31
|
-
from
|
32
|
-
from
|
33
|
-
from
|
34
|
-
from
|
8
|
+
|
9
|
+
from .datatypes import ASRConfigs
|
10
|
+
from .datatypes import AudioF
|
11
|
+
from .datatypes import AudiosF
|
12
|
+
from .datatypes import AudiosT
|
13
|
+
from .datatypes import AudioT
|
14
|
+
from .datatypes import ClassCount
|
15
|
+
from .datatypes import Feature
|
16
|
+
from .datatypes import FeatureGeneratorConfig
|
17
|
+
from .datatypes import FeatureGeneratorInfo
|
18
|
+
from .datatypes import GeneralizedIDs
|
19
|
+
from .datatypes import ImpulseResponseFiles
|
20
|
+
from .datatypes import MetricDoc
|
21
|
+
from .datatypes import MetricDocs
|
22
|
+
from .datatypes import Mixture
|
23
|
+
from .datatypes import Mixtures
|
24
|
+
from .datatypes import NoiseFile
|
25
|
+
from .datatypes import NoiseFiles
|
26
|
+
from .datatypes import Segsnr
|
27
|
+
from .datatypes import SpectralMask
|
28
|
+
from .datatypes import SpectralMasks
|
29
|
+
from .datatypes import SpeechMetadata
|
30
|
+
from .datatypes import TargetFile
|
31
|
+
from .datatypes import TargetFiles
|
32
|
+
from .datatypes import TransformConfig
|
33
|
+
from .datatypes import TruthConfigs
|
34
|
+
from .datatypes import TruthDict
|
35
|
+
from .datatypes import UniversalSNR
|
35
36
|
|
36
37
|
|
37
38
|
def db_file(location: str, test: bool = False) -> str:
|
38
39
|
from os.path import join
|
39
40
|
|
40
41
|
if test:
|
41
|
-
name =
|
42
|
+
name = "mixdb_test.db"
|
42
43
|
else:
|
43
|
-
name =
|
44
|
+
name = "mixdb.db"
|
44
45
|
|
45
46
|
return join(location, name)
|
46
47
|
|
@@ -50,19 +51,17 @@ def db_connection(location: str, create: bool = False, readonly: bool = True, te
|
|
50
51
|
from os import remove
|
51
52
|
from os.path import exists
|
52
53
|
|
53
|
-
from sonusai import SonusAIError
|
54
|
-
|
55
54
|
name = db_file(location, test)
|
56
55
|
if create and exists(name):
|
57
56
|
remove(name)
|
58
57
|
|
59
58
|
if not create and not exists(name):
|
60
|
-
raise
|
59
|
+
raise OSError(f"Could not find mixture database in {location}")
|
61
60
|
|
62
61
|
if not create and readonly:
|
63
|
-
name +=
|
62
|
+
name += "?mode=ro"
|
64
63
|
|
65
|
-
connection = sqlite3.connect(
|
64
|
+
connection = sqlite3.connect("file:" + name, uri=True)
|
66
65
|
# connection.set_trace_callback(print)
|
67
66
|
return connection
|
68
67
|
|
@@ -103,25 +102,23 @@ class MixtureDatabase:
|
|
103
102
|
num_classes=self.num_classes,
|
104
103
|
spectral_masks=self.spectral_masks,
|
105
104
|
target_files=self.target_files,
|
106
|
-
truth_mutex=self.truth_mutex,
|
107
|
-
truth_reduction_function=self.truth_reduction_function
|
108
105
|
)
|
109
106
|
return config.to_json(indent=2)
|
110
107
|
|
111
108
|
def save(self) -> None:
|
112
|
-
"""Save the MixtureDatabase as a JSON file
|
113
|
-
"""
|
109
|
+
"""Save the MixtureDatabase as a JSON file"""
|
114
110
|
from os.path import join
|
115
111
|
|
116
|
-
json_name = join(self.location,
|
117
|
-
with open(file=json_name, mode=
|
112
|
+
json_name = join(self.location, "mixdb.json")
|
113
|
+
with open(file=json_name, mode="w") as file:
|
118
114
|
file.write(self.json)
|
119
115
|
|
120
116
|
@cached_property
|
121
117
|
def fg_config(self) -> FeatureGeneratorConfig:
|
122
|
-
return FeatureGeneratorConfig(
|
123
|
-
|
124
|
-
|
118
|
+
return FeatureGeneratorConfig(
|
119
|
+
feature_mode=self.feature,
|
120
|
+
truth_parameters=self.truth_parameters,
|
121
|
+
)
|
125
122
|
|
126
123
|
@cached_property
|
127
124
|
def fg_info(self) -> FeatureGeneratorInfo:
|
@@ -130,19 +127,18 @@ class MixtureDatabase:
|
|
130
127
|
return get_feature_generator_info(self.fg_config)
|
131
128
|
|
132
129
|
@cached_property
|
133
|
-
def
|
134
|
-
with self.db() as c:
|
135
|
-
return int(c.execute("SELECT top.num_classes FROM top").fetchone()[0])
|
136
|
-
|
137
|
-
@cached_property
|
138
|
-
def truth_mutex(self) -> bool:
|
130
|
+
def truth_parameters(self) -> dict[str, int]:
|
139
131
|
with self.db() as c:
|
140
|
-
|
132
|
+
rows = c.execute("SELECT * FROM truth_parameters").fetchall()
|
133
|
+
truth_parameters: dict[str, int] = {}
|
134
|
+
for row in rows:
|
135
|
+
truth_parameters[row[1]] = row[2]
|
136
|
+
return truth_parameters
|
141
137
|
|
142
138
|
@cached_property
|
143
|
-
def
|
139
|
+
def num_classes(self) -> int:
|
144
140
|
with self.db() as c:
|
145
|
-
return
|
141
|
+
return int(c.execute("SELECT top.num_classes FROM top").fetchone()[0])
|
146
142
|
|
147
143
|
@cached_property
|
148
144
|
def noise_mix_mode(self) -> str:
|
@@ -158,70 +154,152 @@ class MixtureDatabase:
|
|
158
154
|
|
159
155
|
@cached_property
|
160
156
|
def supported_metrics(self) -> MetricDocs:
|
161
|
-
metrics = MetricDocs(
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
157
|
+
metrics = MetricDocs(
|
158
|
+
[
|
159
|
+
MetricDoc("Mixture Metrics", "mxsnr", "SNR specification in dB"),
|
160
|
+
MetricDoc(
|
161
|
+
"Mixture Metrics",
|
162
|
+
"mxssnr_avg",
|
163
|
+
"Segmental SNR average over all frames",
|
164
|
+
),
|
165
|
+
MetricDoc(
|
166
|
+
"Mixture Metrics",
|
167
|
+
"mxssnr_std",
|
168
|
+
"Segmental SNR standard deviation over all frames",
|
169
|
+
),
|
170
|
+
MetricDoc(
|
171
|
+
"Mixture Metrics",
|
172
|
+
"mxssnrdb_avg",
|
173
|
+
"Segmental SNR average of the dB frame values over all frames",
|
174
|
+
),
|
175
|
+
MetricDoc(
|
176
|
+
"Mixture Metrics",
|
177
|
+
"mxssnrdb_std",
|
178
|
+
"Segmental SNR standard deviation of the dB frame values over all frames",
|
179
|
+
),
|
180
|
+
MetricDoc(
|
181
|
+
"Mixture Metrics",
|
182
|
+
"mxssnrf_avg",
|
183
|
+
"Per-bin segmental SNR average over all frames (using feature transform)",
|
184
|
+
),
|
185
|
+
MetricDoc(
|
186
|
+
"Mixture Metrics",
|
187
|
+
"mxssnrf_std",
|
188
|
+
"Per-bin segmental SNR standard deviation over all frames (using feature transform)",
|
189
|
+
),
|
190
|
+
MetricDoc(
|
191
|
+
"Mixture Metrics",
|
192
|
+
"mxssnrdbf_avg",
|
193
|
+
"Per-bin segmental average of the dB frame values over all frames (using feature transform)",
|
194
|
+
),
|
195
|
+
MetricDoc(
|
196
|
+
"Mixture Metrics",
|
197
|
+
"mxssnrdbf_std",
|
198
|
+
"Per-bin segmental standard deviation of the dB frame values over all frames (using feature transform)",
|
199
|
+
),
|
200
|
+
MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true target[0]"),
|
201
|
+
MetricDoc(
|
202
|
+
"Mixture Metrics",
|
203
|
+
"mxwsdr",
|
204
|
+
"Weighted signal distorion ratio of mixture versus true target[0]",
|
205
|
+
),
|
206
|
+
MetricDoc(
|
207
|
+
"Mixture Metrics",
|
208
|
+
"mxpd",
|
209
|
+
"Phase distance between mixture and true target[0]",
|
210
|
+
),
|
211
|
+
MetricDoc(
|
212
|
+
"Mixture Metrics",
|
213
|
+
"mxstoi",
|
214
|
+
"Short term objective intelligibility of mixture versus true target[0]",
|
215
|
+
),
|
216
|
+
MetricDoc(
|
217
|
+
"Mixture Metrics",
|
218
|
+
"mxcsig",
|
219
|
+
"Predicted rating of speech distortion of mixture versus true target[0]",
|
220
|
+
),
|
221
|
+
MetricDoc(
|
222
|
+
"Mixture Metrics",
|
223
|
+
"mxcbak",
|
224
|
+
"Predicted rating of background distortion of mixture versus true target[0]",
|
225
|
+
),
|
226
|
+
MetricDoc(
|
227
|
+
"Mixture Metrics",
|
228
|
+
"mxcovl",
|
229
|
+
"Predicted rating of overall quality of mixture versus true target[0]",
|
230
|
+
),
|
231
|
+
MetricDoc("Mixture Metrics", "ssnr", "Segmental SNR"),
|
232
|
+
MetricDoc("Target Metrics", "tdco", "Target[0] DC offset"),
|
233
|
+
MetricDoc("Target Metrics", "tmin", "Target[0] min level"),
|
234
|
+
MetricDoc("Target Metrics", "tmax", "Target[0] max levl"),
|
235
|
+
MetricDoc("Target Metrics", "tpkdb", "Target[0] Pk lev dB"),
|
236
|
+
MetricDoc("Target Metrics", "tlrms", "Target[0] RMS lev dB"),
|
237
|
+
MetricDoc("Target Metrics", "tpkr", "Target[0] RMS Pk dB"),
|
238
|
+
MetricDoc("Target Metrics", "ttr", "Target[0] RMS Tr dB"),
|
239
|
+
MetricDoc("Target Metrics", "tcr", "Target[0] Crest factor"),
|
240
|
+
MetricDoc("Target Metrics", "tfl", "Target[0] Flat factor"),
|
241
|
+
MetricDoc("Target Metrics", "tpkc", "Target[0] Pk count"),
|
242
|
+
MetricDoc("Noise Metrics", "ndco", "Noise DC offset"),
|
243
|
+
MetricDoc("Noise Metrics", "nmin", "Noise min level"),
|
244
|
+
MetricDoc("Noise Metrics", "nmax", "Noise max levl"),
|
245
|
+
MetricDoc("Noise Metrics", "npkdb", "Noise Pk lev dB"),
|
246
|
+
MetricDoc("Noise Metrics", "nlrms", "Noise RMS lev dB"),
|
247
|
+
MetricDoc("Noise Metrics", "npkr", "Noise RMS Pk dB"),
|
248
|
+
MetricDoc("Noise Metrics", "ntr", "Noise RMS Tr dB"),
|
249
|
+
MetricDoc("Noise Metrics", "ncr", "Noise Crest factor"),
|
250
|
+
MetricDoc("Noise Metrics", "nfl", "Noise Flat factor"),
|
251
|
+
MetricDoc("Noise Metrics", "npkc", "Noise Pk count"),
|
252
|
+
MetricDoc(
|
253
|
+
"Truth Metrics",
|
254
|
+
"sedavg",
|
255
|
+
"(not implemented) Average SED activity over all frames [truth_parameters, 1]",
|
256
|
+
),
|
257
|
+
MetricDoc(
|
258
|
+
"Truth Metrics",
|
259
|
+
"sedcnt",
|
260
|
+
"(not implemented) Count in number of frames that SED is active [truth_parameters, 1]",
|
261
|
+
),
|
262
|
+
MetricDoc(
|
263
|
+
"Truth Metrics",
|
264
|
+
"sedtop3",
|
265
|
+
"(not implemented) 3 most active by largest sedavg [3, 1]",
|
266
|
+
),
|
267
|
+
MetricDoc(
|
268
|
+
"Truth Metrics",
|
269
|
+
"sedtopn",
|
270
|
+
"(not implemented) N most active by largest sedavg [N, 1]",
|
271
|
+
),
|
272
|
+
]
|
273
|
+
)
|
216
274
|
for name in self.asr_configs:
|
217
|
-
metrics.append(
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
275
|
+
metrics.append(
|
276
|
+
MetricDoc(
|
277
|
+
"Target Metrics",
|
278
|
+
f"tasr.{name}",
|
279
|
+
f"Target[0] ASR text using {name} ASR as defined in mixdb asr_configs parameter",
|
280
|
+
)
|
281
|
+
)
|
282
|
+
metrics.append(
|
283
|
+
MetricDoc(
|
284
|
+
"Mixture Metrics",
|
285
|
+
f"mxasr.{name}",
|
286
|
+
f"ASR text using {name} ASR as defined in mixdb asr_configs parameter",
|
287
|
+
)
|
288
|
+
)
|
289
|
+
metrics.append(
|
290
|
+
MetricDoc(
|
291
|
+
"Target Metrics",
|
292
|
+
f"basewer.{name}",
|
293
|
+
f"Word error rate of tasr.{name} vs. speech text metadata for the target",
|
294
|
+
)
|
295
|
+
)
|
296
|
+
metrics.append(
|
297
|
+
MetricDoc(
|
298
|
+
"Mixture Metrics",
|
299
|
+
f"mxwer.{name}",
|
300
|
+
f"Word error rate of mxasr.{name} vs. tasr.{name}",
|
301
|
+
)
|
302
|
+
)
|
225
303
|
|
226
304
|
return metrics
|
227
305
|
|
@@ -267,7 +345,7 @@ class MixtureDatabase:
|
|
267
345
|
def transform_frame_ms(self) -> float:
|
268
346
|
from .constants import SAMPLE_RATE
|
269
347
|
|
270
|
-
return float(self.ft_config.
|
348
|
+
return float(self.ft_config.overlap) / float(SAMPLE_RATE / 1000)
|
271
349
|
|
272
350
|
@cached_property
|
273
351
|
def feature_ms(self) -> float:
|
@@ -275,7 +353,7 @@ class MixtureDatabase:
|
|
275
353
|
|
276
354
|
@cached_property
|
277
355
|
def feature_samples(self) -> int:
|
278
|
-
return self.ft_config.
|
356
|
+
return self.ft_config.overlap * self.fg_decimation * self.fg_stride
|
279
357
|
|
280
358
|
@cached_property
|
281
359
|
def feature_step_ms(self) -> float:
|
@@ -283,28 +361,33 @@ class MixtureDatabase:
|
|
283
361
|
|
284
362
|
@cached_property
|
285
363
|
def feature_step_samples(self) -> int:
|
286
|
-
return self.ft_config.
|
364
|
+
return self.ft_config.overlap * self.fg_decimation * self.fg_step
|
287
365
|
|
288
|
-
def total_samples(self, m_ids: GeneralizedIDs =
|
289
|
-
|
366
|
+
def total_samples(self, m_ids: GeneralizedIDs = "*") -> int:
|
367
|
+
samples = 0
|
368
|
+
for m_id in self.mixids_to_list(m_ids):
|
369
|
+
s = self.mixture(m_id).samples
|
370
|
+
if s is not None:
|
371
|
+
samples += s
|
372
|
+
return samples
|
290
373
|
|
291
|
-
def total_transform_frames(self, m_ids: GeneralizedIDs =
|
292
|
-
return self.total_samples(m_ids) // self.ft_config.
|
374
|
+
def total_transform_frames(self, m_ids: GeneralizedIDs = "*") -> int:
|
375
|
+
return self.total_samples(m_ids) // self.ft_config.overlap
|
293
376
|
|
294
|
-
def total_feature_frames(self, m_ids: GeneralizedIDs =
|
377
|
+
def total_feature_frames(self, m_ids: GeneralizedIDs = "*") -> int:
|
295
378
|
return self.total_samples(m_ids) // self.feature_step_samples
|
296
379
|
|
297
380
|
def mixture_transform_frames(self, m_id: int) -> int:
|
298
381
|
from .helpers import frames_from_samples
|
299
382
|
|
300
|
-
return frames_from_samples(self.mixture(m_id).samples, self.ft_config.
|
383
|
+
return frames_from_samples(self.mixture(m_id).samples, self.ft_config.overlap)
|
301
384
|
|
302
385
|
def mixture_feature_frames(self, m_id: int) -> int:
|
303
386
|
from .helpers import frames_from_samples
|
304
387
|
|
305
388
|
return frames_from_samples(self.mixture(m_id).samples, self.feature_step_samples)
|
306
389
|
|
307
|
-
def mixids_to_list(self, m_ids:
|
390
|
+
def mixids_to_list(self, m_ids: GeneralizedIDs = "*") -> list[int]:
|
308
391
|
"""Resolve generalized mixture IDs to a list of integers
|
309
392
|
|
310
393
|
:param m_ids: Generalized mixture IDs
|
@@ -321,8 +404,10 @@ class MixtureDatabase:
|
|
321
404
|
:return: Class labels
|
322
405
|
"""
|
323
406
|
with self.db() as c:
|
324
|
-
return [
|
325
|
-
|
407
|
+
return [
|
408
|
+
str(item[0])
|
409
|
+
for item in c.execute("SELECT class_label.label FROM class_label ORDER BY class_label.id").fetchall()
|
410
|
+
]
|
326
411
|
|
327
412
|
@cached_property
|
328
413
|
def class_weights_thresholds(self) -> list[float]:
|
@@ -331,8 +416,37 @@ class MixtureDatabase:
|
|
331
416
|
:return: Class weights thresholds
|
332
417
|
"""
|
333
418
|
with self.db() as c:
|
334
|
-
return [
|
335
|
-
|
419
|
+
return [
|
420
|
+
float(item[0])
|
421
|
+
for item in c.execute(
|
422
|
+
"SELECT class_weights_threshold.threshold FROM class_weights_threshold"
|
423
|
+
).fetchall()
|
424
|
+
]
|
425
|
+
|
426
|
+
@cached_property
|
427
|
+
def truth_configs(self) -> TruthConfigs:
|
428
|
+
"""Get truth configs from db
|
429
|
+
|
430
|
+
:return: Truth configs
|
431
|
+
"""
|
432
|
+
import json
|
433
|
+
|
434
|
+
from .datatypes import TruthConfig
|
435
|
+
|
436
|
+
with self.db() as c:
|
437
|
+
truth_configs: TruthConfigs = {}
|
438
|
+
for truth_config_record in c.execute("SELECT truth_config.config FROM truth_config").fetchall():
|
439
|
+
truth_config = json.loads(truth_config_record[0])
|
440
|
+
if truth_config["name"] not in truth_configs:
|
441
|
+
truth_configs[truth_config["name"]] = TruthConfig(
|
442
|
+
function=truth_config["function"],
|
443
|
+
stride_reduction=truth_config["stride_reduction"],
|
444
|
+
config=truth_config["config"],
|
445
|
+
)
|
446
|
+
return truth_configs
|
447
|
+
|
448
|
+
def target_truth_configs(self, t_id: int) -> TruthConfigs:
|
449
|
+
return _target_truth_configs(self.db, t_id)
|
336
450
|
|
337
451
|
@cached_property
|
338
452
|
def random_snrs(self) -> list[float]:
|
@@ -341,8 +455,12 @@ class MixtureDatabase:
|
|
341
455
|
:return: Random SNRs
|
342
456
|
"""
|
343
457
|
with self.db() as c:
|
344
|
-
return list(
|
345
|
-
|
458
|
+
return list(
|
459
|
+
{
|
460
|
+
float(item[0])
|
461
|
+
for item in c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 1").fetchall()
|
462
|
+
}
|
463
|
+
)
|
346
464
|
|
347
465
|
@cached_property
|
348
466
|
def snrs(self) -> list[float]:
|
@@ -351,13 +469,21 @@ class MixtureDatabase:
|
|
351
469
|
:return: SNRs
|
352
470
|
"""
|
353
471
|
with self.db() as c:
|
354
|
-
return list(
|
355
|
-
|
472
|
+
return list(
|
473
|
+
{
|
474
|
+
float(item[0])
|
475
|
+
for item in c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 0").fetchall()
|
476
|
+
}
|
477
|
+
)
|
356
478
|
|
357
479
|
@cached_property
|
358
480
|
def all_snrs(self) -> list[UniversalSNR]:
|
359
|
-
return sorted(
|
360
|
-
|
481
|
+
return sorted(
|
482
|
+
set(
|
483
|
+
[UniversalSNR(is_random=False, value=snr) for snr in self.snrs]
|
484
|
+
+ [UniversalSNR(is_random=True, value=snr) for snr in self.random_snrs]
|
485
|
+
)
|
486
|
+
)
|
361
487
|
|
362
488
|
@cached_property
|
363
489
|
def spectral_masks(self) -> SpectralMasks:
|
@@ -368,13 +494,19 @@ class MixtureDatabase:
|
|
368
494
|
from .db_datatypes import SpectralMaskRecord
|
369
495
|
|
370
496
|
with self.db() as c:
|
371
|
-
spectral_masks = [
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
497
|
+
spectral_masks = [
|
498
|
+
SpectralMaskRecord(*result) for result in c.execute("SELECT * FROM spectral_mask").fetchall()
|
499
|
+
]
|
500
|
+
return [
|
501
|
+
SpectralMask(
|
502
|
+
f_max_width=spectral_mask.f_max_width,
|
503
|
+
f_num=spectral_mask.f_num,
|
504
|
+
t_max_width=spectral_mask.t_max_width,
|
505
|
+
t_num=spectral_mask.t_num,
|
506
|
+
t_max_percent=spectral_mask.t_max_percent,
|
507
|
+
)
|
508
|
+
for spectral_mask in spectral_masks
|
509
|
+
]
|
378
510
|
|
379
511
|
def spectral_mask(self, sm_id: int) -> SpectralMask:
|
380
512
|
"""Get spectral mask with ID from db
|
@@ -392,31 +524,40 @@ class MixtureDatabase:
|
|
392
524
|
"""
|
393
525
|
import json
|
394
526
|
|
395
|
-
from .datatypes import
|
396
|
-
from .datatypes import
|
527
|
+
from .datatypes import TruthConfig
|
528
|
+
from .datatypes import TruthConfigs
|
397
529
|
from .db_datatypes import TargetFileRecord
|
398
530
|
|
399
531
|
with self.db() as c:
|
400
532
|
target_files: TargetFiles = []
|
401
|
-
target_file_records = [
|
402
|
-
|
533
|
+
target_file_records = [
|
534
|
+
TargetFileRecord(*result) for result in c.execute("SELECT * FROM target_file").fetchall()
|
535
|
+
]
|
403
536
|
for target_file_record in target_file_records:
|
404
|
-
|
405
|
-
for
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
537
|
+
truth_configs: TruthConfigs = {}
|
538
|
+
for truth_config_records in c.execute(
|
539
|
+
"SELECT truth_config.config "
|
540
|
+
+ "FROM truth_config, target_file_truth_config "
|
541
|
+
+ "WHERE ? = target_file_truth_config.target_file_id "
|
542
|
+
+ "AND truth_config.id = target_file_truth_config.truth_config_id",
|
543
|
+
(target_file_record.id,),
|
544
|
+
).fetchall():
|
545
|
+
truth_config = json.loads(truth_config_records[0])
|
546
|
+
truth_configs[truth_config["name"]] = TruthConfig(
|
547
|
+
function=truth_config["function"],
|
548
|
+
stride_reduction=truth_config["stride_reduction"],
|
549
|
+
config=truth_config["config"],
|
550
|
+
)
|
551
|
+
target_files.append(
|
552
|
+
TargetFile(
|
553
|
+
name=target_file_record.name,
|
554
|
+
samples=target_file_record.samples,
|
555
|
+
class_indices=json.loads(target_file_record.class_indices),
|
556
|
+
level_type=target_file_record.level_type,
|
557
|
+
truth_configs=truth_configs,
|
558
|
+
speaker_id=target_file_record.speaker_id,
|
559
|
+
)
|
560
|
+
)
|
420
561
|
return target_files
|
421
562
|
|
422
563
|
@cached_property
|
@@ -452,8 +593,10 @@ class MixtureDatabase:
|
|
452
593
|
:return: Noise files
|
453
594
|
"""
|
454
595
|
with self.db() as c:
|
455
|
-
return [
|
456
|
-
|
596
|
+
return [
|
597
|
+
NoiseFile(name=noise[0], samples=noise[1])
|
598
|
+
for noise in c.execute("SELECT noise_file.name, samples FROM noise_file").fetchall()
|
599
|
+
]
|
457
600
|
|
458
601
|
@cached_property
|
459
602
|
def noise_file_ids(self) -> list[int]:
|
@@ -487,9 +630,21 @@ class MixtureDatabase:
|
|
487
630
|
|
488
631
|
:return: Impulse response files
|
489
632
|
"""
|
633
|
+
import json
|
634
|
+
|
635
|
+
from .datatypes import ImpulseResponseFile
|
636
|
+
|
490
637
|
with self.db() as c:
|
491
|
-
|
492
|
-
|
638
|
+
# for impulse_response in c.execute(
|
639
|
+
# "SELECT impulse_response_file.* FROM impulse_response_file"
|
640
|
+
# ).fetchall():
|
641
|
+
# print(impulse_response)
|
642
|
+
return [
|
643
|
+
ImpulseResponseFile(impulse_response[1], json.loads(impulse_response[2]))
|
644
|
+
for impulse_response in c.execute(
|
645
|
+
"SELECT impulse_response_file.* FROM impulse_response_file"
|
646
|
+
).fetchall()
|
647
|
+
]
|
493
648
|
|
494
649
|
@cached_property
|
495
650
|
def impulse_response_file_ids(self) -> list[int]:
|
@@ -498,15 +653,19 @@ class MixtureDatabase:
|
|
498
653
|
:return: List of impulse response file IDs
|
499
654
|
"""
|
500
655
|
with self.db() as c:
|
501
|
-
return [
|
502
|
-
|
656
|
+
return [
|
657
|
+
int(item[0])
|
658
|
+
for item in c.execute("SELECT impulse_response_file.id FROM impulse_response_file").fetchall()
|
659
|
+
]
|
503
660
|
|
504
|
-
def impulse_response_file(self, ir_id: int) -> str:
|
661
|
+
def impulse_response_file(self, ir_id: int | None) -> str | None:
|
505
662
|
"""Get impulse response file with ID from db
|
506
663
|
|
507
664
|
:param ir_id: Impulse response file ID
|
508
665
|
:return: Noise
|
509
666
|
"""
|
667
|
+
if ir_id is None:
|
668
|
+
return None
|
510
669
|
return _impulse_response_file(self.db, ir_id)
|
511
670
|
|
512
671
|
@cached_property
|
@@ -524,18 +683,22 @@ class MixtureDatabase:
|
|
524
683
|
|
525
684
|
:return: Mixtures
|
526
685
|
"""
|
527
|
-
from .helpers import to_mixture
|
528
|
-
from .helpers import to_target
|
529
686
|
from .db_datatypes import MixtureRecord
|
530
687
|
from .db_datatypes import TargetRecord
|
688
|
+
from .helpers import to_mixture
|
689
|
+
from .helpers import to_target
|
531
690
|
|
532
691
|
with self.db() as c:
|
533
692
|
mixtures: Mixtures = []
|
534
693
|
for mixture in [MixtureRecord(*record) for record in c.execute("SELECT * FROM mixture").fetchall()]:
|
535
|
-
targets = [
|
536
|
-
|
537
|
-
|
538
|
-
|
694
|
+
targets = [
|
695
|
+
to_target(TargetRecord(*target))
|
696
|
+
for target in c.execute(
|
697
|
+
"SELECT target.* FROM target, mixture_target "
|
698
|
+
+ "WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
|
699
|
+
(mixture.id,),
|
700
|
+
).fetchall()
|
701
|
+
]
|
539
702
|
mixtures.append(to_mixture(mixture, targets))
|
540
703
|
return mixtures
|
541
704
|
|
@@ -561,23 +724,15 @@ class MixtureDatabase:
|
|
561
724
|
with self.db() as c:
|
562
725
|
return int(c.execute("SELECT top.mixid_width FROM top").fetchone()[0])
|
563
726
|
|
564
|
-
def
|
565
|
-
"""
|
727
|
+
def mixture_location(self, m_id: int) -> str:
|
728
|
+
"""Get the file location for the give mixture ID
|
566
729
|
|
567
|
-
:param
|
568
|
-
:return:
|
730
|
+
:param m_id: Zero-based mixture ID
|
731
|
+
:return: File location
|
569
732
|
"""
|
570
733
|
from os.path import join
|
571
734
|
|
572
|
-
return join(self.location, name)
|
573
|
-
|
574
|
-
def mixture_filename(self, m_id: int) -> str:
|
575
|
-
"""Get the HDF5 file name for the give mixture ID
|
576
|
-
|
577
|
-
:param m_id: Zero-based mixture ID
|
578
|
-
:return: File name
|
579
|
-
"""
|
580
|
-
return self.location_filename(self.mixture(m_id).name)
|
735
|
+
return join(self.location, self.mixture(m_id).name)
|
581
736
|
|
582
737
|
@cached_property
|
583
738
|
def num_mixtures(self) -> int:
|
@@ -595,9 +750,9 @@ class MixtureDatabase:
|
|
595
750
|
:param items: String(s) of dataset(s) to retrieve
|
596
751
|
:return: Data (or tuple of data)
|
597
752
|
"""
|
598
|
-
from .
|
753
|
+
from sonusai.mixture import read_cached_data
|
599
754
|
|
600
|
-
return
|
755
|
+
return read_cached_data(self.location, "mixture", self.mixture(m_id).name, items)
|
601
756
|
|
602
757
|
def read_target_audio(self, t_id: int) -> AudioT:
|
603
758
|
"""Read target audio
|
@@ -624,10 +779,19 @@ class MixtureDatabase:
|
|
624
779
|
audio = read_audio(noise.name)
|
625
780
|
audio = apply_augmentation(audio, mixture.noise.augmentation)
|
626
781
|
if mixture.noise.augmentation.ir is not None:
|
627
|
-
audio = apply_impulse_response(
|
782
|
+
audio = apply_impulse_response(
|
783
|
+
audio,
|
784
|
+
read_ir(self.impulse_response_file(mixture.noise.augmentation.ir)),
|
785
|
+
)
|
628
786
|
|
629
787
|
return audio
|
630
788
|
|
789
|
+
def mixture_class_indices(self, m_id: int) -> list[int]:
|
790
|
+
class_indices: list[int] = []
|
791
|
+
for t_id in self.mixture(m_id).target_ids:
|
792
|
+
class_indices.extend(self.target_file(t_id).class_indices)
|
793
|
+
return sorted(set(class_indices))
|
794
|
+
|
631
795
|
def mixture_targets(self, m_id: int, force: bool = False) -> AudiosT:
|
632
796
|
"""Get the list of augmented target audio data (one per target in the mixup) for the given mixture ID
|
633
797
|
|
@@ -635,36 +799,34 @@ class MixtureDatabase:
|
|
635
799
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
636
800
|
:return: List of augmented target audio data (one per target in the mixup)
|
637
801
|
"""
|
638
|
-
from sonusai import SonusAIError
|
639
802
|
from .augmentation import apply_augmentation
|
640
803
|
from .augmentation import apply_gain
|
641
804
|
from .augmentation import pad_audio_to_length
|
642
805
|
|
643
806
|
if not force:
|
644
|
-
targets_audio = self.read_mixture_data(m_id,
|
807
|
+
targets_audio = self.read_mixture_data(m_id, "targets")
|
645
808
|
if targets_audio is not None:
|
646
809
|
return list(targets_audio)
|
647
810
|
|
648
811
|
mixture = self.mixture(m_id)
|
649
812
|
if mixture is None:
|
650
|
-
raise
|
813
|
+
raise ValueError(f"Could not find mixture for m_id: {m_id}")
|
651
814
|
|
652
815
|
targets_audio = []
|
653
816
|
for target in mixture.targets:
|
654
817
|
target_audio = self.read_target_audio(target.file_id)
|
655
|
-
target_audio = apply_augmentation(
|
656
|
-
|
657
|
-
|
818
|
+
target_audio = apply_augmentation(
|
819
|
+
audio=target_audio,
|
820
|
+
augmentation=target.augmentation,
|
821
|
+
frame_length=self.feature_step_samples,
|
822
|
+
)
|
658
823
|
target_audio = apply_gain(audio=target_audio, gain=mixture.target_snr_gain)
|
659
824
|
target_audio = pad_audio_to_length(audio=target_audio, length=mixture.samples)
|
660
825
|
targets_audio.append(target_audio)
|
661
826
|
|
662
827
|
return targets_audio
|
663
828
|
|
664
|
-
def mixture_targets_f(self,
|
665
|
-
m_id: int,
|
666
|
-
targets: Optional[AudiosT] = None,
|
667
|
-
force: bool = False) -> AudiosF:
|
829
|
+
def mixture_targets_f(self, m_id: int, targets: AudiosT | None = None, force: bool = False) -> AudiosF:
|
668
830
|
"""Get the list of augmented target transform data (one per target in the mixup) for the given mixture ID
|
669
831
|
|
670
832
|
:param m_id: Zero-based mixture ID
|
@@ -679,10 +841,7 @@ class MixtureDatabase:
|
|
679
841
|
|
680
842
|
return [forward_transform(target, self.ft_config) for target in targets]
|
681
843
|
|
682
|
-
def mixture_target(self,
|
683
|
-
m_id: int,
|
684
|
-
targets: Optional[AudiosT] = None,
|
685
|
-
force: bool = False) -> AudioT:
|
844
|
+
def mixture_target(self, m_id: int, targets: AudiosT | None = None, force: bool = False) -> AudioT:
|
686
845
|
"""Get the augmented target audio data for the given mixture ID
|
687
846
|
|
688
847
|
:param m_id: Zero-based mixture ID
|
@@ -693,7 +852,7 @@ class MixtureDatabase:
|
|
693
852
|
from .helpers import get_target
|
694
853
|
|
695
854
|
if not force:
|
696
|
-
target = self.read_mixture_data(m_id,
|
855
|
+
target = self.read_mixture_data(m_id, "target")
|
697
856
|
if target is not None:
|
698
857
|
return target
|
699
858
|
|
@@ -702,11 +861,13 @@ class MixtureDatabase:
|
|
702
861
|
|
703
862
|
return get_target(self, self.mixture(m_id), targets)
|
704
863
|
|
705
|
-
def mixture_target_f(
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
864
|
+
def mixture_target_f(
|
865
|
+
self,
|
866
|
+
m_id: int,
|
867
|
+
targets: AudiosT | None = None,
|
868
|
+
target: AudioT | None = None,
|
869
|
+
force: bool = False,
|
870
|
+
) -> AudioF:
|
710
871
|
"""Get the augmented target transform data for the given mixture ID
|
711
872
|
|
712
873
|
:param m_id: Zero-based mixture ID
|
@@ -722,9 +883,7 @@ class MixtureDatabase:
|
|
722
883
|
|
723
884
|
return forward_transform(target, self.ft_config)
|
724
885
|
|
725
|
-
def mixture_noise(self,
|
726
|
-
m_id: int,
|
727
|
-
force: bool = False) -> AudioT:
|
886
|
+
def mixture_noise(self, m_id: int, force: bool = False) -> AudioT:
|
728
887
|
"""Get the augmented noise audio data for the given mixture ID
|
729
888
|
|
730
889
|
:param m_id: Zero-based mixture ID
|
@@ -735,7 +894,7 @@ class MixtureDatabase:
|
|
735
894
|
from .augmentation import apply_gain
|
736
895
|
|
737
896
|
if not force:
|
738
|
-
noise = self.read_mixture_data(m_id,
|
897
|
+
noise = self.read_mixture_data(m_id, "noise")
|
739
898
|
if noise is not None:
|
740
899
|
return noise
|
741
900
|
|
@@ -744,10 +903,7 @@ class MixtureDatabase:
|
|
744
903
|
noise = get_next_noise(audio=noise, offset=mixture.noise.offset, length=mixture.samples)
|
745
904
|
return apply_gain(audio=noise, gain=mixture.noise_snr_gain)
|
746
905
|
|
747
|
-
def mixture_noise_f(self,
|
748
|
-
m_id: int,
|
749
|
-
noise: Optional[AudioT] = None,
|
750
|
-
force: bool = False) -> AudioF:
|
906
|
+
def mixture_noise_f(self, m_id: int, noise: AudioT | None = None, force: bool = False) -> AudioF:
|
751
907
|
"""Get the augmented noise transform for the given mixture ID
|
752
908
|
|
753
909
|
:param m_id: Zero-based mixture ID
|
@@ -762,12 +918,14 @@ class MixtureDatabase:
|
|
762
918
|
|
763
919
|
return forward_transform(noise, self.ft_config)
|
764
920
|
|
765
|
-
def mixture_mixture(
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
921
|
+
def mixture_mixture(
|
922
|
+
self,
|
923
|
+
m_id: int,
|
924
|
+
targets: AudiosT | None = None,
|
925
|
+
target: AudioT | None = None,
|
926
|
+
noise: AudioT | None = None,
|
927
|
+
force: bool = False,
|
928
|
+
) -> AudioT:
|
771
929
|
"""Get the mixture audio data for the given mixture ID
|
772
930
|
|
773
931
|
:param m_id: Zero-based mixture ID
|
@@ -778,7 +936,7 @@ class MixtureDatabase:
|
|
778
936
|
:return: Mixture audio data
|
779
937
|
"""
|
780
938
|
if not force:
|
781
|
-
mixture = self.read_mixture_data(m_id,
|
939
|
+
mixture = self.read_mixture_data(m_id, "mixture")
|
782
940
|
if mixture is not None:
|
783
941
|
return mixture
|
784
942
|
|
@@ -790,13 +948,15 @@ class MixtureDatabase:
|
|
790
948
|
|
791
949
|
return target + noise
|
792
950
|
|
793
|
-
def mixture_mixture_f(
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
951
|
+
def mixture_mixture_f(
|
952
|
+
self,
|
953
|
+
m_id: int,
|
954
|
+
targets: AudiosT | None = None,
|
955
|
+
target: AudioT | None = None,
|
956
|
+
noise: AudioT | None = None,
|
957
|
+
mixture: AudioT | None = None,
|
958
|
+
force: bool = False,
|
959
|
+
) -> AudioF:
|
800
960
|
"""Get the mixture transform for the given mixture ID
|
801
961
|
|
802
962
|
:param m_id: Zero-based mixture ID
|
@@ -817,18 +977,22 @@ class MixtureDatabase:
|
|
817
977
|
|
818
978
|
m = self.mixture(m_id)
|
819
979
|
if m.spectral_mask_id is not None:
|
820
|
-
mixture_f = apply_spectral_mask(
|
821
|
-
|
822
|
-
|
980
|
+
mixture_f = apply_spectral_mask(
|
981
|
+
audio_f=mixture_f,
|
982
|
+
spectral_mask=self.spectral_mask(int(m.spectral_mask_id)),
|
983
|
+
seed=m.spectral_mask_seed,
|
984
|
+
)
|
823
985
|
|
824
986
|
return mixture_f
|
825
987
|
|
826
|
-
def mixture_truth_t(
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
988
|
+
def mixture_truth_t(
|
989
|
+
self,
|
990
|
+
m_id: int,
|
991
|
+
targets: AudiosT | None = None,
|
992
|
+
noise: AudioT | None = None,
|
993
|
+
mixture: AudioT | None = None,
|
994
|
+
force: bool = False,
|
995
|
+
) -> TruthDict:
|
832
996
|
"""Get the truth_t data for the given mixture ID
|
833
997
|
|
834
998
|
:param m_id: Zero-based mixture ID
|
@@ -838,10 +1002,10 @@ class MixtureDatabase:
|
|
838
1002
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
839
1003
|
:return: truth_t data
|
840
1004
|
"""
|
841
|
-
from .helpers import
|
1005
|
+
from .helpers import get_truth
|
842
1006
|
|
843
1007
|
if not force:
|
844
|
-
truth_t = self.read_mixture_data(m_id,
|
1008
|
+
truth_t = self.read_mixture_data(m_id, "truth_t")
|
845
1009
|
if truth_t is not None:
|
846
1010
|
return truth_t
|
847
1011
|
|
@@ -852,19 +1016,18 @@ class MixtureDatabase:
|
|
852
1016
|
noise = self.mixture_noise(m_id, force)
|
853
1017
|
|
854
1018
|
if force or mixture is None:
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
force: bool = False) -> Segsnr:
|
1019
|
+
mixture = self.mixture_mixture(m_id, targets=targets, noise=noise, force=force)
|
1020
|
+
|
1021
|
+
return get_truth(self, self.mixture(m_id), targets, noise, mixture)
|
1022
|
+
|
1023
|
+
def mixture_segsnr_t(
|
1024
|
+
self,
|
1025
|
+
m_id: int,
|
1026
|
+
targets: AudiosT | None = None,
|
1027
|
+
target: AudioT | None = None,
|
1028
|
+
noise: AudioT | None = None,
|
1029
|
+
force: bool = False,
|
1030
|
+
) -> Segsnr:
|
868
1031
|
"""Get the segsnr_t data for the given mixture ID
|
869
1032
|
|
870
1033
|
:param m_id: Zero-based mixture ID
|
@@ -877,7 +1040,7 @@ class MixtureDatabase:
|
|
877
1040
|
from .helpers import get_segsnr_t
|
878
1041
|
|
879
1042
|
if not force:
|
880
|
-
segsnr_t = self.read_mixture_data(m_id,
|
1043
|
+
segsnr_t = self.read_mixture_data(m_id, "segsnr_t")
|
881
1044
|
if segsnr_t is not None:
|
882
1045
|
return segsnr_t
|
883
1046
|
|
@@ -889,13 +1052,15 @@ class MixtureDatabase:
|
|
889
1052
|
|
890
1053
|
return get_segsnr_t(self, self.mixture(m_id), target, noise)
|
891
1054
|
|
892
|
-
def mixture_segsnr(
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
1055
|
+
def mixture_segsnr(
|
1056
|
+
self,
|
1057
|
+
m_id: int,
|
1058
|
+
segsnr_t: Segsnr | None = None,
|
1059
|
+
targets: AudiosT | None = None,
|
1060
|
+
target: AudioT | None = None,
|
1061
|
+
noise: AudioT | None = None,
|
1062
|
+
force: bool = False,
|
1063
|
+
) -> Segsnr:
|
899
1064
|
"""Get the segsnr data for the given mixture ID
|
900
1065
|
|
901
1066
|
:param m_id: Zero-based mixture ID
|
@@ -907,28 +1072,30 @@ class MixtureDatabase:
|
|
907
1072
|
:return: segsnr data
|
908
1073
|
"""
|
909
1074
|
if not force:
|
910
|
-
segsnr = self.read_mixture_data(m_id,
|
1075
|
+
segsnr = self.read_mixture_data(m_id, "segsnr")
|
911
1076
|
if segsnr is not None:
|
912
1077
|
return segsnr
|
913
1078
|
|
914
|
-
segsnr_t = self.read_mixture_data(m_id,
|
1079
|
+
segsnr_t = self.read_mixture_data(m_id, "segsnr_t")
|
915
1080
|
if segsnr_t is not None:
|
916
|
-
return segsnr_t[0::self.ft_config.
|
1081
|
+
return segsnr_t[0 :: self.ft_config.overlap]
|
917
1082
|
|
918
1083
|
if force or segsnr_t is None:
|
919
1084
|
segsnr_t = self.mixture_segsnr_t(m_id, targets, target, noise, force)
|
920
1085
|
|
921
|
-
return segsnr_t[0::self.ft_config.
|
922
|
-
|
923
|
-
def mixture_ft(
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
1086
|
+
return segsnr_t[0 :: self.ft_config.overlap]
|
1087
|
+
|
1088
|
+
def mixture_ft(
|
1089
|
+
self,
|
1090
|
+
m_id: int,
|
1091
|
+
targets: AudiosT | None = None,
|
1092
|
+
target: AudioT | None = None,
|
1093
|
+
noise: AudioT | None = None,
|
1094
|
+
mixture_f: AudioF | None = None,
|
1095
|
+
mixture: AudioT | None = None,
|
1096
|
+
truth_t: TruthDict | None = None,
|
1097
|
+
force: bool = False,
|
1098
|
+
) -> tuple[Feature, TruthDict]:
|
932
1099
|
"""Get the feature and truth_f data for the given mixture ID
|
933
1100
|
|
934
1101
|
:param m_id: Zero-based mixture ID
|
@@ -941,63 +1108,45 @@ class MixtureDatabase:
|
|
941
1108
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
942
1109
|
:return: Tuple of (feature, truth_f) data
|
943
1110
|
"""
|
944
|
-
from dataclasses import asdict
|
945
|
-
|
946
|
-
import numpy as np
|
947
1111
|
from pyaaware import FeatureGenerator
|
948
1112
|
|
949
|
-
from .truth import
|
1113
|
+
from .truth import truth_stride_reduction
|
950
1114
|
|
951
1115
|
if not force:
|
952
|
-
feature, truth_f = self.read_mixture_data(m_id, [
|
1116
|
+
feature, truth_f = self.read_mixture_data(m_id, ["feature", "truth_f"])
|
953
1117
|
if feature is not None and truth_f is not None:
|
954
1118
|
return feature, truth_f
|
955
1119
|
|
956
1120
|
if force or mixture_f is None:
|
957
|
-
mixture_f = self.mixture_mixture_f(
|
958
|
-
|
959
|
-
|
960
|
-
|
961
|
-
|
962
|
-
|
1121
|
+
mixture_f = self.mixture_mixture_f(
|
1122
|
+
m_id=m_id,
|
1123
|
+
targets=targets,
|
1124
|
+
target=target,
|
1125
|
+
noise=noise,
|
1126
|
+
mixture=mixture,
|
1127
|
+
force=force,
|
1128
|
+
)
|
963
1129
|
|
964
1130
|
if force or truth_t is None:
|
965
1131
|
truth_t = self.mixture_truth_t(m_id=m_id, targets=targets, noise=noise, force=force)
|
966
1132
|
|
967
|
-
|
968
|
-
transform_frames = self.mixture_transform_frames(m_id)
|
969
|
-
feature_frames = self.mixture_feature_frames(m_id)
|
970
|
-
|
971
|
-
if truth_t is None:
|
972
|
-
truth_t = np.zeros((m.samples, self.num_classes), dtype=np.float32)
|
973
|
-
|
974
|
-
feature = np.empty((feature_frames, self.fg_stride, self.feature_parameters), dtype=np.float32)
|
975
|
-
truth_f = np.empty((feature_frames, self.num_classes), dtype=np.complex64)
|
976
|
-
|
977
|
-
fg = FeatureGenerator(**asdict(self.fg_config))
|
978
|
-
feature_frame = 0
|
979
|
-
for transform_frame in range(transform_frames):
|
980
|
-
indices = slice(transform_frame * self.ft_config.R, (transform_frame + 1) * self.ft_config.R)
|
981
|
-
fg.execute(mixture_f[transform_frame],
|
982
|
-
truth_reduction(truth_t[indices], self.truth_reduction_function))
|
1133
|
+
fg = FeatureGenerator(self.fg_config.feature_mode, self.fg_config.truth_parameters)
|
983
1134
|
|
984
|
-
|
985
|
-
|
986
|
-
|
987
|
-
feature_frame += 1
|
988
|
-
|
989
|
-
if np.isreal(truth_f).all():
|
990
|
-
return feature, truth_f.real
|
1135
|
+
feature, truth_f = fg.execute_all(mixture_f, truth_t)
|
1136
|
+
for key in self.truth_configs:
|
1137
|
+
truth_f[key] = truth_stride_reduction(truth_f[key], self.truth_configs[key].stride_reduction)
|
991
1138
|
|
992
1139
|
return feature, truth_f
|
993
1140
|
|
994
|
-
def mixture_feature(
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1141
|
+
def mixture_feature(
|
1142
|
+
self,
|
1143
|
+
m_id: int,
|
1144
|
+
targets: AudiosT | None = None,
|
1145
|
+
noise: AudioT | None = None,
|
1146
|
+
mixture: AudioT | None = None,
|
1147
|
+
truth_t: TruthDict | None = None,
|
1148
|
+
force: bool = False,
|
1149
|
+
) -> Feature:
|
1001
1150
|
"""Get the feature data for the given mixture ID
|
1002
1151
|
|
1003
1152
|
:param m_id: Zero-based mixture ID
|
@@ -1008,21 +1157,25 @@ class MixtureDatabase:
|
|
1008
1157
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1009
1158
|
:return: Feature data
|
1010
1159
|
"""
|
1011
|
-
feature, _ = self.mixture_ft(
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1160
|
+
feature, _ = self.mixture_ft(
|
1161
|
+
m_id=m_id,
|
1162
|
+
targets=targets,
|
1163
|
+
noise=noise,
|
1164
|
+
mixture=mixture,
|
1165
|
+
truth_t=truth_t,
|
1166
|
+
force=force,
|
1167
|
+
)
|
1017
1168
|
return feature
|
1018
1169
|
|
1019
|
-
def mixture_truth_f(
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1170
|
+
def mixture_truth_f(
|
1171
|
+
self,
|
1172
|
+
m_id: int,
|
1173
|
+
targets: AudiosT | None = None,
|
1174
|
+
noise: AudioT | None = None,
|
1175
|
+
mixture: AudioT | None = None,
|
1176
|
+
truth_t: TruthDict | None = None,
|
1177
|
+
force: bool = False,
|
1178
|
+
) -> TruthDict:
|
1026
1179
|
"""Get the truth_f data for the given mixture ID
|
1027
1180
|
|
1028
1181
|
:param m_id: Zero-based mixture ID
|
@@ -1033,20 +1186,24 @@ class MixtureDatabase:
|
|
1033
1186
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1034
1187
|
:return: truth_f data
|
1035
1188
|
"""
|
1036
|
-
_, truth_f = self.mixture_ft(
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1189
|
+
_, truth_f = self.mixture_ft(
|
1190
|
+
m_id=m_id,
|
1191
|
+
targets=targets,
|
1192
|
+
noise=noise,
|
1193
|
+
mixture=mixture,
|
1194
|
+
truth_t=truth_t,
|
1195
|
+
force=force,
|
1196
|
+
)
|
1042
1197
|
return truth_f
|
1043
1198
|
|
1044
|
-
def mixture_class_count(
|
1045
|
-
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1199
|
+
def mixture_class_count(
|
1200
|
+
self,
|
1201
|
+
m_id: int,
|
1202
|
+
targets: AudiosT | None = None,
|
1203
|
+
noise: AudioT | None = None,
|
1204
|
+
truth_t: TruthDict | None = None,
|
1205
|
+
) -> ClassCount:
|
1206
|
+
"""Compute the number of frames for which each class index is active for the given mixture ID
|
1050
1207
|
|
1051
1208
|
:param m_id: Zero-based mixture ID
|
1052
1209
|
:param targets: List of augmented target audio (one per target in the mixup)
|
@@ -1061,10 +1218,9 @@ class MixtureDatabase:
|
|
1061
1218
|
|
1062
1219
|
class_count = [0] * self.num_classes
|
1063
1220
|
num_classes = self.num_classes
|
1064
|
-
if self.
|
1065
|
-
|
1066
|
-
|
1067
|
-
class_count[cl] = int(np.sum(truth_t[:, cl] >= self.class_weights_thresholds[cl]))
|
1221
|
+
if "sed" in self.truth_configs:
|
1222
|
+
for cl in range(num_classes):
|
1223
|
+
class_count[cl] = int(np.sum(truth_t["sed"][:, cl] >= self.class_weights_thresholds[cl]))
|
1068
1224
|
|
1069
1225
|
return class_count
|
1070
1226
|
|
@@ -1086,7 +1242,7 @@ class MixtureDatabase:
|
|
1086
1242
|
def speech_metadata_tiers(self) -> list[str]:
|
1087
1243
|
return sorted(set(self.speaker_metadata_tiers + self.textgrid_metadata_tiers))
|
1088
1244
|
|
1089
|
-
def speaker(self, s_id: int | None, tier: str) ->
|
1245
|
+
def speaker(self, s_id: int | None, tier: str) -> str | None:
|
1090
1246
|
return _speaker(self.db, s_id, tier)
|
1091
1247
|
|
1092
1248
|
def speech_metadata(self, tier: str) -> list[str]:
|
@@ -1126,9 +1282,13 @@ class MixtureDatabase:
|
|
1126
1282
|
entries = []
|
1127
1283
|
for entry in data:
|
1128
1284
|
if target.augmentation.tempo is not None:
|
1129
|
-
entries.append(
|
1130
|
-
|
1131
|
-
|
1285
|
+
entries.append(
|
1286
|
+
Interval(
|
1287
|
+
entry.start / target.augmentation.tempo,
|
1288
|
+
entry.end / target.augmentation.tempo,
|
1289
|
+
entry.label,
|
1290
|
+
)
|
1291
|
+
)
|
1132
1292
|
else:
|
1133
1293
|
entries.append(entry)
|
1134
1294
|
results.append(entries)
|
@@ -1138,12 +1298,9 @@ class MixtureDatabase:
|
|
1138
1298
|
for target in self.mixture(mixid).targets:
|
1139
1299
|
results.append(self.speaker(self.target_file(target.file_id).speaker_id, tier))
|
1140
1300
|
|
1141
|
-
return
|
1301
|
+
return results
|
1142
1302
|
|
1143
|
-
def mixids_for_speech_metadata(self,
|
1144
|
-
tier: str,
|
1145
|
-
value: str = None,
|
1146
|
-
where: str = None) -> list[int]:
|
1303
|
+
def mixids_for_speech_metadata(self, tier: str, value: str | None = None, where: str | None = None) -> list[int]:
|
1147
1304
|
"""Get a list of mixture IDs for the given speech metadata tier.
|
1148
1305
|
|
1149
1306
|
If 'where' is None, then include mixture IDs whose tier values are equal to the given 'value'.
|
@@ -1162,25 +1319,27 @@ class MixtureDatabase:
|
|
1162
1319
|
>>> mixids = mixdb.mixids_for_speech_metadata('dialect', where="dialect in ('New York City', 'Northern')")
|
1163
1320
|
Get mixture IDs for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
|
1164
1321
|
"""
|
1165
|
-
from sonusai import SonusAIError
|
1166
|
-
|
1167
1322
|
if value is None and where is None:
|
1168
|
-
raise
|
1323
|
+
raise ValueError("Must provide either value or where")
|
1169
1324
|
|
1170
1325
|
if where is None:
|
1171
1326
|
where = f"{tier} = '{value}'"
|
1172
1327
|
|
1173
1328
|
if tier in self.textgrid_metadata_tiers:
|
1174
|
-
raise
|
1329
|
+
raise ValueError(f"TextGrid tier data, '{tier}', is not supported in mixids_for_speech_metadata().")
|
1175
1330
|
|
1176
1331
|
with self.db() as c:
|
1177
|
-
speaker_ids = [
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1332
|
+
speaker_ids = [
|
1333
|
+
speaker_id[0] for speaker_id in c.execute(f"SELECT id FROM speaker WHERE {where}").fetchall()
|
1334
|
+
]
|
1335
|
+
results = c.execute(
|
1336
|
+
"SELECT id FROM target_file " + f"WHERE speaker_id IN ({','.join(map(str, speaker_ids))})"
|
1337
|
+
).fetchall()
|
1181
1338
|
target_file_ids = [target_file_id[0] for target_file_id in results]
|
1182
|
-
results = c.execute(
|
1183
|
-
|
1339
|
+
results = c.execute(
|
1340
|
+
"SELECT mixture_id FROM mixture_target "
|
1341
|
+
+ f"WHERE mixture_target.target_id IN ({','.join(map(str, target_file_ids))})"
|
1342
|
+
).fetchall()
|
1184
1343
|
|
1185
1344
|
return [mixture_id[0] - 1 for mixture_id in results]
|
1186
1345
|
|
@@ -1189,9 +1348,9 @@ class MixtureDatabase:
|
|
1189
1348
|
|
1190
1349
|
return mixture_all_speech_metadata(self, self.mixture(m_id))
|
1191
1350
|
|
1192
|
-
def mixture_metrics(
|
1193
|
-
|
1194
|
-
|
1351
|
+
def mixture_metrics(
|
1352
|
+
self, m_id: int, metrics: list[str], force: bool = False
|
1353
|
+
) -> list[float | int | str | Segsnr | None]:
|
1195
1354
|
"""Get metrics data for the given mixture ID
|
1196
1355
|
|
1197
1356
|
:param m_id: Zero-based mixture ID
|
@@ -1199,12 +1358,11 @@ class MixtureDatabase:
|
|
1199
1358
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
1200
1359
|
:return: List of metric data
|
1201
1360
|
"""
|
1202
|
-
from
|
1361
|
+
from collections.abc import Callable
|
1203
1362
|
|
1204
1363
|
import numpy as np
|
1205
1364
|
from pystoi import stoi
|
1206
1365
|
|
1207
|
-
from sonusai import SonusAIError
|
1208
1366
|
from sonusai.metrics import calc_audio_stats
|
1209
1367
|
from sonusai.metrics import calc_phase_distance
|
1210
1368
|
from sonusai.metrics import calc_segsnr_f
|
@@ -1314,7 +1472,7 @@ class MixtureDatabase:
|
|
1314
1472
|
def get() -> AudioStatsMetrics:
|
1315
1473
|
nonlocal state
|
1316
1474
|
if state is None:
|
1317
|
-
state = calc_audio_stats(target_audio(), self.fg_info.ft_config.
|
1475
|
+
state = calc_audio_stats(target_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
|
1318
1476
|
return state
|
1319
1477
|
|
1320
1478
|
return get
|
@@ -1327,7 +1485,7 @@ class MixtureDatabase:
|
|
1327
1485
|
def get() -> AudioStatsMetrics:
|
1328
1486
|
nonlocal state
|
1329
1487
|
if state is None:
|
1330
|
-
state = calc_audio_stats(noise_audio(), self.fg_info.ft_config.
|
1488
|
+
state = calc_audio_stats(noise_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
|
1331
1489
|
return state
|
1332
1490
|
|
1333
1491
|
return get
|
@@ -1340,9 +1498,10 @@ class MixtureDatabase:
|
|
1340
1498
|
def get(asr_name) -> dict:
|
1341
1499
|
nonlocal state
|
1342
1500
|
if asr_name not in state:
|
1343
|
-
|
1344
|
-
if
|
1345
|
-
raise
|
1501
|
+
value = self.asr_configs.get(asr_name, None)
|
1502
|
+
if value is None:
|
1503
|
+
raise ValueError(f"Unrecognized ASR name: '{asr_name}'")
|
1504
|
+
state[asr_name] = value
|
1346
1505
|
return state[asr_name]
|
1347
1506
|
|
1348
1507
|
return get
|
@@ -1376,15 +1535,14 @@ class MixtureDatabase:
|
|
1376
1535
|
mixture_asr = create_mixture_asr()
|
1377
1536
|
|
1378
1537
|
def get_asr_name(m: str) -> str:
|
1379
|
-
parts = m.split(
|
1538
|
+
parts = m.split(".")
|
1380
1539
|
if len(parts) != 2:
|
1381
|
-
raise
|
1382
|
-
f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
|
1540
|
+
raise ValueError(f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
|
1383
1541
|
asr_name = parts[1]
|
1384
1542
|
return asr_name
|
1385
1543
|
|
1386
|
-
def calc(m: str) -> float | int | str | Segsnr:
|
1387
|
-
if m ==
|
1544
|
+
def calc(m: str) -> float | int | str | Segsnr | None:
|
1545
|
+
if m == "mxsnr":
|
1388
1546
|
return self.mixture(m_id).snr
|
1389
1547
|
|
1390
1548
|
# Get cached data first, if exists
|
@@ -1394,12 +1552,12 @@ class MixtureDatabase:
|
|
1394
1552
|
return value
|
1395
1553
|
|
1396
1554
|
# Otherwise, generate data as needed
|
1397
|
-
if m.startswith(
|
1555
|
+
if m.startswith("mxwer"):
|
1398
1556
|
asr_name = get_asr_name(m)
|
1399
1557
|
|
1400
1558
|
if self.mixture(m_id).snr < -96:
|
1401
1559
|
# noise only, ignore/reset target asr
|
1402
|
-
return float(
|
1560
|
+
return float("nan")
|
1403
1561
|
|
1404
1562
|
if target_asr(asr_name):
|
1405
1563
|
return calc_wer(mixture_asr(asr_name), target_asr(asr_name)).wer * 100
|
@@ -1407,159 +1565,166 @@ class MixtureDatabase:
|
|
1407
1565
|
# TODO: should this be NaN like above?
|
1408
1566
|
return float(0)
|
1409
1567
|
|
1410
|
-
if m.startswith(
|
1568
|
+
if m.startswith("basewer"):
|
1411
1569
|
asr_name = get_asr_name(m)
|
1412
1570
|
|
1413
|
-
text = self.mixture_speech_metadata(m_id,
|
1571
|
+
text = self.mixture_speech_metadata(m_id, "text")[0]
|
1414
1572
|
if text is not None:
|
1415
1573
|
return calc_wer(target_asr(asr_name), text).wer * 100
|
1416
1574
|
|
1417
1575
|
# TODO: should this be NaN like above?
|
1418
1576
|
return float(0)
|
1419
1577
|
|
1420
|
-
if m.startswith(
|
1578
|
+
if m.startswith("mxasr"):
|
1421
1579
|
return mixture_asr(get_asr_name(m))
|
1422
1580
|
|
1423
|
-
if m ==
|
1581
|
+
if m == "mxssnr_avg":
|
1424
1582
|
return calc_segsnr_f(segsnr_f()).avg
|
1425
1583
|
|
1426
|
-
if m ==
|
1584
|
+
if m == "mxssnr_std":
|
1427
1585
|
return calc_segsnr_f(segsnr_f()).std
|
1428
1586
|
|
1429
|
-
if m ==
|
1587
|
+
if m == "mxssnrdb_avg":
|
1430
1588
|
return calc_segsnr_f(segsnr_f()).db_avg
|
1431
1589
|
|
1432
|
-
if m ==
|
1590
|
+
if m == "mxssnrdb_std":
|
1433
1591
|
return calc_segsnr_f(segsnr_f()).db_std
|
1434
1592
|
|
1435
|
-
if m ==
|
1593
|
+
if m == "mxssnrf_avg":
|
1436
1594
|
return calc_segsnr_f_bin(target_f(), noise_f()).avg
|
1437
1595
|
|
1438
|
-
if m ==
|
1596
|
+
if m == "mxssnrf_std":
|
1439
1597
|
return calc_segsnr_f_bin(target_f(), noise_f()).std
|
1440
1598
|
|
1441
|
-
if m ==
|
1599
|
+
if m == "mxssnrdbf_avg":
|
1442
1600
|
return calc_segsnr_f_bin(target_f(), noise_f()).db_avg
|
1443
1601
|
|
1444
|
-
if m ==
|
1602
|
+
if m == "mxssnrdbf_std":
|
1445
1603
|
return calc_segsnr_f_bin(target_f(), noise_f()).db_std
|
1446
1604
|
|
1447
|
-
if m ==
|
1605
|
+
if m == "mxpesq":
|
1448
1606
|
if self.mixture(m_id).snr < -96:
|
1449
1607
|
return 0
|
1450
1608
|
return speech().pesq
|
1451
1609
|
|
1452
|
-
if m ==
|
1610
|
+
if m == "mxcsig":
|
1453
1611
|
if self.mixture(m_id).snr < -96:
|
1454
1612
|
return 0
|
1455
1613
|
return speech().csig
|
1456
1614
|
|
1457
|
-
if m ==
|
1615
|
+
if m == "mxcbak":
|
1458
1616
|
if self.mixture(m_id).snr < -96:
|
1459
1617
|
return 0
|
1460
1618
|
return speech().cbak
|
1461
1619
|
|
1462
|
-
if m ==
|
1620
|
+
if m == "mxcovl":
|
1463
1621
|
if self.mixture(m_id).snr < -96:
|
1464
1622
|
return 0
|
1465
1623
|
return speech().covl
|
1466
1624
|
|
1467
|
-
if m ==
|
1625
|
+
if m == "mxwsdr":
|
1468
1626
|
mixture = mixture_audio()[:, np.newaxis]
|
1469
1627
|
target = target_audio()[:, np.newaxis]
|
1470
1628
|
noise = noise_audio()[:, np.newaxis]
|
1471
|
-
return calc_wsdr(
|
1472
|
-
|
1473
|
-
|
1629
|
+
return calc_wsdr(
|
1630
|
+
hypothesis=np.concatenate((mixture, noise), axis=1),
|
1631
|
+
reference=np.concatenate((target, noise), axis=1),
|
1632
|
+
with_log=True,
|
1633
|
+
)[0]
|
1474
1634
|
|
1475
|
-
if m ==
|
1635
|
+
if m == "mxpd":
|
1476
1636
|
mixture_f = self.mixture_mixture_f(m_id)
|
1477
1637
|
return calc_phase_distance(hypothesis=mixture_f, reference=target_f())[0]
|
1478
1638
|
|
1479
|
-
if m ==
|
1480
|
-
return stoi(
|
1639
|
+
if m == "mxstoi":
|
1640
|
+
return stoi(
|
1641
|
+
x=target_audio(),
|
1642
|
+
y=mixture_audio(),
|
1643
|
+
fs_sig=SAMPLE_RATE,
|
1644
|
+
extended=False,
|
1645
|
+
)
|
1481
1646
|
|
1482
|
-
if m ==
|
1647
|
+
if m == "tdco":
|
1483
1648
|
return target_stats().dco
|
1484
1649
|
|
1485
|
-
if m ==
|
1650
|
+
if m == "tmin":
|
1486
1651
|
return target_stats().min
|
1487
1652
|
|
1488
|
-
if m ==
|
1653
|
+
if m == "tmax":
|
1489
1654
|
return target_stats().max
|
1490
1655
|
|
1491
|
-
if m ==
|
1656
|
+
if m == "tpkdb":
|
1492
1657
|
return target_stats().pkdb
|
1493
1658
|
|
1494
|
-
if m ==
|
1659
|
+
if m == "tlrms":
|
1495
1660
|
return target_stats().lrms
|
1496
1661
|
|
1497
|
-
if m ==
|
1662
|
+
if m == "tpkr":
|
1498
1663
|
return target_stats().pkr
|
1499
1664
|
|
1500
|
-
if m ==
|
1665
|
+
if m == "ttr":
|
1501
1666
|
return target_stats().tr
|
1502
1667
|
|
1503
|
-
if m ==
|
1668
|
+
if m == "tcr":
|
1504
1669
|
return target_stats().cr
|
1505
1670
|
|
1506
|
-
if m ==
|
1671
|
+
if m == "tfl":
|
1507
1672
|
return target_stats().fl
|
1508
1673
|
|
1509
|
-
if m ==
|
1674
|
+
if m == "tpkc":
|
1510
1675
|
return target_stats().pkc
|
1511
1676
|
|
1512
|
-
if m.startswith(
|
1677
|
+
if m.startswith("tasr"):
|
1513
1678
|
return target_asr(get_asr_name(m))
|
1514
1679
|
|
1515
|
-
if m ==
|
1680
|
+
if m == "ndco":
|
1516
1681
|
return noise_stats().dco
|
1517
1682
|
|
1518
|
-
if m ==
|
1683
|
+
if m == "nmin":
|
1519
1684
|
return noise_stats().min
|
1520
1685
|
|
1521
|
-
if m ==
|
1686
|
+
if m == "nmax":
|
1522
1687
|
return noise_stats().max
|
1523
1688
|
|
1524
|
-
if m ==
|
1689
|
+
if m == "npkdb":
|
1525
1690
|
return noise_stats().pkdb
|
1526
1691
|
|
1527
|
-
if m ==
|
1692
|
+
if m == "nlrms":
|
1528
1693
|
return noise_stats().lrms
|
1529
1694
|
|
1530
|
-
if m ==
|
1695
|
+
if m == "npkr":
|
1531
1696
|
return noise_stats().pkr
|
1532
1697
|
|
1533
|
-
if m ==
|
1698
|
+
if m == "ntr":
|
1534
1699
|
return noise_stats().tr
|
1535
1700
|
|
1536
|
-
if m ==
|
1701
|
+
if m == "ncr":
|
1537
1702
|
return noise_stats().cr
|
1538
1703
|
|
1539
|
-
if m ==
|
1704
|
+
if m == "nfl":
|
1540
1705
|
return noise_stats().fl
|
1541
1706
|
|
1542
|
-
if m ==
|
1707
|
+
if m == "npkc":
|
1543
1708
|
return noise_stats().pkc
|
1544
1709
|
|
1545
|
-
if m ==
|
1710
|
+
if m == "sedavg":
|
1546
1711
|
return 0
|
1547
1712
|
|
1548
|
-
if m ==
|
1713
|
+
if m == "sedcnt":
|
1549
1714
|
return 0
|
1550
1715
|
|
1551
|
-
if m ==
|
1716
|
+
if m == "sedtop3":
|
1552
1717
|
return np.zeros(3, dtype=np.float32)
|
1553
1718
|
|
1554
|
-
if m ==
|
1719
|
+
if m == "sedtopn":
|
1555
1720
|
return 0
|
1556
1721
|
|
1557
|
-
if m ==
|
1722
|
+
if m == "ssnr":
|
1558
1723
|
return segsnr_f()
|
1559
1724
|
|
1560
|
-
raise
|
1725
|
+
raise AttributeError(f"Unrecognized metric: '{m}'")
|
1561
1726
|
|
1562
|
-
result: list[float | int | str | Segsnr] = []
|
1727
|
+
result: list[float | int | str | Segsnr | None] = []
|
1563
1728
|
for metric in metrics:
|
1564
1729
|
result.append(calc(metric))
|
1565
1730
|
|
@@ -1577,13 +1742,16 @@ def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
|
|
1577
1742
|
from .db_datatypes import SpectralMaskRecord
|
1578
1743
|
|
1579
1744
|
with db() as c:
|
1580
|
-
spectral_mask = SpectralMaskRecord(
|
1581
|
-
|
1582
|
-
|
1583
|
-
|
1584
|
-
|
1585
|
-
|
1586
|
-
|
1745
|
+
spectral_mask = SpectralMaskRecord(
|
1746
|
+
*c.execute("SELECT * FROM spectral_mask WHERE ? = spectral_mask.id", (sm_id,)).fetchone()
|
1747
|
+
)
|
1748
|
+
return SpectralMask(
|
1749
|
+
f_max_width=spectral_mask.f_max_width,
|
1750
|
+
f_num=spectral_mask.f_num,
|
1751
|
+
t_max_width=spectral_mask.t_max_width,
|
1752
|
+
t_num=spectral_mask.t_num,
|
1753
|
+
t_max_percent=spectral_mask.t_max_percent,
|
1754
|
+
)
|
1587
1755
|
|
1588
1756
|
|
1589
1757
|
@lru_cache
|
@@ -1596,30 +1764,21 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
|
|
1596
1764
|
"""
|
1597
1765
|
import json
|
1598
1766
|
|
1599
|
-
from .datatypes import TruthSetting
|
1600
|
-
from .datatypes import TruthSettings
|
1601
1767
|
from .db_datatypes import TargetFileRecord
|
1602
1768
|
|
1603
1769
|
with db() as c:
|
1604
1770
|
target_file = TargetFileRecord(
|
1605
|
-
*c.execute("SELECT * FROM target_file WHERE ? = target_file.id", (t_id,)).fetchone()
|
1606
|
-
|
1607
|
-
|
1608
|
-
|
1609
|
-
|
1610
|
-
|
1611
|
-
|
1612
|
-
|
1613
|
-
|
1614
|
-
|
1615
|
-
|
1616
|
-
function=entry.get('function', None),
|
1617
|
-
index=entry.get('index', None)))
|
1618
|
-
return TargetFile(name=target_file.name,
|
1619
|
-
samples=target_file.samples,
|
1620
|
-
level_type=target_file.level_type,
|
1621
|
-
truth_settings=truth_settings,
|
1622
|
-
speaker_id=target_file.speaker_id)
|
1771
|
+
*c.execute("SELECT * FROM target_file WHERE ? = target_file.id", (t_id,)).fetchone()
|
1772
|
+
)
|
1773
|
+
|
1774
|
+
return TargetFile(
|
1775
|
+
name=target_file.name,
|
1776
|
+
samples=target_file.samples,
|
1777
|
+
class_indices=json.loads(target_file.class_indices),
|
1778
|
+
level_type=target_file.level_type,
|
1779
|
+
truth_configs=_target_truth_configs(db, t_id),
|
1780
|
+
speaker_id=target_file.speaker_id,
|
1781
|
+
)
|
1623
1782
|
|
1624
1783
|
|
1625
1784
|
@lru_cache
|
@@ -1631,8 +1790,10 @@ def _noise_file(db: partial, n_id: int) -> NoiseFile:
|
|
1631
1790
|
:return: Noise file
|
1632
1791
|
"""
|
1633
1792
|
with db() as c:
|
1634
|
-
noise = c.execute(
|
1635
|
-
|
1793
|
+
noise = c.execute(
|
1794
|
+
"SELECT noise_file.name, samples FROM noise_file WHERE ? = noise_file.id",
|
1795
|
+
(n_id,),
|
1796
|
+
).fetchone()
|
1636
1797
|
return NoiseFile(name=noise[0], samples=noise[1])
|
1637
1798
|
|
1638
1799
|
|
@@ -1645,9 +1806,12 @@ def _impulse_response_file(db: partial, ir_id: int) -> str:
|
|
1645
1806
|
:return: Noise
|
1646
1807
|
"""
|
1647
1808
|
with db() as c:
|
1648
|
-
return str(
|
1649
|
-
|
1650
|
-
|
1809
|
+
return str(
|
1810
|
+
c.execute(
|
1811
|
+
"SELECT impulse_response_file.file FROM impulse_response_file WHERE ? = impulse_response_file.id",
|
1812
|
+
(ir_id + 1,),
|
1813
|
+
).fetchone()[0]
|
1814
|
+
)
|
1651
1815
|
|
1652
1816
|
|
1653
1817
|
@lru_cache
|
@@ -1658,31 +1822,59 @@ def _mixture(db: partial, m_id: int) -> Mixture:
|
|
1658
1822
|
:param m_id: Zero-based mixture ID
|
1659
1823
|
:return: Mixture record
|
1660
1824
|
"""
|
1661
|
-
from .helpers import to_mixture
|
1662
|
-
from .helpers import to_target
|
1663
1825
|
from .db_datatypes import MixtureRecord
|
1664
1826
|
from .db_datatypes import TargetRecord
|
1827
|
+
from .helpers import to_mixture
|
1828
|
+
from .helpers import to_target
|
1665
1829
|
|
1666
1830
|
with db() as c:
|
1667
1831
|
mixture = MixtureRecord(*c.execute("SELECT * FROM mixture WHERE ? = mixture.id", (m_id + 1,)).fetchone())
|
1668
|
-
targets = [
|
1669
|
-
|
1670
|
-
|
1671
|
-
|
1672
|
-
|
1832
|
+
targets = [
|
1833
|
+
to_target(TargetRecord(*target))
|
1834
|
+
for target in c.execute(
|
1835
|
+
"SELECT target.* "
|
1836
|
+
+ "FROM target, mixture_target "
|
1837
|
+
+ "WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
|
1838
|
+
(mixture.id,),
|
1839
|
+
).fetchall()
|
1840
|
+
]
|
1673
1841
|
|
1674
1842
|
return to_mixture(mixture, targets)
|
1675
1843
|
|
1676
1844
|
|
1677
1845
|
@lru_cache
|
1678
|
-
def _speaker(db: partial, s_id: int | None, tier: str) ->
|
1846
|
+
def _speaker(db: partial, s_id: int | None, tier: str) -> str | None:
|
1679
1847
|
if s_id is None:
|
1680
1848
|
return None
|
1681
1849
|
|
1682
1850
|
with db() as c:
|
1683
|
-
data = c.execute(f
|
1851
|
+
data = c.execute(f"SELECT {tier} FROM speaker WHERE ? = id", (s_id,)).fetchone()
|
1684
1852
|
if data is None:
|
1685
1853
|
return None
|
1686
1854
|
if data[0] is None:
|
1687
1855
|
return None
|
1688
1856
|
return data[0]
|
1857
|
+
|
1858
|
+
|
1859
|
+
@lru_cache
|
1860
|
+
def _target_truth_configs(db: partial, t_id: int) -> TruthConfigs:
|
1861
|
+
import json
|
1862
|
+
|
1863
|
+
from .datatypes import TruthConfig
|
1864
|
+
|
1865
|
+
truth_configs: TruthConfigs = {}
|
1866
|
+
with db() as c:
|
1867
|
+
for truth_config_record in c.execute(
|
1868
|
+
"SELECT truth_config.config "
|
1869
|
+
+ "FROM truth_config, target_file_truth_config "
|
1870
|
+
+ "WHERE ? = target_file_truth_config.target_file_id "
|
1871
|
+
+ "AND truth_config.id = target_file_truth_config.truth_config_id",
|
1872
|
+
(t_id,),
|
1873
|
+
).fetchall():
|
1874
|
+
truth_config = json.loads(truth_config_record[0])
|
1875
|
+
truth_configs[truth_config["name"]] = TruthConfig(
|
1876
|
+
function=truth_config["function"],
|
1877
|
+
stride_reduction=truth_config["stride_reduction"],
|
1878
|
+
config=truth_config["config"],
|
1879
|
+
)
|
1880
|
+
return truth_configs
|