sonusai 0.17.3__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +0 -1
- sonusai/calc_metric_spenh.py +74 -45
- sonusai/doc/doc.py +0 -24
- sonusai/genmetrics.py +146 -0
- sonusai/genmixdb.py +0 -2
- sonusai/mixture/__init__.py +0 -1
- sonusai/mixture/constants.py +0 -1
- sonusai/mixture/datatypes.py +2 -9
- sonusai/mixture/generation.py +136 -38
- sonusai/mixture/helpers.py +58 -1
- sonusai/mixture/mapped_snr_f.py +56 -9
- sonusai/mixture/mixdb.py +293 -169
- sonusai/mixture/tokenized_shell_vars.py +8 -1
- sonusai/speech/textgrid.py +6 -24
- {sonusai-0.17.3.dist-info → sonusai-0.18.0.dist-info}/METADATA +3 -1
- {sonusai-0.17.3.dist-info → sonusai-0.18.0.dist-info}/RECORD +18 -24
- sonusai/mixture/speaker_metadata.py +0 -35
- sonusai/mkmanifest.py +0 -209
- sonusai/utils/asr_manifest_functions/__init__.py +0 -6
- sonusai/utils/asr_manifest_functions/data.py +0 -1
- sonusai/utils/asr_manifest_functions/librispeech.py +0 -46
- sonusai/utils/asr_manifest_functions/mcgill_speech.py +0 -29
- sonusai/utils/asr_manifest_functions/vctk_noisy_speech.py +0 -66
- {sonusai-0.17.3.dist-info → sonusai-0.18.0.dist-info}/WHEEL +0 -0
- {sonusai-0.17.3.dist-info → sonusai-0.18.0.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py
CHANGED
@@ -1,15 +1,12 @@
|
|
1
1
|
from functools import cached_property
|
2
2
|
from functools import lru_cache
|
3
3
|
from functools import partial
|
4
|
-
from pathlib import Path
|
5
4
|
from sqlite3 import Connection
|
6
5
|
from sqlite3 import Cursor
|
7
6
|
from typing import Any
|
8
7
|
from typing import Callable
|
9
8
|
from typing import Optional
|
10
9
|
|
11
|
-
from praatio import textgrid
|
12
|
-
from praatio.utilities.constants import Interval
|
13
10
|
from sonusai.mixture.datatypes import AudioF
|
14
11
|
from sonusai.mixture.datatypes import AudioT
|
15
12
|
from sonusai.mixture.datatypes import AudiosF
|
@@ -33,7 +30,6 @@ from sonusai.mixture.datatypes import TargetFiles
|
|
33
30
|
from sonusai.mixture.datatypes import TransformConfig
|
34
31
|
from sonusai.mixture.datatypes import Truth
|
35
32
|
from sonusai.mixture.datatypes import UniversalSNR
|
36
|
-
from sonusai.mixture.tokenized_shell_vars import tokenized_expand
|
37
33
|
|
38
34
|
|
39
35
|
def db_file(location: str, test: bool = False) -> str:
|
@@ -87,14 +83,12 @@ class MixtureDatabase:
|
|
87
83
|
def __init__(self, location: str, test: bool = False) -> None:
|
88
84
|
self.location = location
|
89
85
|
self.db = partial(SQLiteContextManager, self.location, test)
|
90
|
-
self._speaker_metadata_tiers: list[str] = []
|
91
86
|
|
92
87
|
@cached_property
|
93
88
|
def json(self) -> str:
|
94
89
|
from .datatypes import MixtureDatabaseConfig
|
95
90
|
|
96
91
|
config = MixtureDatabaseConfig(
|
97
|
-
asr_manifest=self.asr_manifests,
|
98
92
|
class_balancing=self.class_balancing,
|
99
93
|
class_labels=self.class_labels,
|
100
94
|
class_weights_threshold=self.class_weights_thresholds,
|
@@ -120,86 +114,6 @@ class MixtureDatabase:
|
|
120
114
|
with open(file=json_name, mode='w') as file:
|
121
115
|
file.write(self.json)
|
122
116
|
|
123
|
-
def target_asr_data(self, t_id: int) -> str | None:
|
124
|
-
"""Get the ASR data for the given target ID
|
125
|
-
|
126
|
-
:param t_id: Target ID
|
127
|
-
:return: ASR text or None
|
128
|
-
"""
|
129
|
-
from .tokenized_shell_vars import tokenized_expand
|
130
|
-
|
131
|
-
name, _ = tokenized_expand(self.target_file(t_id).name)
|
132
|
-
return self.asr_manifest_data.get(name, None)
|
133
|
-
|
134
|
-
def mixture_asr_data(self, m_id: int) -> list[str | None]:
|
135
|
-
"""Get the ASR data for the given mixid
|
136
|
-
|
137
|
-
:param m_id: Zero-based mixture ID
|
138
|
-
:return: List of ASR text or None
|
139
|
-
"""
|
140
|
-
return [self.target_asr_data(target.file_id) for target in self.mixture(m_id).targets]
|
141
|
-
|
142
|
-
@cached_property
|
143
|
-
def asr_manifest_data(self) -> dict[str, str]:
|
144
|
-
"""Get ASR data
|
145
|
-
|
146
|
-
Each line of a manifest file should be in the following format:
|
147
|
-
|
148
|
-
{"audio_filepath": "/path/to/audio.wav", "text": "the transcription of the utterance", "duration": 23.147}
|
149
|
-
|
150
|
-
The audio_filepath field should provide an absolute path to the audio file corresponding to the utterance. The
|
151
|
-
text field should contain the full transcript for the utterance, and the duration field should reflect the
|
152
|
-
duration of the utterance in seconds.
|
153
|
-
|
154
|
-
Each entry in the manifest (describing one audio file) should be bordered by '{' and '}' and must be contained
|
155
|
-
on one line. The fields that describe the file should be separated by commas, and have the form
|
156
|
-
"field_name": value, as shown above.
|
157
|
-
|
158
|
-
Since the manifest specifies the path for each utterance, the audio files do not have to be located in the same
|
159
|
-
directory as the manifest, or even in any specific directory structure.
|
160
|
-
|
161
|
-
The manifest dictionary consists of key/value pairs where the keys are target file names and the values are ASR
|
162
|
-
text.
|
163
|
-
"""
|
164
|
-
import json
|
165
|
-
|
166
|
-
from sonusai import SonusAIError
|
167
|
-
from .tokenized_shell_vars import tokenized_expand
|
168
|
-
|
169
|
-
expected_keys = ['audio_filepath', 'text', 'duration']
|
170
|
-
|
171
|
-
def _error_preamble(e_name: str, e_line_num: int) -> str:
|
172
|
-
return f'Invalid entry in ASR manifest {e_name} line {e_line_num}'
|
173
|
-
|
174
|
-
asr_manifest_data: dict[str, str] = {}
|
175
|
-
|
176
|
-
for name in self.asr_manifests:
|
177
|
-
expanded_name, _ = tokenized_expand(name)
|
178
|
-
with open(file=expanded_name, mode='r') as f:
|
179
|
-
line_num = 1
|
180
|
-
for line in f:
|
181
|
-
result = json.loads(line.strip())
|
182
|
-
|
183
|
-
for key in expected_keys:
|
184
|
-
if key not in result:
|
185
|
-
SonusAIError(f'{_error_preamble(name, line_num)}: missing field "{key}"')
|
186
|
-
|
187
|
-
for key in result.keys():
|
188
|
-
if key not in expected_keys:
|
189
|
-
SonusAIError(f'{_error_preamble(name, line_num)}: unknown field "{key}"')
|
190
|
-
|
191
|
-
key, _ = tokenized_expand(result['audio_filepath'])
|
192
|
-
value = result['text']
|
193
|
-
|
194
|
-
if key in asr_manifest_data:
|
195
|
-
SonusAIError(f'{_error_preamble(name, line_num)}: entry already exists')
|
196
|
-
|
197
|
-
asr_manifest_data[key] = value
|
198
|
-
|
199
|
-
line_num += 1
|
200
|
-
|
201
|
-
return asr_manifest_data
|
202
|
-
|
203
117
|
@cached_property
|
204
118
|
def fg_config(self) -> FeatureGeneratorConfig:
|
205
119
|
return FeatureGeneratorConfig(feature_mode=self.feature,
|
@@ -292,14 +206,14 @@ class MixtureDatabase:
|
|
292
206
|
def feature_step_samples(self) -> int:
|
293
207
|
return self.ft_config.R * self.fg_decimation * self.fg_step
|
294
208
|
|
295
|
-
def total_samples(self,
|
296
|
-
return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(
|
209
|
+
def total_samples(self, m_ids: GeneralizedIDs = '*') -> int:
|
210
|
+
return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(m_ids)])
|
297
211
|
|
298
|
-
def total_transform_frames(self,
|
299
|
-
return self.total_samples(
|
212
|
+
def total_transform_frames(self, m_ids: GeneralizedIDs = '*') -> int:
|
213
|
+
return self.total_samples(m_ids) // self.ft_config.R
|
300
214
|
|
301
|
-
def total_feature_frames(self,
|
302
|
-
return self.total_samples(
|
215
|
+
def total_feature_frames(self, m_ids: GeneralizedIDs = '*') -> int:
|
216
|
+
return self.total_samples(m_ids) // self.feature_step_samples
|
303
217
|
|
304
218
|
def mixture_transform_frames(self, samples: int) -> int:
|
305
219
|
return samples // self.ft_config.R
|
@@ -307,24 +221,15 @@ class MixtureDatabase:
|
|
307
221
|
def mixture_feature_frames(self, samples: int) -> int:
|
308
222
|
return samples // self.feature_step_samples
|
309
223
|
|
310
|
-
def mixids_to_list(self,
|
224
|
+
def mixids_to_list(self, m_ids: Optional[GeneralizedIDs] = None) -> list[int]:
|
311
225
|
"""Resolve generalized mixture IDs to a list of integers
|
312
226
|
|
313
|
-
:param
|
227
|
+
:param m_ids: Generalized mixture IDs
|
314
228
|
:return: List of mixture ID integers
|
315
229
|
"""
|
316
230
|
from .helpers import generic_ids_to_list
|
317
231
|
|
318
|
-
return generic_ids_to_list(self.num_mixtures,
|
319
|
-
|
320
|
-
@cached_property
|
321
|
-
def asr_manifests(self) -> list[str]:
|
322
|
-
"""Get ASR manifests from db
|
323
|
-
|
324
|
-
:return: ASR manifests
|
325
|
-
"""
|
326
|
-
with self.db() as c:
|
327
|
-
return [str(item[0]) for item in c.execute("SELECT asr_manifest.manifest FROM asr_manifest").fetchall()]
|
232
|
+
return generic_ids_to_list(self.num_mixtures, m_ids)
|
328
233
|
|
329
234
|
@cached_property
|
330
235
|
def class_labels(self) -> list[str]:
|
@@ -407,7 +312,8 @@ class MixtureDatabase:
|
|
407
312
|
|
408
313
|
with self.db() as c:
|
409
314
|
target_files: TargetFiles = []
|
410
|
-
for target in c.execute(
|
315
|
+
for target in c.execute(
|
316
|
+
"SELECT target_file.name, samples, level_type, id, speaker_id FROM target_file").fetchall():
|
411
317
|
truth_settings: TruthSettings = []
|
412
318
|
for ts in c.execute(
|
413
319
|
"SELECT truth_setting.setting " +
|
@@ -422,7 +328,8 @@ class MixtureDatabase:
|
|
422
328
|
target_files.append(TargetFile(name=target[0],
|
423
329
|
samples=target[1],
|
424
330
|
level_type=target[2],
|
425
|
-
truth_settings=truth_settings
|
331
|
+
truth_settings=truth_settings,
|
332
|
+
speaker_id=target[4]))
|
426
333
|
return target_files
|
427
334
|
|
428
335
|
@cached_property
|
@@ -719,7 +626,7 @@ class MixtureDatabase:
|
|
719
626
|
|
720
627
|
:param m_id: Zero-based mixture ID
|
721
628
|
:param targets: List of augmented target audio data (one per target in the mixup)
|
722
|
-
:param target: Augmented target audio for the given
|
629
|
+
:param target: Augmented target audio for the given m_id
|
723
630
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
724
631
|
:return: Augmented target transform data
|
725
632
|
"""
|
@@ -1077,97 +984,312 @@ class MixtureDatabase:
|
|
1077
984
|
return class_count
|
1078
985
|
|
1079
986
|
@cached_property
|
1080
|
-
def
|
1081
|
-
|
987
|
+
def speaker_metadata_tiers(self) -> list[str]:
|
988
|
+
import json
|
1082
989
|
|
1083
|
-
|
1084
|
-
|
1085
|
-
data: dict[str, dict[str, SpeechMetadata]] = {}
|
1086
|
-
for file in self.target_files:
|
1087
|
-
data[file.name] = {}
|
1088
|
-
file_name, _ = tokenized_expand(file.name)
|
1089
|
-
tg_file = Path(file_name).with_suffix('.TextGrid')
|
1090
|
-
if tg_file.exists():
|
1091
|
-
tg = textgrid.openTextgrid(str(tg_file), includeEmptyIntervals=False)
|
1092
|
-
for tier in tg.tierNames:
|
1093
|
-
entries = tg.getTier(tier).entries
|
1094
|
-
if len(entries) > 1:
|
1095
|
-
data[file.name][tier] = entries
|
1096
|
-
else:
|
1097
|
-
data[file.name][tier] = entries[0].label
|
990
|
+
with self.db() as c:
|
991
|
+
return json.loads(c.execute("SELECT speaker_metadata_tiers FROM top WHERE 1 = id").fetchone()[0])
|
1098
992
|
|
1099
|
-
|
993
|
+
@cached_property
|
994
|
+
def textgrid_metadata_tiers(self) -> list[str]:
|
995
|
+
import json
|
996
|
+
|
997
|
+
with self.db() as c:
|
998
|
+
return json.loads(c.execute("SELECT textgrid_metadata_tiers FROM top WHERE 1 = id").fetchone()[0])
|
1100
999
|
|
1101
1000
|
@cached_property
|
1102
1001
|
def speech_metadata_tiers(self) -> list[str]:
|
1103
|
-
return sorted(
|
1002
|
+
return sorted(set(self.speaker_metadata_tiers + self.textgrid_metadata_tiers))
|
1104
1003
|
|
1105
|
-
def
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1004
|
+
def speaker(self, speaker_id: int | None, tier: str) -> Optional[str]:
|
1005
|
+
if speaker_id is None:
|
1006
|
+
return None
|
1007
|
+
|
1008
|
+
with self.db() as c:
|
1009
|
+
data = c.execute(f'SELECT {tier} FROM speaker WHERE ? = id', (speaker_id,)).fetchone()
|
1010
|
+
if data is None:
|
1011
|
+
return None
|
1012
|
+
if data[0] is None:
|
1013
|
+
return None
|
1014
|
+
return data[0]
|
1015
|
+
|
1016
|
+
def speech_metadata(self, tier: str) -> list[str]:
|
1017
|
+
from .helpers import get_textgrid_tier_from_target_file
|
1018
|
+
|
1019
|
+
results: set[str] = set()
|
1020
|
+
if tier in self.textgrid_metadata_tiers:
|
1021
|
+
for target_file in self.target_files:
|
1022
|
+
data = get_textgrid_tier_from_target_file(target_file.name, tier)
|
1023
|
+
if data is None:
|
1024
|
+
continue
|
1025
|
+
if isinstance(data, list):
|
1026
|
+
for item in data:
|
1027
|
+
results.add(item.label)
|
1028
|
+
else:
|
1029
|
+
results.add(data)
|
1030
|
+
elif tier in self.speaker_metadata_tiers:
|
1031
|
+
for target_file in self.target_files:
|
1032
|
+
data = self.speaker(target_file.speaker_id, tier)
|
1033
|
+
if data is not None:
|
1034
|
+
results.add(data)
|
1035
|
+
|
1036
|
+
return sorted(results)
|
1037
|
+
|
1038
|
+
def mixture_speech_metadata(self, mixid: int, tier: str) -> list[SpeechMetadata]:
|
1039
|
+
from praatio.utilities.constants import Interval
|
1040
|
+
|
1041
|
+
from .helpers import get_textgrid_tier_from_target_file
|
1042
|
+
|
1043
|
+
results: list[SpeechMetadata] = []
|
1044
|
+
is_textgrid = tier in self.textgrid_metadata_tiers
|
1045
|
+
if is_textgrid:
|
1046
|
+
for target in self.mixture(mixid).targets:
|
1047
|
+
data = get_textgrid_tier_from_target_file(self.target_file(target.file_id).name, tier)
|
1048
|
+
if data is not None:
|
1049
|
+
if isinstance(data, list):
|
1050
|
+
# Check for tempo augmentation and adjust Interval start and end data as needed
|
1051
|
+
entries = []
|
1052
|
+
for entry in data:
|
1053
|
+
if target.augmentation.tempo is not None:
|
1054
|
+
entries.append(Interval(entry.start / target.augmentation.tempo,
|
1055
|
+
entry.end / target.augmentation.tempo,
|
1056
|
+
entry.label))
|
1057
|
+
else:
|
1058
|
+
entries.append(entry)
|
1059
|
+
results.append(entries)
|
1060
|
+
else:
|
1061
|
+
results.append(data)
|
1062
|
+
else:
|
1063
|
+
for target in self.mixture(mixid).targets:
|
1064
|
+
data = self.speaker(self.target_file(target.file_id).speaker_id, tier)
|
1065
|
+
if data is not None:
|
1066
|
+
results.append(data)
|
1067
|
+
|
1068
|
+
return sorted(results)
|
1109
1069
|
|
1110
1070
|
def mixids_for_speech_metadata(self,
|
1111
1071
|
tier: str,
|
1112
|
-
value: str,
|
1072
|
+
value: str | None,
|
1113
1073
|
predicate: Callable[[str], bool] = None) -> list[int]:
|
1114
|
-
"""Get a list of
|
1074
|
+
"""Get a list of mixture IDs for the given speech metadata tier.
|
1115
1075
|
|
1116
|
-
If 'predicate' is None, then include
|
1117
|
-
not None, then ignore 'value' and use the given callable to determine which entries
|
1076
|
+
If 'predicate' is None, then include mixture IDs whose tier values are equal to the given 'value'.
|
1077
|
+
If 'predicate' is not None, then ignore 'value' and use the given callable to determine which entries
|
1078
|
+
to include.
|
1118
1079
|
|
1119
1080
|
Examples:
|
1081
|
+
>>> mixdb = MixtureDatabase('/mixdb_location')
|
1120
1082
|
|
1121
1083
|
>>> mixids = mixdb.mixids_for_speech_metadata('speaker_id', 'TIMIT_ARC0')
|
1122
|
-
Get
|
1084
|
+
Get mixutre IDs for mixtures with speakers whose speaker_ids are 'TIMIT_ARC0'.
|
1123
1085
|
|
1124
1086
|
>>> mixids = mixdb.mixids_for_speech_metadata('age', '', lambda x: int(x) < 25)
|
1125
|
-
Get
|
1087
|
+
Get mixture IDs for mixtures with speakers whose ages are less than 25.
|
1126
1088
|
|
1127
1089
|
>>> mixids = mixdb.mixids_for_speech_metadata('dialect', '', lambda x: x in ['New York City', 'Northern'])
|
1128
|
-
Get
|
1090
|
+
Get mixture IDs for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
|
1129
1091
|
"""
|
1092
|
+
from .helpers import get_textgrid_tier_from_target_file
|
1093
|
+
|
1130
1094
|
if predicate is None:
|
1131
|
-
def predicate(x: str) -> bool:
|
1095
|
+
def predicate(x: str | None) -> bool:
|
1132
1096
|
return x == value
|
1133
1097
|
|
1134
1098
|
# First get list of matching target files
|
1135
|
-
target_files = [
|
1136
|
-
|
1099
|
+
target_files: list[str] = []
|
1100
|
+
is_textgrid = tier in self.textgrid_metadata_tiers
|
1101
|
+
for target_file in self.target_files:
|
1102
|
+
if is_textgrid:
|
1103
|
+
metadata = get_textgrid_tier_from_target_file(target_file.name, tier)
|
1104
|
+
else:
|
1105
|
+
metadata = self.speaker(target_file.speaker_id, tier)
|
1137
1106
|
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1107
|
+
if not isinstance(metadata, list) and predicate(metadata):
|
1108
|
+
target_files.append(target_file.name)
|
1109
|
+
|
1110
|
+
# Next get list of mixture IDs that contain those target files
|
1111
|
+
m_ids: list[int] = []
|
1112
|
+
for m_id in self.mixids_to_list():
|
1113
|
+
mixid_target_files = [self.target_file(target.file_id).name for target in self.mixture(m_id).targets]
|
1142
1114
|
for mixid_target_file in mixid_target_files:
|
1143
1115
|
if mixid_target_file in target_files:
|
1144
|
-
|
1116
|
+
m_ids.append(m_id)
|
1145
1117
|
|
1146
|
-
# Return sorted, unique list of
|
1147
|
-
return sorted(list(set(
|
1118
|
+
# Return sorted, unique list of mixture IDs
|
1119
|
+
return sorted(list(set(m_ids)))
|
1148
1120
|
|
1149
|
-
def
|
1150
|
-
|
1151
|
-
for target in self.mixture(mixid).targets:
|
1152
|
-
data = self._speech_metadata[self.target_file(target.file_id).name].get(tier)
|
1121
|
+
def mixture_all_speech_metadata(self, m_id: int) -> list[dict[str, SpeechMetadata]]:
|
1122
|
+
from .helpers import mixture_all_speech_metadata
|
1153
1123
|
|
1154
|
-
|
1155
|
-
results.append(None)
|
1156
|
-
elif isinstance(data, list):
|
1157
|
-
# Check for tempo augmentation and adjust Interval start and end data as needed
|
1158
|
-
entries = []
|
1159
|
-
for entry in data:
|
1160
|
-
if target.augmentation.tempo is not None:
|
1161
|
-
entries.append(Interval(entry.start / target.augmentation.tempo,
|
1162
|
-
entry.end / target.augmentation.tempo,
|
1163
|
-
entry.label))
|
1164
|
-
else:
|
1165
|
-
entries.append(entry)
|
1124
|
+
return mixture_all_speech_metadata(self, self.mixture(m_id))
|
1166
1125
|
|
1167
|
-
|
1168
|
-
|
1126
|
+
def mixture_metric(self, m_id: int, metric: str, force: bool = False) -> Any:
|
1127
|
+
"""Get metric data for the given mixture ID
|
1128
|
+
|
1129
|
+
:param m_id: Zero-based mixture ID
|
1130
|
+
:param metric: Metric data to retrieve
|
1131
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
1132
|
+
:return: Metric data
|
1133
|
+
"""
|
1134
|
+
from sonusai import SonusAIError
|
1135
|
+
|
1136
|
+
supported_metrics = (
|
1137
|
+
'MXSNR',
|
1138
|
+
'MXSSNRAVG',
|
1139
|
+
'MXSSNRSTD',
|
1140
|
+
'MXSSNRDAVG',
|
1141
|
+
'MXSSNRDSTD',
|
1142
|
+
'MXPESQ',
|
1143
|
+
'MXWSDR',
|
1144
|
+
'MXPD',
|
1145
|
+
'MXSTOI',
|
1146
|
+
'MXCSIG',
|
1147
|
+
'MXCBAK',
|
1148
|
+
'MXCOVL',
|
1149
|
+
'TDCO',
|
1150
|
+
'TMIN',
|
1151
|
+
'TMAX',
|
1152
|
+
'TPKDB',
|
1153
|
+
'TLRMS',
|
1154
|
+
'TPKR',
|
1155
|
+
'TTR',
|
1156
|
+
'TCR',
|
1157
|
+
'TFL',
|
1158
|
+
'TPKC',
|
1159
|
+
'NDCO',
|
1160
|
+
'NMIN',
|
1161
|
+
'NMAX',
|
1162
|
+
'NPKDB',
|
1163
|
+
'NLRMS',
|
1164
|
+
'NPKR',
|
1165
|
+
'NTR',
|
1166
|
+
'NCR',
|
1167
|
+
'NFL',
|
1168
|
+
'NPKC',
|
1169
|
+
'SEDAVG',
|
1170
|
+
'SEDCNT',
|
1171
|
+
'SEDTOPN',
|
1172
|
+
)
|
1173
|
+
|
1174
|
+
if not (metric in supported_metrics or metric.startswith('MXWER')):
|
1175
|
+
raise ValueError(f'Unsupported metric: {metric}')
|
1176
|
+
|
1177
|
+
if not force:
|
1178
|
+
result = self.read_mixture_data(m_id, metric)
|
1179
|
+
if result is not None:
|
1180
|
+
return result
|
1181
|
+
|
1182
|
+
mixture = self.mixture(m_id)
|
1183
|
+
if mixture is None:
|
1184
|
+
raise SonusAIError(f'Could not find mixture for m_id: {m_id}')
|
1185
|
+
|
1186
|
+
if metric.startswith('MXWER'):
|
1187
|
+
return None
|
1188
|
+
|
1189
|
+
if metric == 'MXSNR':
|
1190
|
+
return self.snrs
|
1191
|
+
|
1192
|
+
if metric == 'MXSSNRAVG':
|
1193
|
+
return None
|
1194
|
+
|
1195
|
+
if metric == 'MXSSNRSTD':
|
1196
|
+
return None
|
1197
|
+
|
1198
|
+
if metric == 'MXSSNRDAVG':
|
1199
|
+
return None
|
1200
|
+
|
1201
|
+
if metric == 'MXSSNRDSTD':
|
1202
|
+
return None
|
1203
|
+
|
1204
|
+
if metric == 'MXPESQ':
|
1205
|
+
return None
|
1206
|
+
|
1207
|
+
if metric == 'MXWSDR':
|
1208
|
+
return None
|
1209
|
+
|
1210
|
+
if metric == 'MXPD':
|
1211
|
+
return None
|
1212
|
+
|
1213
|
+
if metric == 'MXSTOI':
|
1214
|
+
return None
|
1215
|
+
|
1216
|
+
if metric == 'MXCSIG':
|
1217
|
+
return None
|
1218
|
+
|
1219
|
+
if metric == 'MXCBAK':
|
1220
|
+
return None
|
1221
|
+
|
1222
|
+
if metric == 'MXCOVL':
|
1223
|
+
return None
|
1224
|
+
|
1225
|
+
if metric == 'TDCO':
|
1226
|
+
return None
|
1227
|
+
|
1228
|
+
if metric == 'TMIN':
|
1229
|
+
return None
|
1230
|
+
|
1231
|
+
if metric == 'TMAX':
|
1232
|
+
return None
|
1233
|
+
|
1234
|
+
if metric == 'TPKDB':
|
1235
|
+
return None
|
1236
|
+
|
1237
|
+
if metric == 'TLRMS':
|
1238
|
+
return None
|
1239
|
+
|
1240
|
+
if metric == 'TPKR':
|
1241
|
+
return None
|
1242
|
+
|
1243
|
+
if metric == 'TTR':
|
1244
|
+
return None
|
1245
|
+
|
1246
|
+
if metric == 'TCR':
|
1247
|
+
return None
|
1248
|
+
|
1249
|
+
if metric == 'TFL':
|
1250
|
+
return None
|
1251
|
+
|
1252
|
+
if metric == 'TPKC':
|
1253
|
+
return None
|
1254
|
+
|
1255
|
+
if metric == 'NDCO':
|
1256
|
+
return None
|
1257
|
+
|
1258
|
+
if metric == 'NMIN':
|
1259
|
+
return None
|
1260
|
+
|
1261
|
+
if metric == 'NMAX':
|
1262
|
+
return None
|
1263
|
+
|
1264
|
+
if metric == 'NPKDB':
|
1265
|
+
return None
|
1266
|
+
|
1267
|
+
if metric == 'NLRMS':
|
1268
|
+
return None
|
1269
|
+
|
1270
|
+
if metric == 'NPKR':
|
1271
|
+
return None
|
1272
|
+
|
1273
|
+
if metric == 'NTR':
|
1274
|
+
return None
|
1275
|
+
|
1276
|
+
if metric == 'NCR':
|
1277
|
+
return None
|
1278
|
+
|
1279
|
+
if metric == 'NFL':
|
1280
|
+
return None
|
1281
|
+
|
1282
|
+
if metric == 'NPKC':
|
1283
|
+
return None
|
1284
|
+
|
1285
|
+
if metric == 'SEDAVG':
|
1286
|
+
return None
|
1287
|
+
|
1288
|
+
if metric == 'SEDCNT':
|
1289
|
+
return None
|
1169
1290
|
|
1170
|
-
|
1291
|
+
if metric == 'SEDTOPN':
|
1292
|
+
return None
|
1171
1293
|
|
1172
1294
|
|
1173
1295
|
@lru_cache
|
@@ -1205,8 +1327,9 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
|
|
1205
1327
|
from .datatypes import TruthSettings
|
1206
1328
|
|
1207
1329
|
with db() as c:
|
1208
|
-
target = c.execute(
|
1209
|
-
|
1330
|
+
target = c.execute(
|
1331
|
+
"SELECT target_file.name, samples, level_type, speaker_id FROM target_file WHERE ? = target_file.id",
|
1332
|
+
(t_id,)).fetchone()
|
1210
1333
|
|
1211
1334
|
truth_settings: TruthSettings = []
|
1212
1335
|
for ts in c.execute(
|
@@ -1222,7 +1345,8 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
|
|
1222
1345
|
return TargetFile(name=target[0],
|
1223
1346
|
samples=target[1],
|
1224
1347
|
level_type=target[2],
|
1225
|
-
truth_settings=truth_settings
|
1348
|
+
truth_settings=truth_settings,
|
1349
|
+
speaker_id=target[3])
|
1226
1350
|
|
1227
1351
|
|
1228
1352
|
@lru_cache
|
@@ -1,4 +1,7 @@
|
|
1
|
-
|
1
|
+
from pathlib import Path
|
2
|
+
|
3
|
+
|
4
|
+
def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
|
2
5
|
"""Expand shell variables of the forms $var, ${var} and %var%.
|
3
6
|
Unknown variables are left unchanged.
|
4
7
|
|
@@ -25,6 +28,9 @@ def tokenized_expand(name: str | bytes) -> tuple[str, dict[str, str]]:
|
|
25
28
|
if isinstance(name, bytes):
|
26
29
|
name = name.decode('utf-8')
|
27
30
|
|
31
|
+
if isinstance(name, Path):
|
32
|
+
name = name.as_posix()
|
33
|
+
|
28
34
|
name = os.fspath(name)
|
29
35
|
token_map: dict = {}
|
30
36
|
|
@@ -121,6 +127,7 @@ def tokenized_expand(name: str | bytes) -> tuple[str, dict[str, str]]:
|
|
121
127
|
else:
|
122
128
|
result += c
|
123
129
|
index += 1
|
130
|
+
|
124
131
|
return result, token_map
|
125
132
|
|
126
133
|
|
sonusai/speech/textgrid.py
CHANGED
@@ -6,37 +6,19 @@ from praatio.utilities.constants import Interval
|
|
6
6
|
from .types import TimeAlignedType
|
7
7
|
|
8
8
|
|
9
|
-
def _get_duration(name: str) -> float:
|
10
|
-
from pydub import AudioSegment
|
11
|
-
|
12
|
-
from sonusai import SonusAIError
|
13
|
-
|
14
|
-
try:
|
15
|
-
return AudioSegment.from_file(name).duration_seconds
|
16
|
-
except Exception as e:
|
17
|
-
raise SonusAIError(f'Error reading {name}: {e}')
|
18
|
-
|
19
|
-
|
20
9
|
def create_textgrid(prompt: Path,
|
21
|
-
speaker_id: str,
|
22
|
-
speaker: dict,
|
23
10
|
output_dir: Path,
|
24
11
|
text: TimeAlignedType = None,
|
25
12
|
words: list[TimeAlignedType] = None,
|
26
13
|
phonemes: list[TimeAlignedType] = None) -> None:
|
27
|
-
if text is
|
28
|
-
|
29
|
-
'text': [text],
|
30
|
-
'words': words})
|
31
|
-
else:
|
32
|
-
min_t = 0
|
33
|
-
max_t = _get_duration(str(prompt))
|
14
|
+
if text is None and words is None and phonemes is None:
|
15
|
+
return
|
34
16
|
|
35
|
-
|
17
|
+
min_t, max_t = _get_min_max({'phonemes': phonemes,
|
18
|
+
'text': [text],
|
19
|
+
'words': words})
|
36
20
|
|
37
|
-
tg
|
38
|
-
for tier in speaker.keys():
|
39
|
-
tg.addTier(textgrid.IntervalTier(tier, [Interval(min_t, max_t, str(speaker[tier]))], min_t, max_t))
|
21
|
+
tg = textgrid.Textgrid()
|
40
22
|
|
41
23
|
if text is not None:
|
42
24
|
entries = [Interval(text.start, text.end, text.text)]
|