sonusai 0.17.3__py3-none-any.whl → 0.18.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +0 -1
- sonusai/calc_metric_spenh.py +74 -45
- sonusai/doc/doc.py +0 -24
- sonusai/genmetrics.py +146 -0
- sonusai/genmixdb.py +0 -2
- sonusai/mixture/__init__.py +0 -1
- sonusai/mixture/constants.py +0 -1
- sonusai/mixture/datatypes.py +2 -9
- sonusai/mixture/db_datatypes.py +72 -0
- sonusai/mixture/generation.py +139 -38
- sonusai/mixture/helpers.py +75 -16
- sonusai/mixture/mapped_snr_f.py +56 -9
- sonusai/mixture/mixdb.py +347 -226
- sonusai/mixture/tokenized_shell_vars.py +8 -1
- sonusai/speech/textgrid.py +6 -24
- {sonusai-0.17.3.dist-info → sonusai-0.18.1.dist-info}/METADATA +3 -1
- {sonusai-0.17.3.dist-info → sonusai-0.18.1.dist-info}/RECORD +19 -24
- sonusai/mixture/speaker_metadata.py +0 -35
- sonusai/mkmanifest.py +0 -209
- sonusai/utils/asr_manifest_functions/__init__.py +0 -6
- sonusai/utils/asr_manifest_functions/data.py +0 -1
- sonusai/utils/asr_manifest_functions/librispeech.py +0 -46
- sonusai/utils/asr_manifest_functions/mcgill_speech.py +0 -29
- sonusai/utils/asr_manifest_functions/vctk_noisy_speech.py +0 -66
- {sonusai-0.17.3.dist-info → sonusai-0.18.1.dist-info}/WHEEL +0 -0
- {sonusai-0.17.3.dist-info → sonusai-0.18.1.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py
CHANGED
@@ -1,15 +1,12 @@
|
|
1
1
|
from functools import cached_property
|
2
2
|
from functools import lru_cache
|
3
3
|
from functools import partial
|
4
|
-
from pathlib import Path
|
5
4
|
from sqlite3 import Connection
|
6
5
|
from sqlite3 import Cursor
|
7
6
|
from typing import Any
|
8
7
|
from typing import Callable
|
9
8
|
from typing import Optional
|
10
9
|
|
11
|
-
from praatio import textgrid
|
12
|
-
from praatio.utilities.constants import Interval
|
13
10
|
from sonusai.mixture.datatypes import AudioF
|
14
11
|
from sonusai.mixture.datatypes import AudioT
|
15
12
|
from sonusai.mixture.datatypes import AudiosF
|
@@ -33,7 +30,6 @@ from sonusai.mixture.datatypes import TargetFiles
|
|
33
30
|
from sonusai.mixture.datatypes import TransformConfig
|
34
31
|
from sonusai.mixture.datatypes import Truth
|
35
32
|
from sonusai.mixture.datatypes import UniversalSNR
|
36
|
-
from sonusai.mixture.tokenized_shell_vars import tokenized_expand
|
37
33
|
|
38
34
|
|
39
35
|
def db_file(location: str, test: bool = False) -> str:
|
@@ -87,14 +83,12 @@ class MixtureDatabase:
|
|
87
83
|
def __init__(self, location: str, test: bool = False) -> None:
|
88
84
|
self.location = location
|
89
85
|
self.db = partial(SQLiteContextManager, self.location, test)
|
90
|
-
self._speaker_metadata_tiers: list[str] = []
|
91
86
|
|
92
87
|
@cached_property
|
93
88
|
def json(self) -> str:
|
94
89
|
from .datatypes import MixtureDatabaseConfig
|
95
90
|
|
96
91
|
config = MixtureDatabaseConfig(
|
97
|
-
asr_manifest=self.asr_manifests,
|
98
92
|
class_balancing=self.class_balancing,
|
99
93
|
class_labels=self.class_labels,
|
100
94
|
class_weights_threshold=self.class_weights_thresholds,
|
@@ -120,86 +114,6 @@ class MixtureDatabase:
|
|
120
114
|
with open(file=json_name, mode='w') as file:
|
121
115
|
file.write(self.json)
|
122
116
|
|
123
|
-
def target_asr_data(self, t_id: int) -> str | None:
|
124
|
-
"""Get the ASR data for the given target ID
|
125
|
-
|
126
|
-
:param t_id: Target ID
|
127
|
-
:return: ASR text or None
|
128
|
-
"""
|
129
|
-
from .tokenized_shell_vars import tokenized_expand
|
130
|
-
|
131
|
-
name, _ = tokenized_expand(self.target_file(t_id).name)
|
132
|
-
return self.asr_manifest_data.get(name, None)
|
133
|
-
|
134
|
-
def mixture_asr_data(self, m_id: int) -> list[str | None]:
|
135
|
-
"""Get the ASR data for the given mixid
|
136
|
-
|
137
|
-
:param m_id: Zero-based mixture ID
|
138
|
-
:return: List of ASR text or None
|
139
|
-
"""
|
140
|
-
return [self.target_asr_data(target.file_id) for target in self.mixture(m_id).targets]
|
141
|
-
|
142
|
-
@cached_property
|
143
|
-
def asr_manifest_data(self) -> dict[str, str]:
|
144
|
-
"""Get ASR data
|
145
|
-
|
146
|
-
Each line of a manifest file should be in the following format:
|
147
|
-
|
148
|
-
{"audio_filepath": "/path/to/audio.wav", "text": "the transcription of the utterance", "duration": 23.147}
|
149
|
-
|
150
|
-
The audio_filepath field should provide an absolute path to the audio file corresponding to the utterance. The
|
151
|
-
text field should contain the full transcript for the utterance, and the duration field should reflect the
|
152
|
-
duration of the utterance in seconds.
|
153
|
-
|
154
|
-
Each entry in the manifest (describing one audio file) should be bordered by '{' and '}' and must be contained
|
155
|
-
on one line. The fields that describe the file should be separated by commas, and have the form
|
156
|
-
"field_name": value, as shown above.
|
157
|
-
|
158
|
-
Since the manifest specifies the path for each utterance, the audio files do not have to be located in the same
|
159
|
-
directory as the manifest, or even in any specific directory structure.
|
160
|
-
|
161
|
-
The manifest dictionary consists of key/value pairs where the keys are target file names and the values are ASR
|
162
|
-
text.
|
163
|
-
"""
|
164
|
-
import json
|
165
|
-
|
166
|
-
from sonusai import SonusAIError
|
167
|
-
from .tokenized_shell_vars import tokenized_expand
|
168
|
-
|
169
|
-
expected_keys = ['audio_filepath', 'text', 'duration']
|
170
|
-
|
171
|
-
def _error_preamble(e_name: str, e_line_num: int) -> str:
|
172
|
-
return f'Invalid entry in ASR manifest {e_name} line {e_line_num}'
|
173
|
-
|
174
|
-
asr_manifest_data: dict[str, str] = {}
|
175
|
-
|
176
|
-
for name in self.asr_manifests:
|
177
|
-
expanded_name, _ = tokenized_expand(name)
|
178
|
-
with open(file=expanded_name, mode='r') as f:
|
179
|
-
line_num = 1
|
180
|
-
for line in f:
|
181
|
-
result = json.loads(line.strip())
|
182
|
-
|
183
|
-
for key in expected_keys:
|
184
|
-
if key not in result:
|
185
|
-
SonusAIError(f'{_error_preamble(name, line_num)}: missing field "{key}"')
|
186
|
-
|
187
|
-
for key in result.keys():
|
188
|
-
if key not in expected_keys:
|
189
|
-
SonusAIError(f'{_error_preamble(name, line_num)}: unknown field "{key}"')
|
190
|
-
|
191
|
-
key, _ = tokenized_expand(result['audio_filepath'])
|
192
|
-
value = result['text']
|
193
|
-
|
194
|
-
if key in asr_manifest_data:
|
195
|
-
SonusAIError(f'{_error_preamble(name, line_num)}: entry already exists')
|
196
|
-
|
197
|
-
asr_manifest_data[key] = value
|
198
|
-
|
199
|
-
line_num += 1
|
200
|
-
|
201
|
-
return asr_manifest_data
|
202
|
-
|
203
117
|
@cached_property
|
204
118
|
def fg_config(self) -> FeatureGeneratorConfig:
|
205
119
|
return FeatureGeneratorConfig(feature_mode=self.feature,
|
@@ -215,32 +129,32 @@ class MixtureDatabase:
|
|
215
129
|
@cached_property
|
216
130
|
def num_classes(self) -> int:
|
217
131
|
with self.db() as c:
|
218
|
-
return int(c.execute("SELECT top.num_classes
|
132
|
+
return int(c.execute("SELECT top.num_classes FROM top").fetchone()[0])
|
219
133
|
|
220
134
|
@cached_property
|
221
135
|
def truth_mutex(self) -> bool:
|
222
136
|
with self.db() as c:
|
223
|
-
return bool(c.execute("SELECT top.truth_mutex
|
137
|
+
return bool(c.execute("SELECT top.truth_mutex FROM top").fetchone()[0])
|
224
138
|
|
225
139
|
@cached_property
|
226
140
|
def truth_reduction_function(self) -> str:
|
227
141
|
with self.db() as c:
|
228
|
-
return str(c.execute("SELECT top.truth_reduction_function
|
142
|
+
return str(c.execute("SELECT top.truth_reduction_function FROM top").fetchone()[0])
|
229
143
|
|
230
144
|
@cached_property
|
231
145
|
def noise_mix_mode(self) -> str:
|
232
146
|
with self.db() as c:
|
233
|
-
return str(c.execute("SELECT top.noise_mix_mode
|
147
|
+
return str(c.execute("SELECT top.noise_mix_mode FROM top").fetchone()[0])
|
234
148
|
|
235
149
|
@cached_property
|
236
150
|
def class_balancing(self) -> bool:
|
237
151
|
with self.db() as c:
|
238
|
-
return bool(c.execute("SELECT top.class_balancing
|
152
|
+
return bool(c.execute("SELECT top.class_balancing FROM top").fetchone()[0])
|
239
153
|
|
240
154
|
@cached_property
|
241
155
|
def feature(self) -> str:
|
242
156
|
with self.db() as c:
|
243
|
-
return str(c.execute("SELECT top.feature
|
157
|
+
return str(c.execute("SELECT top.feature FROM top").fetchone()[0])
|
244
158
|
|
245
159
|
@cached_property
|
246
160
|
def fg_decimation(self) -> int:
|
@@ -292,14 +206,14 @@ class MixtureDatabase:
|
|
292
206
|
def feature_step_samples(self) -> int:
|
293
207
|
return self.ft_config.R * self.fg_decimation * self.fg_step
|
294
208
|
|
295
|
-
def total_samples(self,
|
296
|
-
return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(
|
209
|
+
def total_samples(self, m_ids: GeneralizedIDs = '*') -> int:
|
210
|
+
return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(m_ids)])
|
297
211
|
|
298
|
-
def total_transform_frames(self,
|
299
|
-
return self.total_samples(
|
212
|
+
def total_transform_frames(self, m_ids: GeneralizedIDs = '*') -> int:
|
213
|
+
return self.total_samples(m_ids) // self.ft_config.R
|
300
214
|
|
301
|
-
def total_feature_frames(self,
|
302
|
-
return self.total_samples(
|
215
|
+
def total_feature_frames(self, m_ids: GeneralizedIDs = '*') -> int:
|
216
|
+
return self.total_samples(m_ids) // self.feature_step_samples
|
303
217
|
|
304
218
|
def mixture_transform_frames(self, samples: int) -> int:
|
305
219
|
return samples // self.ft_config.R
|
@@ -307,24 +221,15 @@ class MixtureDatabase:
|
|
307
221
|
def mixture_feature_frames(self, samples: int) -> int:
|
308
222
|
return samples // self.feature_step_samples
|
309
223
|
|
310
|
-
def mixids_to_list(self,
|
224
|
+
def mixids_to_list(self, m_ids: Optional[GeneralizedIDs] = None) -> list[int]:
|
311
225
|
"""Resolve generalized mixture IDs to a list of integers
|
312
226
|
|
313
|
-
:param
|
227
|
+
:param m_ids: Generalized mixture IDs
|
314
228
|
:return: List of mixture ID integers
|
315
229
|
"""
|
316
230
|
from .helpers import generic_ids_to_list
|
317
231
|
|
318
|
-
return generic_ids_to_list(self.num_mixtures,
|
319
|
-
|
320
|
-
@cached_property
|
321
|
-
def asr_manifests(self) -> list[str]:
|
322
|
-
"""Get ASR manifests from db
|
323
|
-
|
324
|
-
:return: ASR manifests
|
325
|
-
"""
|
326
|
-
with self.db() as c:
|
327
|
-
return [str(item[0]) for item in c.execute("SELECT asr_manifest.manifest FROM asr_manifest").fetchall()]
|
232
|
+
return generic_ids_to_list(self.num_mixtures, m_ids)
|
328
233
|
|
329
234
|
@cached_property
|
330
235
|
def class_labels(self) -> list[str]:
|
@@ -377,14 +282,16 @@ class MixtureDatabase:
|
|
377
282
|
|
378
283
|
:return: Spectral masks
|
379
284
|
"""
|
285
|
+
from .db_datatypes import SpectralMaskRecord
|
286
|
+
|
380
287
|
with self.db() as c:
|
381
|
-
|
382
|
-
|
383
|
-
return [SpectralMask(f_max_width=spectral_mask
|
384
|
-
f_num=spectral_mask
|
385
|
-
t_max_width=spectral_mask
|
386
|
-
t_num=spectral_mask
|
387
|
-
t_max_percent=spectral_mask
|
288
|
+
spectral_masks = [SpectralMaskRecord(*result) for result in
|
289
|
+
c.execute("SELECT * FROM spectral_mask").fetchall()]
|
290
|
+
return [SpectralMask(f_max_width=spectral_mask.f_max_width,
|
291
|
+
f_num=spectral_mask.f_num,
|
292
|
+
t_max_width=spectral_mask.t_max_width,
|
293
|
+
t_num=spectral_mask.t_num,
|
294
|
+
t_max_percent=spectral_mask.t_max_percent) for spectral_mask in spectral_masks]
|
388
295
|
|
389
296
|
def spectral_mask(self, sm_id: int) -> SpectralMask:
|
390
297
|
"""Get spectral mask with ID from db
|
@@ -404,25 +311,29 @@ class MixtureDatabase:
|
|
404
311
|
|
405
312
|
from .datatypes import TruthSetting
|
406
313
|
from .datatypes import TruthSettings
|
314
|
+
from .db_datatypes import TargetFileRecord
|
407
315
|
|
408
316
|
with self.db() as c:
|
409
317
|
target_files: TargetFiles = []
|
410
|
-
|
318
|
+
target_file_records = [TargetFileRecord(*result) for result in
|
319
|
+
c.execute("SELECT * FROM target_file").fetchall()]
|
320
|
+
for target_file_record in target_file_records:
|
411
321
|
truth_settings: TruthSettings = []
|
412
|
-
for
|
322
|
+
for truth_setting_records in c.execute(
|
413
323
|
"SELECT truth_setting.setting " +
|
414
324
|
"FROM truth_setting, target_file_truth_setting " +
|
415
325
|
"WHERE ? = target_file_truth_setting.target_file_id " +
|
416
326
|
"AND truth_setting.id = target_file_truth_setting.truth_setting_id",
|
417
|
-
(
|
418
|
-
|
419
|
-
truth_settings.append(TruthSetting(config=
|
420
|
-
function=
|
421
|
-
index=
|
422
|
-
target_files.append(TargetFile(name=
|
423
|
-
samples=
|
424
|
-
level_type=
|
425
|
-
truth_settings=truth_settings
|
327
|
+
(target_file_record.id,)).fetchall():
|
328
|
+
truth_setting = json.loads(truth_setting_records[0])
|
329
|
+
truth_settings.append(TruthSetting(config=truth_setting.get('config', None),
|
330
|
+
function=truth_setting.get('function', None),
|
331
|
+
index=truth_setting.get('index', None)))
|
332
|
+
target_files.append(TargetFile(name=target_file_record.name,
|
333
|
+
samples=target_file_record.samples,
|
334
|
+
level_type=target_file_record.level_type,
|
335
|
+
truth_settings=truth_settings,
|
336
|
+
speaker_id=target_file_record.speaker_id))
|
426
337
|
return target_files
|
427
338
|
|
428
339
|
@cached_property
|
@@ -532,18 +443,16 @@ class MixtureDatabase:
|
|
532
443
|
"""
|
533
444
|
from .helpers import to_mixture
|
534
445
|
from .helpers import to_target
|
446
|
+
from .db_datatypes import MixtureRecord
|
447
|
+
from .db_datatypes import TargetRecord
|
535
448
|
|
536
449
|
with self.db() as c:
|
537
450
|
mixtures: Mixtures = []
|
538
|
-
for mixture in c.execute(
|
539
|
-
|
540
|
-
"
|
541
|
-
"FROM mixture").fetchall():
|
542
|
-
targets = [to_target(target) for target in c.execute(
|
543
|
-
"SELECT target.file_id, augmentation, gain " +
|
544
|
-
"FROM target, mixture_target " +
|
451
|
+
for mixture in [MixtureRecord(*record) for record in c.execute("SELECT * FROM mixture").fetchall()]:
|
452
|
+
targets = [to_target(TargetRecord(*target)) for target in c.execute(
|
453
|
+
"SELECT target.* FROM target, mixture_target " +
|
545
454
|
"WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
|
546
|
-
(mixture
|
455
|
+
(mixture.id,)).fetchall()]
|
547
456
|
mixtures.append(to_mixture(mixture, targets))
|
548
457
|
return mixtures
|
549
458
|
|
@@ -567,7 +476,7 @@ class MixtureDatabase:
|
|
567
476
|
@cached_property
|
568
477
|
def mixid_width(self) -> int:
|
569
478
|
with self.db() as c:
|
570
|
-
return int(c.execute("SELECT top.mixid_width
|
479
|
+
return int(c.execute("SELECT top.mixid_width FROM top").fetchone()[0])
|
571
480
|
|
572
481
|
def location_filename(self, name: str) -> str:
|
573
482
|
"""Add the location to the given file name
|
@@ -719,7 +628,7 @@ class MixtureDatabase:
|
|
719
628
|
|
720
629
|
:param m_id: Zero-based mixture ID
|
721
630
|
:param targets: List of augmented target audio data (one per target in the mixup)
|
722
|
-
:param target: Augmented target audio for the given
|
631
|
+
:param target: Augmented target audio for the given m_id
|
723
632
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
724
633
|
:return: Augmented target transform data
|
725
634
|
"""
|
@@ -1077,97 +986,298 @@ class MixtureDatabase:
|
|
1077
986
|
return class_count
|
1078
987
|
|
1079
988
|
@cached_property
|
1080
|
-
def
|
1081
|
-
|
989
|
+
def speaker_metadata_tiers(self) -> list[str]:
|
990
|
+
import json
|
1082
991
|
|
1083
|
-
|
1084
|
-
|
1085
|
-
data: dict[str, dict[str, SpeechMetadata]] = {}
|
1086
|
-
for file in self.target_files:
|
1087
|
-
data[file.name] = {}
|
1088
|
-
file_name, _ = tokenized_expand(file.name)
|
1089
|
-
tg_file = Path(file_name).with_suffix('.TextGrid')
|
1090
|
-
if tg_file.exists():
|
1091
|
-
tg = textgrid.openTextgrid(str(tg_file), includeEmptyIntervals=False)
|
1092
|
-
for tier in tg.tierNames:
|
1093
|
-
entries = tg.getTier(tier).entries
|
1094
|
-
if len(entries) > 1:
|
1095
|
-
data[file.name][tier] = entries
|
1096
|
-
else:
|
1097
|
-
data[file.name][tier] = entries[0].label
|
992
|
+
with self.db() as c:
|
993
|
+
return json.loads(c.execute("SELECT speaker_metadata_tiers FROM top WHERE 1 = id").fetchone()[0])
|
1098
994
|
|
1099
|
-
|
995
|
+
@cached_property
|
996
|
+
def textgrid_metadata_tiers(self) -> list[str]:
|
997
|
+
import json
|
998
|
+
|
999
|
+
with self.db() as c:
|
1000
|
+
return json.loads(c.execute("SELECT textgrid_metadata_tiers FROM top WHERE 1 = id").fetchone()[0])
|
1100
1001
|
|
1101
1002
|
@cached_property
|
1102
1003
|
def speech_metadata_tiers(self) -> list[str]:
|
1103
|
-
return sorted(
|
1004
|
+
return sorted(set(self.speaker_metadata_tiers + self.textgrid_metadata_tiers))
|
1005
|
+
|
1006
|
+
def speaker(self, s_id: int | None, tier: str) -> Optional[str]:
|
1007
|
+
return _speaker(self.db, s_id, tier)
|
1008
|
+
|
1009
|
+
def speech_metadata(self, tier: str) -> list[str]:
|
1010
|
+
from .helpers import get_textgrid_tier_from_target_file
|
1011
|
+
|
1012
|
+
results: set[str] = set()
|
1013
|
+
if tier in self.textgrid_metadata_tiers:
|
1014
|
+
for target_file in self.target_files:
|
1015
|
+
data = get_textgrid_tier_from_target_file(target_file.name, tier)
|
1016
|
+
if data is None:
|
1017
|
+
continue
|
1018
|
+
if isinstance(data, list):
|
1019
|
+
for item in data:
|
1020
|
+
results.add(item.label)
|
1021
|
+
else:
|
1022
|
+
results.add(data)
|
1023
|
+
elif tier in self.speaker_metadata_tiers:
|
1024
|
+
for target_file in self.target_files:
|
1025
|
+
data = self.speaker(target_file.speaker_id, tier)
|
1026
|
+
if data is not None:
|
1027
|
+
results.add(data)
|
1028
|
+
|
1029
|
+
return sorted(results)
|
1030
|
+
|
1031
|
+
def mixture_speech_metadata(self, mixid: int, tier: str) -> list[SpeechMetadata]:
|
1032
|
+
from praatio.utilities.constants import Interval
|
1033
|
+
|
1034
|
+
from .helpers import get_textgrid_tier_from_target_file
|
1035
|
+
|
1036
|
+
results: list[SpeechMetadata] = []
|
1037
|
+
is_textgrid = tier in self.textgrid_metadata_tiers
|
1038
|
+
if is_textgrid:
|
1039
|
+
for target in self.mixture(mixid).targets:
|
1040
|
+
data = get_textgrid_tier_from_target_file(self.target_file(target.file_id).name, tier)
|
1041
|
+
if data is not None:
|
1042
|
+
if isinstance(data, list):
|
1043
|
+
# Check for tempo augmentation and adjust Interval start and end data as needed
|
1044
|
+
entries = []
|
1045
|
+
for entry in data:
|
1046
|
+
if target.augmentation.tempo is not None:
|
1047
|
+
entries.append(Interval(entry.start / target.augmentation.tempo,
|
1048
|
+
entry.end / target.augmentation.tempo,
|
1049
|
+
entry.label))
|
1050
|
+
else:
|
1051
|
+
entries.append(entry)
|
1052
|
+
results.append(entries)
|
1053
|
+
else:
|
1054
|
+
results.append(data)
|
1055
|
+
else:
|
1056
|
+
for target in self.mixture(mixid).targets:
|
1057
|
+
data = self.speaker(self.target_file(target.file_id).speaker_id, tier)
|
1058
|
+
if data is not None:
|
1059
|
+
results.append(data)
|
1104
1060
|
|
1105
|
-
|
1106
|
-
results = sorted(
|
1107
|
-
set([value.get(tier) for value in self._speech_metadata.values() if isinstance(value.get(tier), str)]))
|
1108
|
-
return results
|
1061
|
+
return sorted(results)
|
1109
1062
|
|
1110
1063
|
def mixids_for_speech_metadata(self,
|
1111
1064
|
tier: str,
|
1112
|
-
value: str,
|
1065
|
+
value: str | None,
|
1113
1066
|
predicate: Callable[[str], bool] = None) -> list[int]:
|
1114
|
-
"""Get a list of
|
1067
|
+
"""Get a list of mixture IDs for the given speech metadata tier.
|
1115
1068
|
|
1116
|
-
If 'predicate' is None, then include
|
1117
|
-
not None, then ignore 'value' and use the given callable to determine which entries
|
1069
|
+
If 'predicate' is None, then include mixture IDs whose tier values are equal to the given 'value'.
|
1070
|
+
If 'predicate' is not None, then ignore 'value' and use the given callable to determine which entries
|
1071
|
+
to include.
|
1118
1072
|
|
1119
1073
|
Examples:
|
1074
|
+
>>> mixdb = MixtureDatabase('/mixdb_location')
|
1120
1075
|
|
1121
1076
|
>>> mixids = mixdb.mixids_for_speech_metadata('speaker_id', 'TIMIT_ARC0')
|
1122
|
-
Get
|
1077
|
+
Get mixutre IDs for mixtures with speakers whose speaker_ids are 'TIMIT_ARC0'.
|
1123
1078
|
|
1124
1079
|
>>> mixids = mixdb.mixids_for_speech_metadata('age', '', lambda x: int(x) < 25)
|
1125
|
-
Get
|
1080
|
+
Get mixture IDs for mixtures with speakers whose ages are less than 25.
|
1126
1081
|
|
1127
1082
|
>>> mixids = mixdb.mixids_for_speech_metadata('dialect', '', lambda x: x in ['New York City', 'Northern'])
|
1128
|
-
Get
|
1083
|
+
Get mixture IDs for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
|
1129
1084
|
"""
|
1085
|
+
from .helpers import get_textgrid_tier_from_target_file
|
1086
|
+
|
1130
1087
|
if predicate is None:
|
1131
|
-
def predicate(x: str) -> bool:
|
1088
|
+
def predicate(x: str | None) -> bool:
|
1132
1089
|
return x == value
|
1133
1090
|
|
1134
1091
|
# First get list of matching target files
|
1135
|
-
|
1136
|
-
|
1092
|
+
target_file_ids: list[int] = []
|
1093
|
+
is_textgrid = tier in self.textgrid_metadata_tiers
|
1094
|
+
for target_file_id, target_file in enumerate(self.target_files):
|
1095
|
+
if is_textgrid:
|
1096
|
+
metadata = get_textgrid_tier_from_target_file(target_file.name, tier)
|
1097
|
+
else:
|
1098
|
+
metadata = self.speaker(target_file.speaker_id, tier)
|
1137
1099
|
|
1138
|
-
|
1139
|
-
|
1140
|
-
for mixid in self.mixids_to_list():
|
1141
|
-
mixid_target_files = [self.target_file(target.file_id).name for target in self.mixture(mixid).targets]
|
1142
|
-
for mixid_target_file in mixid_target_files:
|
1143
|
-
if mixid_target_file in target_files:
|
1144
|
-
mixids.append(mixid)
|
1100
|
+
if not isinstance(metadata, list) and predicate(metadata):
|
1101
|
+
target_file_ids.append(target_file_id + 1)
|
1145
1102
|
|
1146
|
-
#
|
1147
|
-
|
1103
|
+
# Next get list of mixture IDs that contain those target files
|
1104
|
+
with self.db() as c:
|
1105
|
+
m_ids = c.execute("SELECT mixture_id FROM mixture_target " +
|
1106
|
+
f"WHERE mixture_target.target_id IN ({','.join(map(str, target_file_ids))})").fetchall()
|
1107
|
+
return [x[0] - 1 for x in m_ids]
|
1148
1108
|
|
1149
|
-
def
|
1150
|
-
|
1151
|
-
for target in self.mixture(mixid).targets:
|
1152
|
-
data = self._speech_metadata[self.target_file(target.file_id).name].get(tier)
|
1153
|
-
|
1154
|
-
if data is None:
|
1155
|
-
results.append(None)
|
1156
|
-
elif isinstance(data, list):
|
1157
|
-
# Check for tempo augmentation and adjust Interval start and end data as needed
|
1158
|
-
entries = []
|
1159
|
-
for entry in data:
|
1160
|
-
if target.augmentation.tempo is not None:
|
1161
|
-
entries.append(Interval(entry.start / target.augmentation.tempo,
|
1162
|
-
entry.end / target.augmentation.tempo,
|
1163
|
-
entry.label))
|
1164
|
-
else:
|
1165
|
-
entries.append(entry)
|
1109
|
+
def mixture_all_speech_metadata(self, m_id: int) -> list[dict[str, SpeechMetadata]]:
|
1110
|
+
from .helpers import mixture_all_speech_metadata
|
1166
1111
|
|
1167
|
-
|
1168
|
-
results.append(data)
|
1112
|
+
return mixture_all_speech_metadata(self, self.mixture(m_id))
|
1169
1113
|
|
1170
|
-
|
1114
|
+
def mixture_metric(self, m_id: int, metric: str, force: bool = False) -> Any:
|
1115
|
+
"""Get metric data for the given mixture ID
|
1116
|
+
|
1117
|
+
:param m_id: Zero-based mixture ID
|
1118
|
+
:param metric: Metric data to retrieve
|
1119
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
1120
|
+
:return: Metric data
|
1121
|
+
"""
|
1122
|
+
from sonusai import SonusAIError
|
1123
|
+
|
1124
|
+
supported_metrics = (
|
1125
|
+
'MXSNR',
|
1126
|
+
'MXSSNRAVG',
|
1127
|
+
'MXSSNRSTD',
|
1128
|
+
'MXSSNRDAVG',
|
1129
|
+
'MXSSNRDSTD',
|
1130
|
+
'MXPESQ',
|
1131
|
+
'MXWSDR',
|
1132
|
+
'MXPD',
|
1133
|
+
'MXSTOI',
|
1134
|
+
'MXCSIG',
|
1135
|
+
'MXCBAK',
|
1136
|
+
'MXCOVL',
|
1137
|
+
'TDCO',
|
1138
|
+
'TMIN',
|
1139
|
+
'TMAX',
|
1140
|
+
'TPKDB',
|
1141
|
+
'TLRMS',
|
1142
|
+
'TPKR',
|
1143
|
+
'TTR',
|
1144
|
+
'TCR',
|
1145
|
+
'TFL',
|
1146
|
+
'TPKC',
|
1147
|
+
'NDCO',
|
1148
|
+
'NMIN',
|
1149
|
+
'NMAX',
|
1150
|
+
'NPKDB',
|
1151
|
+
'NLRMS',
|
1152
|
+
'NPKR',
|
1153
|
+
'NTR',
|
1154
|
+
'NCR',
|
1155
|
+
'NFL',
|
1156
|
+
'NPKC',
|
1157
|
+
'SEDAVG',
|
1158
|
+
'SEDCNT',
|
1159
|
+
'SEDTOPN',
|
1160
|
+
)
|
1161
|
+
|
1162
|
+
if not (metric in supported_metrics or metric.startswith('MXWER')):
|
1163
|
+
raise ValueError(f'Unsupported metric: {metric}')
|
1164
|
+
|
1165
|
+
if not force:
|
1166
|
+
result = self.read_mixture_data(m_id, metric)
|
1167
|
+
if result is not None:
|
1168
|
+
return result
|
1169
|
+
|
1170
|
+
mixture = self.mixture(m_id)
|
1171
|
+
if mixture is None:
|
1172
|
+
raise SonusAIError(f'Could not find mixture for m_id: {m_id}')
|
1173
|
+
|
1174
|
+
if metric.startswith('MXWER'):
|
1175
|
+
return None
|
1176
|
+
|
1177
|
+
if metric == 'MXSNR':
|
1178
|
+
return self.snrs
|
1179
|
+
|
1180
|
+
if metric == 'MXSSNRAVG':
|
1181
|
+
return None
|
1182
|
+
|
1183
|
+
if metric == 'MXSSNRSTD':
|
1184
|
+
return None
|
1185
|
+
|
1186
|
+
if metric == 'MXSSNRDAVG':
|
1187
|
+
return None
|
1188
|
+
|
1189
|
+
if metric == 'MXSSNRDSTD':
|
1190
|
+
return None
|
1191
|
+
|
1192
|
+
if metric == 'MXPESQ':
|
1193
|
+
return None
|
1194
|
+
|
1195
|
+
if metric == 'MXWSDR':
|
1196
|
+
return None
|
1197
|
+
|
1198
|
+
if metric == 'MXPD':
|
1199
|
+
return None
|
1200
|
+
|
1201
|
+
if metric == 'MXSTOI':
|
1202
|
+
return None
|
1203
|
+
|
1204
|
+
if metric == 'MXCSIG':
|
1205
|
+
return None
|
1206
|
+
|
1207
|
+
if metric == 'MXCBAK':
|
1208
|
+
return None
|
1209
|
+
|
1210
|
+
if metric == 'MXCOVL':
|
1211
|
+
return None
|
1212
|
+
|
1213
|
+
if metric == 'TDCO':
|
1214
|
+
return None
|
1215
|
+
|
1216
|
+
if metric == 'TMIN':
|
1217
|
+
return None
|
1218
|
+
|
1219
|
+
if metric == 'TMAX':
|
1220
|
+
return None
|
1221
|
+
|
1222
|
+
if metric == 'TPKDB':
|
1223
|
+
return None
|
1224
|
+
|
1225
|
+
if metric == 'TLRMS':
|
1226
|
+
return None
|
1227
|
+
|
1228
|
+
if metric == 'TPKR':
|
1229
|
+
return None
|
1230
|
+
|
1231
|
+
if metric == 'TTR':
|
1232
|
+
return None
|
1233
|
+
|
1234
|
+
if metric == 'TCR':
|
1235
|
+
return None
|
1236
|
+
|
1237
|
+
if metric == 'TFL':
|
1238
|
+
return None
|
1239
|
+
|
1240
|
+
if metric == 'TPKC':
|
1241
|
+
return None
|
1242
|
+
|
1243
|
+
if metric == 'NDCO':
|
1244
|
+
return None
|
1245
|
+
|
1246
|
+
if metric == 'NMIN':
|
1247
|
+
return None
|
1248
|
+
|
1249
|
+
if metric == 'NMAX':
|
1250
|
+
return None
|
1251
|
+
|
1252
|
+
if metric == 'NPKDB':
|
1253
|
+
return None
|
1254
|
+
|
1255
|
+
if metric == 'NLRMS':
|
1256
|
+
return None
|
1257
|
+
|
1258
|
+
if metric == 'NPKR':
|
1259
|
+
return None
|
1260
|
+
|
1261
|
+
if metric == 'NTR':
|
1262
|
+
return None
|
1263
|
+
|
1264
|
+
if metric == 'NCR':
|
1265
|
+
return None
|
1266
|
+
|
1267
|
+
if metric == 'NFL':
|
1268
|
+
return None
|
1269
|
+
|
1270
|
+
if metric == 'NPKC':
|
1271
|
+
return None
|
1272
|
+
|
1273
|
+
if metric == 'SEDAVG':
|
1274
|
+
return None
|
1275
|
+
|
1276
|
+
if metric == 'SEDCNT':
|
1277
|
+
return None
|
1278
|
+
|
1279
|
+
if metric == 'SEDTOPN':
|
1280
|
+
return None
|
1171
1281
|
|
1172
1282
|
|
1173
1283
|
@lru_cache
|
@@ -1178,17 +1288,16 @@ def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
|
|
1178
1288
|
:param sm_id: Spectral mask ID
|
1179
1289
|
:return: Spectral mask
|
1180
1290
|
"""
|
1291
|
+
from .db_datatypes import SpectralMaskRecord
|
1292
|
+
|
1181
1293
|
with db() as c:
|
1182
|
-
spectral_mask = c.execute(
|
1183
|
-
|
1184
|
-
|
1185
|
-
|
1186
|
-
|
1187
|
-
|
1188
|
-
|
1189
|
-
t_max_width=spectral_mask[2],
|
1190
|
-
t_num=spectral_mask[3],
|
1191
|
-
t_max_percent=spectral_mask[4])
|
1294
|
+
spectral_mask = SpectralMaskRecord(*c.execute("SELECT * FROM spectral_mask WHERE ? = spectral_mask.id",
|
1295
|
+
(sm_id,)).fetchone())
|
1296
|
+
return SpectralMask(f_max_width=spectral_mask.f_max_width,
|
1297
|
+
f_num=spectral_mask.f_num,
|
1298
|
+
t_max_width=spectral_mask.t_max_width,
|
1299
|
+
t_num=spectral_mask.t_num,
|
1300
|
+
t_max_percent=spectral_mask.t_max_percent)
|
1192
1301
|
|
1193
1302
|
|
1194
1303
|
@lru_cache
|
@@ -1203,10 +1312,11 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
|
|
1203
1312
|
|
1204
1313
|
from .datatypes import TruthSetting
|
1205
1314
|
from .datatypes import TruthSettings
|
1315
|
+
from .db_datatypes import TargetFileRecord
|
1206
1316
|
|
1207
1317
|
with db() as c:
|
1208
|
-
|
1209
|
-
|
1318
|
+
target_file = TargetFileRecord(
|
1319
|
+
*c.execute("SELECT * FROM target_file WHERE ? = target_file.id", (t_id,)).fetchone())
|
1210
1320
|
|
1211
1321
|
truth_settings: TruthSettings = []
|
1212
1322
|
for ts in c.execute(
|
@@ -1219,10 +1329,11 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
|
|
1219
1329
|
truth_settings.append(TruthSetting(config=entry.get('config', None),
|
1220
1330
|
function=entry.get('function', None),
|
1221
1331
|
index=entry.get('index', None)))
|
1222
|
-
return TargetFile(name=
|
1223
|
-
samples=
|
1224
|
-
level_type=
|
1225
|
-
truth_settings=truth_settings
|
1332
|
+
return TargetFile(name=target_file.name,
|
1333
|
+
samples=target_file.samples,
|
1334
|
+
level_type=target_file.level_type,
|
1335
|
+
truth_settings=truth_settings,
|
1336
|
+
speaker_id=target_file.speaker_id)
|
1226
1337
|
|
1227
1338
|
|
1228
1339
|
@lru_cache
|
@@ -1263,19 +1374,29 @@ def _mixture(db: partial, m_id: int) -> Mixture:
|
|
1263
1374
|
"""
|
1264
1375
|
from .helpers import to_mixture
|
1265
1376
|
from .helpers import to_target
|
1377
|
+
from .db_datatypes import MixtureRecord
|
1378
|
+
from .db_datatypes import TargetRecord
|
1266
1379
|
|
1267
1380
|
with db() as c:
|
1268
|
-
mixture = c.execute(
|
1269
|
-
|
1270
|
-
"
|
1271
|
-
"FROM mixture " +
|
1272
|
-
"WHERE ? = mixture.id",
|
1273
|
-
(m_id + 1,)).fetchone()
|
1274
|
-
|
1275
|
-
targets = [to_target(target) for target in c.execute(
|
1276
|
-
"SELECT target.file_id, augmentation, gain " +
|
1381
|
+
mixture = MixtureRecord(*c.execute("SELECT * FROM mixture WHERE ? = mixture.id", (m_id + 1,)).fetchone())
|
1382
|
+
targets = [to_target(TargetRecord(*target)) for target in c.execute(
|
1383
|
+
"SELECT target.* " +
|
1277
1384
|
"FROM target, mixture_target " +
|
1278
1385
|
"WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
|
1279
|
-
(mixture
|
1386
|
+
(mixture.id,)).fetchall()]
|
1280
1387
|
|
1281
1388
|
return to_mixture(mixture, targets)
|
1389
|
+
|
1390
|
+
|
1391
|
+
@lru_cache
|
1392
|
+
def _speaker(db: partial, s_id: int | None, tier: str) -> Optional[str]:
|
1393
|
+
if s_id is None:
|
1394
|
+
return None
|
1395
|
+
|
1396
|
+
with db() as c:
|
1397
|
+
data = c.execute(f'SELECT {tier} FROM speaker WHERE ? = id', (s_id,)).fetchone()
|
1398
|
+
if data is None:
|
1399
|
+
return None
|
1400
|
+
if data[0] is None:
|
1401
|
+
return None
|
1402
|
+
return data[0]
|