sonusai 0.17.2__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +0 -1
- sonusai/audiofe.py +3 -3
- sonusai/calc_metric_spenh.py +81 -52
- sonusai/doc/doc.py +0 -24
- sonusai/genmetrics.py +146 -0
- sonusai/genmixdb.py +0 -2
- sonusai/mixture/__init__.py +0 -1
- sonusai/mixture/constants.py +0 -1
- sonusai/mixture/datatypes.py +2 -9
- sonusai/mixture/generation.py +136 -38
- sonusai/mixture/helpers.py +58 -1
- sonusai/mixture/mapped_snr_f.py +56 -9
- sonusai/mixture/mixdb.py +293 -170
- sonusai/mixture/sox_augmentation.py +3 -0
- sonusai/mixture/tokenized_shell_vars.py +8 -1
- sonusai/mkwav.py +4 -4
- sonusai/onnx_predict.py +2 -2
- sonusai/post_spenh_targetf.py +2 -2
- sonusai/speech/textgrid.py +6 -24
- sonusai/speech/{voxceleb2.py → voxceleb.py} +19 -3
- sonusai/utils/__init__.py +1 -1
- sonusai/utils/asr_functions/aaware_whisper.py +2 -2
- sonusai/utils/{wave.py → write_audio.py} +2 -2
- {sonusai-0.17.2.dist-info → sonusai-0.18.0.dist-info}/METADATA +4 -1
- {sonusai-0.17.2.dist-info → sonusai-0.18.0.dist-info}/RECORD +27 -33
- sonusai/mixture/speaker_metadata.py +0 -35
- sonusai/mkmanifest.py +0 -209
- sonusai/utils/asr_manifest_functions/__init__.py +0 -6
- sonusai/utils/asr_manifest_functions/data.py +0 -1
- sonusai/utils/asr_manifest_functions/librispeech.py +0 -46
- sonusai/utils/asr_manifest_functions/mcgill_speech.py +0 -29
- sonusai/utils/asr_manifest_functions/vctk_noisy_speech.py +0 -66
- {sonusai-0.17.2.dist-info → sonusai-0.18.0.dist-info}/WHEEL +0 -0
- {sonusai-0.17.2.dist-info → sonusai-0.18.0.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py
CHANGED
@@ -1,16 +1,12 @@
|
|
1
1
|
from functools import cached_property
|
2
2
|
from functools import lru_cache
|
3
3
|
from functools import partial
|
4
|
-
from pathlib import Path
|
5
4
|
from sqlite3 import Connection
|
6
5
|
from sqlite3 import Cursor
|
7
6
|
from typing import Any
|
8
7
|
from typing import Callable
|
9
8
|
from typing import Optional
|
10
9
|
|
11
|
-
from praatio import textgrid
|
12
|
-
from praatio.utilities.constants import Interval
|
13
|
-
|
14
10
|
from sonusai.mixture.datatypes import AudioF
|
15
11
|
from sonusai.mixture.datatypes import AudioT
|
16
12
|
from sonusai.mixture.datatypes import AudiosF
|
@@ -34,7 +30,6 @@ from sonusai.mixture.datatypes import TargetFiles
|
|
34
30
|
from sonusai.mixture.datatypes import TransformConfig
|
35
31
|
from sonusai.mixture.datatypes import Truth
|
36
32
|
from sonusai.mixture.datatypes import UniversalSNR
|
37
|
-
from sonusai.mixture.tokenized_shell_vars import tokenized_expand
|
38
33
|
|
39
34
|
|
40
35
|
def db_file(location: str, test: bool = False) -> str:
|
@@ -88,14 +83,12 @@ class MixtureDatabase:
|
|
88
83
|
def __init__(self, location: str, test: bool = False) -> None:
|
89
84
|
self.location = location
|
90
85
|
self.db = partial(SQLiteContextManager, self.location, test)
|
91
|
-
self._speaker_metadata_tiers: list[str] = []
|
92
86
|
|
93
87
|
@cached_property
|
94
88
|
def json(self) -> str:
|
95
89
|
from .datatypes import MixtureDatabaseConfig
|
96
90
|
|
97
91
|
config = MixtureDatabaseConfig(
|
98
|
-
asr_manifest=self.asr_manifests,
|
99
92
|
class_balancing=self.class_balancing,
|
100
93
|
class_labels=self.class_labels,
|
101
94
|
class_weights_threshold=self.class_weights_thresholds,
|
@@ -121,86 +114,6 @@ class MixtureDatabase:
|
|
121
114
|
with open(file=json_name, mode='w') as file:
|
122
115
|
file.write(self.json)
|
123
116
|
|
124
|
-
def target_asr_data(self, t_id: int) -> str | None:
|
125
|
-
"""Get the ASR data for the given target ID
|
126
|
-
|
127
|
-
:param t_id: Target ID
|
128
|
-
:return: ASR text or None
|
129
|
-
"""
|
130
|
-
from .tokenized_shell_vars import tokenized_expand
|
131
|
-
|
132
|
-
name, _ = tokenized_expand(self.target_file(t_id).name)
|
133
|
-
return self.asr_manifest_data.get(name, None)
|
134
|
-
|
135
|
-
def mixture_asr_data(self, m_id: int) -> list[str | None]:
|
136
|
-
"""Get the ASR data for the given mixid
|
137
|
-
|
138
|
-
:param m_id: Zero-based mixture ID
|
139
|
-
:return: List of ASR text or None
|
140
|
-
"""
|
141
|
-
return [self.target_asr_data(target.file_id) for target in self.mixture(m_id).targets]
|
142
|
-
|
143
|
-
@cached_property
|
144
|
-
def asr_manifest_data(self) -> dict[str, str]:
|
145
|
-
"""Get ASR data
|
146
|
-
|
147
|
-
Each line of a manifest file should be in the following format:
|
148
|
-
|
149
|
-
{"audio_filepath": "/path/to/audio.wav", "text": "the transcription of the utterance", "duration": 23.147}
|
150
|
-
|
151
|
-
The audio_filepath field should provide an absolute path to the audio file corresponding to the utterance. The
|
152
|
-
text field should contain the full transcript for the utterance, and the duration field should reflect the
|
153
|
-
duration of the utterance in seconds.
|
154
|
-
|
155
|
-
Each entry in the manifest (describing one audio file) should be bordered by '{' and '}' and must be contained
|
156
|
-
on one line. The fields that describe the file should be separated by commas, and have the form
|
157
|
-
"field_name": value, as shown above.
|
158
|
-
|
159
|
-
Since the manifest specifies the path for each utterance, the audio files do not have to be located in the same
|
160
|
-
directory as the manifest, or even in any specific directory structure.
|
161
|
-
|
162
|
-
The manifest dictionary consists of key/value pairs where the keys are target file names and the values are ASR
|
163
|
-
text.
|
164
|
-
"""
|
165
|
-
import json
|
166
|
-
|
167
|
-
from sonusai import SonusAIError
|
168
|
-
from .tokenized_shell_vars import tokenized_expand
|
169
|
-
|
170
|
-
expected_keys = ['audio_filepath', 'text', 'duration']
|
171
|
-
|
172
|
-
def _error_preamble(e_name: str, e_line_num: int) -> str:
|
173
|
-
return f'Invalid entry in ASR manifest {e_name} line {e_line_num}'
|
174
|
-
|
175
|
-
asr_manifest_data: dict[str, str] = {}
|
176
|
-
|
177
|
-
for name in self.asr_manifests:
|
178
|
-
expanded_name, _ = tokenized_expand(name)
|
179
|
-
with open(file=expanded_name, mode='r') as f:
|
180
|
-
line_num = 1
|
181
|
-
for line in f:
|
182
|
-
result = json.loads(line.strip())
|
183
|
-
|
184
|
-
for key in expected_keys:
|
185
|
-
if key not in result:
|
186
|
-
SonusAIError(f'{_error_preamble(name, line_num)}: missing field "{key}"')
|
187
|
-
|
188
|
-
for key in result.keys():
|
189
|
-
if key not in expected_keys:
|
190
|
-
SonusAIError(f'{_error_preamble(name, line_num)}: unknown field "{key}"')
|
191
|
-
|
192
|
-
key, _ = tokenized_expand(result['audio_filepath'])
|
193
|
-
value = result['text']
|
194
|
-
|
195
|
-
if key in asr_manifest_data:
|
196
|
-
SonusAIError(f'{_error_preamble(name, line_num)}: entry already exists')
|
197
|
-
|
198
|
-
asr_manifest_data[key] = value
|
199
|
-
|
200
|
-
line_num += 1
|
201
|
-
|
202
|
-
return asr_manifest_data
|
203
|
-
|
204
117
|
@cached_property
|
205
118
|
def fg_config(self) -> FeatureGeneratorConfig:
|
206
119
|
return FeatureGeneratorConfig(feature_mode=self.feature,
|
@@ -293,14 +206,14 @@ class MixtureDatabase:
|
|
293
206
|
def feature_step_samples(self) -> int:
|
294
207
|
return self.ft_config.R * self.fg_decimation * self.fg_step
|
295
208
|
|
296
|
-
def total_samples(self,
|
297
|
-
return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(
|
209
|
+
def total_samples(self, m_ids: GeneralizedIDs = '*') -> int:
|
210
|
+
return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(m_ids)])
|
298
211
|
|
299
|
-
def total_transform_frames(self,
|
300
|
-
return self.total_samples(
|
212
|
+
def total_transform_frames(self, m_ids: GeneralizedIDs = '*') -> int:
|
213
|
+
return self.total_samples(m_ids) // self.ft_config.R
|
301
214
|
|
302
|
-
def total_feature_frames(self,
|
303
|
-
return self.total_samples(
|
215
|
+
def total_feature_frames(self, m_ids: GeneralizedIDs = '*') -> int:
|
216
|
+
return self.total_samples(m_ids) // self.feature_step_samples
|
304
217
|
|
305
218
|
def mixture_transform_frames(self, samples: int) -> int:
|
306
219
|
return samples // self.ft_config.R
|
@@ -308,24 +221,15 @@ class MixtureDatabase:
|
|
308
221
|
def mixture_feature_frames(self, samples: int) -> int:
|
309
222
|
return samples // self.feature_step_samples
|
310
223
|
|
311
|
-
def mixids_to_list(self,
|
224
|
+
def mixids_to_list(self, m_ids: Optional[GeneralizedIDs] = None) -> list[int]:
|
312
225
|
"""Resolve generalized mixture IDs to a list of integers
|
313
226
|
|
314
|
-
:param
|
227
|
+
:param m_ids: Generalized mixture IDs
|
315
228
|
:return: List of mixture ID integers
|
316
229
|
"""
|
317
230
|
from .helpers import generic_ids_to_list
|
318
231
|
|
319
|
-
return generic_ids_to_list(self.num_mixtures,
|
320
|
-
|
321
|
-
@cached_property
|
322
|
-
def asr_manifests(self) -> list[str]:
|
323
|
-
"""Get ASR manifests from db
|
324
|
-
|
325
|
-
:return: ASR manifests
|
326
|
-
"""
|
327
|
-
with self.db() as c:
|
328
|
-
return [str(item[0]) for item in c.execute("SELECT asr_manifest.manifest FROM asr_manifest").fetchall()]
|
232
|
+
return generic_ids_to_list(self.num_mixtures, m_ids)
|
329
233
|
|
330
234
|
@cached_property
|
331
235
|
def class_labels(self) -> list[str]:
|
@@ -408,7 +312,8 @@ class MixtureDatabase:
|
|
408
312
|
|
409
313
|
with self.db() as c:
|
410
314
|
target_files: TargetFiles = []
|
411
|
-
for target in c.execute(
|
315
|
+
for target in c.execute(
|
316
|
+
"SELECT target_file.name, samples, level_type, id, speaker_id FROM target_file").fetchall():
|
412
317
|
truth_settings: TruthSettings = []
|
413
318
|
for ts in c.execute(
|
414
319
|
"SELECT truth_setting.setting " +
|
@@ -423,7 +328,8 @@ class MixtureDatabase:
|
|
423
328
|
target_files.append(TargetFile(name=target[0],
|
424
329
|
samples=target[1],
|
425
330
|
level_type=target[2],
|
426
|
-
truth_settings=truth_settings
|
331
|
+
truth_settings=truth_settings,
|
332
|
+
speaker_id=target[4]))
|
427
333
|
return target_files
|
428
334
|
|
429
335
|
@cached_property
|
@@ -720,7 +626,7 @@ class MixtureDatabase:
|
|
720
626
|
|
721
627
|
:param m_id: Zero-based mixture ID
|
722
628
|
:param targets: List of augmented target audio data (one per target in the mixup)
|
723
|
-
:param target: Augmented target audio for the given
|
629
|
+
:param target: Augmented target audio for the given m_id
|
724
630
|
:param force: Force computing data from original sources regardless of whether cached data exists
|
725
631
|
:return: Augmented target transform data
|
726
632
|
"""
|
@@ -1078,97 +984,312 @@ class MixtureDatabase:
|
|
1078
984
|
return class_count
|
1079
985
|
|
1080
986
|
@cached_property
|
1081
|
-
def
|
1082
|
-
|
987
|
+
def speaker_metadata_tiers(self) -> list[str]:
|
988
|
+
import json
|
1083
989
|
|
1084
|
-
|
1085
|
-
|
1086
|
-
data: dict[str, dict[str, SpeechMetadata]] = {}
|
1087
|
-
for file in self.target_files:
|
1088
|
-
data[file.name] = {}
|
1089
|
-
file_name, _ = tokenized_expand(file.name)
|
1090
|
-
tg_file = Path(file_name).with_suffix('.TextGrid')
|
1091
|
-
if tg_file.exists():
|
1092
|
-
tg = textgrid.openTextgrid(str(tg_file), includeEmptyIntervals=False)
|
1093
|
-
for tier in tg.tierNames:
|
1094
|
-
entries = tg.getTier(tier).entries
|
1095
|
-
if len(entries) > 1:
|
1096
|
-
data[file.name][tier] = entries
|
1097
|
-
else:
|
1098
|
-
data[file.name][tier] = entries[0].label
|
990
|
+
with self.db() as c:
|
991
|
+
return json.loads(c.execute("SELECT speaker_metadata_tiers FROM top WHERE 1 = id").fetchone()[0])
|
1099
992
|
|
1100
|
-
|
993
|
+
@cached_property
|
994
|
+
def textgrid_metadata_tiers(self) -> list[str]:
|
995
|
+
import json
|
996
|
+
|
997
|
+
with self.db() as c:
|
998
|
+
return json.loads(c.execute("SELECT textgrid_metadata_tiers FROM top WHERE 1 = id").fetchone()[0])
|
1101
999
|
|
1102
1000
|
@cached_property
|
1103
1001
|
def speech_metadata_tiers(self) -> list[str]:
|
1104
|
-
return sorted(
|
1002
|
+
return sorted(set(self.speaker_metadata_tiers + self.textgrid_metadata_tiers))
|
1105
1003
|
|
1106
|
-
def
|
1107
|
-
|
1108
|
-
|
1109
|
-
|
1004
|
+
def speaker(self, speaker_id: int | None, tier: str) -> Optional[str]:
|
1005
|
+
if speaker_id is None:
|
1006
|
+
return None
|
1007
|
+
|
1008
|
+
with self.db() as c:
|
1009
|
+
data = c.execute(f'SELECT {tier} FROM speaker WHERE ? = id', (speaker_id,)).fetchone()
|
1010
|
+
if data is None:
|
1011
|
+
return None
|
1012
|
+
if data[0] is None:
|
1013
|
+
return None
|
1014
|
+
return data[0]
|
1015
|
+
|
1016
|
+
def speech_metadata(self, tier: str) -> list[str]:
|
1017
|
+
from .helpers import get_textgrid_tier_from_target_file
|
1018
|
+
|
1019
|
+
results: set[str] = set()
|
1020
|
+
if tier in self.textgrid_metadata_tiers:
|
1021
|
+
for target_file in self.target_files:
|
1022
|
+
data = get_textgrid_tier_from_target_file(target_file.name, tier)
|
1023
|
+
if data is None:
|
1024
|
+
continue
|
1025
|
+
if isinstance(data, list):
|
1026
|
+
for item in data:
|
1027
|
+
results.add(item.label)
|
1028
|
+
else:
|
1029
|
+
results.add(data)
|
1030
|
+
elif tier in self.speaker_metadata_tiers:
|
1031
|
+
for target_file in self.target_files:
|
1032
|
+
data = self.speaker(target_file.speaker_id, tier)
|
1033
|
+
if data is not None:
|
1034
|
+
results.add(data)
|
1035
|
+
|
1036
|
+
return sorted(results)
|
1037
|
+
|
1038
|
+
def mixture_speech_metadata(self, mixid: int, tier: str) -> list[SpeechMetadata]:
|
1039
|
+
from praatio.utilities.constants import Interval
|
1040
|
+
|
1041
|
+
from .helpers import get_textgrid_tier_from_target_file
|
1042
|
+
|
1043
|
+
results: list[SpeechMetadata] = []
|
1044
|
+
is_textgrid = tier in self.textgrid_metadata_tiers
|
1045
|
+
if is_textgrid:
|
1046
|
+
for target in self.mixture(mixid).targets:
|
1047
|
+
data = get_textgrid_tier_from_target_file(self.target_file(target.file_id).name, tier)
|
1048
|
+
if data is not None:
|
1049
|
+
if isinstance(data, list):
|
1050
|
+
# Check for tempo augmentation and adjust Interval start and end data as needed
|
1051
|
+
entries = []
|
1052
|
+
for entry in data:
|
1053
|
+
if target.augmentation.tempo is not None:
|
1054
|
+
entries.append(Interval(entry.start / target.augmentation.tempo,
|
1055
|
+
entry.end / target.augmentation.tempo,
|
1056
|
+
entry.label))
|
1057
|
+
else:
|
1058
|
+
entries.append(entry)
|
1059
|
+
results.append(entries)
|
1060
|
+
else:
|
1061
|
+
results.append(data)
|
1062
|
+
else:
|
1063
|
+
for target in self.mixture(mixid).targets:
|
1064
|
+
data = self.speaker(self.target_file(target.file_id).speaker_id, tier)
|
1065
|
+
if data is not None:
|
1066
|
+
results.append(data)
|
1067
|
+
|
1068
|
+
return sorted(results)
|
1110
1069
|
|
1111
1070
|
def mixids_for_speech_metadata(self,
|
1112
1071
|
tier: str,
|
1113
|
-
value: str,
|
1072
|
+
value: str | None,
|
1114
1073
|
predicate: Callable[[str], bool] = None) -> list[int]:
|
1115
|
-
"""Get a list of
|
1074
|
+
"""Get a list of mixture IDs for the given speech metadata tier.
|
1116
1075
|
|
1117
|
-
If 'predicate' is None, then include
|
1118
|
-
not None, then ignore 'value' and use the given callable to determine which entries
|
1076
|
+
If 'predicate' is None, then include mixture IDs whose tier values are equal to the given 'value'.
|
1077
|
+
If 'predicate' is not None, then ignore 'value' and use the given callable to determine which entries
|
1078
|
+
to include.
|
1119
1079
|
|
1120
1080
|
Examples:
|
1081
|
+
>>> mixdb = MixtureDatabase('/mixdb_location')
|
1121
1082
|
|
1122
1083
|
>>> mixids = mixdb.mixids_for_speech_metadata('speaker_id', 'TIMIT_ARC0')
|
1123
|
-
Get
|
1084
|
+
Get mixutre IDs for mixtures with speakers whose speaker_ids are 'TIMIT_ARC0'.
|
1124
1085
|
|
1125
1086
|
>>> mixids = mixdb.mixids_for_speech_metadata('age', '', lambda x: int(x) < 25)
|
1126
|
-
Get
|
1087
|
+
Get mixture IDs for mixtures with speakers whose ages are less than 25.
|
1127
1088
|
|
1128
1089
|
>>> mixids = mixdb.mixids_for_speech_metadata('dialect', '', lambda x: x in ['New York City', 'Northern'])
|
1129
|
-
Get
|
1090
|
+
Get mixture IDs for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
|
1130
1091
|
"""
|
1092
|
+
from .helpers import get_textgrid_tier_from_target_file
|
1093
|
+
|
1131
1094
|
if predicate is None:
|
1132
|
-
def predicate(x: str) -> bool:
|
1095
|
+
def predicate(x: str | None) -> bool:
|
1133
1096
|
return x == value
|
1134
1097
|
|
1135
1098
|
# First get list of matching target files
|
1136
|
-
target_files = [
|
1137
|
-
|
1099
|
+
target_files: list[str] = []
|
1100
|
+
is_textgrid = tier in self.textgrid_metadata_tiers
|
1101
|
+
for target_file in self.target_files:
|
1102
|
+
if is_textgrid:
|
1103
|
+
metadata = get_textgrid_tier_from_target_file(target_file.name, tier)
|
1104
|
+
else:
|
1105
|
+
metadata = self.speaker(target_file.speaker_id, tier)
|
1138
1106
|
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1107
|
+
if not isinstance(metadata, list) and predicate(metadata):
|
1108
|
+
target_files.append(target_file.name)
|
1109
|
+
|
1110
|
+
# Next get list of mixture IDs that contain those target files
|
1111
|
+
m_ids: list[int] = []
|
1112
|
+
for m_id in self.mixids_to_list():
|
1113
|
+
mixid_target_files = [self.target_file(target.file_id).name for target in self.mixture(m_id).targets]
|
1143
1114
|
for mixid_target_file in mixid_target_files:
|
1144
1115
|
if mixid_target_file in target_files:
|
1145
|
-
|
1116
|
+
m_ids.append(m_id)
|
1146
1117
|
|
1147
|
-
# Return sorted, unique list of
|
1148
|
-
return sorted(list(set(
|
1118
|
+
# Return sorted, unique list of mixture IDs
|
1119
|
+
return sorted(list(set(m_ids)))
|
1149
1120
|
|
1150
|
-
def
|
1151
|
-
|
1152
|
-
for target in self.mixture(mixid).targets:
|
1153
|
-
data = self._speech_metadata[self.target_file(target.file_id).name].get(tier)
|
1121
|
+
def mixture_all_speech_metadata(self, m_id: int) -> list[dict[str, SpeechMetadata]]:
|
1122
|
+
from .helpers import mixture_all_speech_metadata
|
1154
1123
|
|
1155
|
-
|
1156
|
-
results.append(None)
|
1157
|
-
elif isinstance(data, list):
|
1158
|
-
# Check for tempo augmentation and adjust Interval start and end data as needed
|
1159
|
-
entries = []
|
1160
|
-
for entry in data:
|
1161
|
-
if target.augmentation.tempo is not None:
|
1162
|
-
entries.append(Interval(entry.start / target.augmentation.tempo,
|
1163
|
-
entry.end / target.augmentation.tempo,
|
1164
|
-
entry.label))
|
1165
|
-
else:
|
1166
|
-
entries.append(entry)
|
1124
|
+
return mixture_all_speech_metadata(self, self.mixture(m_id))
|
1167
1125
|
|
1168
|
-
|
1169
|
-
|
1126
|
+
def mixture_metric(self, m_id: int, metric: str, force: bool = False) -> Any:
|
1127
|
+
"""Get metric data for the given mixture ID
|
1128
|
+
|
1129
|
+
:param m_id: Zero-based mixture ID
|
1130
|
+
:param metric: Metric data to retrieve
|
1131
|
+
:param force: Force computing data from original sources regardless of whether cached data exists
|
1132
|
+
:return: Metric data
|
1133
|
+
"""
|
1134
|
+
from sonusai import SonusAIError
|
1135
|
+
|
1136
|
+
supported_metrics = (
|
1137
|
+
'MXSNR',
|
1138
|
+
'MXSSNRAVG',
|
1139
|
+
'MXSSNRSTD',
|
1140
|
+
'MXSSNRDAVG',
|
1141
|
+
'MXSSNRDSTD',
|
1142
|
+
'MXPESQ',
|
1143
|
+
'MXWSDR',
|
1144
|
+
'MXPD',
|
1145
|
+
'MXSTOI',
|
1146
|
+
'MXCSIG',
|
1147
|
+
'MXCBAK',
|
1148
|
+
'MXCOVL',
|
1149
|
+
'TDCO',
|
1150
|
+
'TMIN',
|
1151
|
+
'TMAX',
|
1152
|
+
'TPKDB',
|
1153
|
+
'TLRMS',
|
1154
|
+
'TPKR',
|
1155
|
+
'TTR',
|
1156
|
+
'TCR',
|
1157
|
+
'TFL',
|
1158
|
+
'TPKC',
|
1159
|
+
'NDCO',
|
1160
|
+
'NMIN',
|
1161
|
+
'NMAX',
|
1162
|
+
'NPKDB',
|
1163
|
+
'NLRMS',
|
1164
|
+
'NPKR',
|
1165
|
+
'NTR',
|
1166
|
+
'NCR',
|
1167
|
+
'NFL',
|
1168
|
+
'NPKC',
|
1169
|
+
'SEDAVG',
|
1170
|
+
'SEDCNT',
|
1171
|
+
'SEDTOPN',
|
1172
|
+
)
|
1173
|
+
|
1174
|
+
if not (metric in supported_metrics or metric.startswith('MXWER')):
|
1175
|
+
raise ValueError(f'Unsupported metric: {metric}')
|
1176
|
+
|
1177
|
+
if not force:
|
1178
|
+
result = self.read_mixture_data(m_id, metric)
|
1179
|
+
if result is not None:
|
1180
|
+
return result
|
1181
|
+
|
1182
|
+
mixture = self.mixture(m_id)
|
1183
|
+
if mixture is None:
|
1184
|
+
raise SonusAIError(f'Could not find mixture for m_id: {m_id}')
|
1185
|
+
|
1186
|
+
if metric.startswith('MXWER'):
|
1187
|
+
return None
|
1188
|
+
|
1189
|
+
if metric == 'MXSNR':
|
1190
|
+
return self.snrs
|
1191
|
+
|
1192
|
+
if metric == 'MXSSNRAVG':
|
1193
|
+
return None
|
1194
|
+
|
1195
|
+
if metric == 'MXSSNRSTD':
|
1196
|
+
return None
|
1197
|
+
|
1198
|
+
if metric == 'MXSSNRDAVG':
|
1199
|
+
return None
|
1200
|
+
|
1201
|
+
if metric == 'MXSSNRDSTD':
|
1202
|
+
return None
|
1203
|
+
|
1204
|
+
if metric == 'MXPESQ':
|
1205
|
+
return None
|
1206
|
+
|
1207
|
+
if metric == 'MXWSDR':
|
1208
|
+
return None
|
1209
|
+
|
1210
|
+
if metric == 'MXPD':
|
1211
|
+
return None
|
1212
|
+
|
1213
|
+
if metric == 'MXSTOI':
|
1214
|
+
return None
|
1215
|
+
|
1216
|
+
if metric == 'MXCSIG':
|
1217
|
+
return None
|
1218
|
+
|
1219
|
+
if metric == 'MXCBAK':
|
1220
|
+
return None
|
1221
|
+
|
1222
|
+
if metric == 'MXCOVL':
|
1223
|
+
return None
|
1224
|
+
|
1225
|
+
if metric == 'TDCO':
|
1226
|
+
return None
|
1227
|
+
|
1228
|
+
if metric == 'TMIN':
|
1229
|
+
return None
|
1230
|
+
|
1231
|
+
if metric == 'TMAX':
|
1232
|
+
return None
|
1233
|
+
|
1234
|
+
if metric == 'TPKDB':
|
1235
|
+
return None
|
1236
|
+
|
1237
|
+
if metric == 'TLRMS':
|
1238
|
+
return None
|
1239
|
+
|
1240
|
+
if metric == 'TPKR':
|
1241
|
+
return None
|
1242
|
+
|
1243
|
+
if metric == 'TTR':
|
1244
|
+
return None
|
1245
|
+
|
1246
|
+
if metric == 'TCR':
|
1247
|
+
return None
|
1248
|
+
|
1249
|
+
if metric == 'TFL':
|
1250
|
+
return None
|
1251
|
+
|
1252
|
+
if metric == 'TPKC':
|
1253
|
+
return None
|
1254
|
+
|
1255
|
+
if metric == 'NDCO':
|
1256
|
+
return None
|
1257
|
+
|
1258
|
+
if metric == 'NMIN':
|
1259
|
+
return None
|
1260
|
+
|
1261
|
+
if metric == 'NMAX':
|
1262
|
+
return None
|
1263
|
+
|
1264
|
+
if metric == 'NPKDB':
|
1265
|
+
return None
|
1266
|
+
|
1267
|
+
if metric == 'NLRMS':
|
1268
|
+
return None
|
1269
|
+
|
1270
|
+
if metric == 'NPKR':
|
1271
|
+
return None
|
1272
|
+
|
1273
|
+
if metric == 'NTR':
|
1274
|
+
return None
|
1275
|
+
|
1276
|
+
if metric == 'NCR':
|
1277
|
+
return None
|
1278
|
+
|
1279
|
+
if metric == 'NFL':
|
1280
|
+
return None
|
1281
|
+
|
1282
|
+
if metric == 'NPKC':
|
1283
|
+
return None
|
1284
|
+
|
1285
|
+
if metric == 'SEDAVG':
|
1286
|
+
return None
|
1287
|
+
|
1288
|
+
if metric == 'SEDCNT':
|
1289
|
+
return None
|
1170
1290
|
|
1171
|
-
|
1291
|
+
if metric == 'SEDTOPN':
|
1292
|
+
return None
|
1172
1293
|
|
1173
1294
|
|
1174
1295
|
@lru_cache
|
@@ -1206,8 +1327,9 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
|
|
1206
1327
|
from .datatypes import TruthSettings
|
1207
1328
|
|
1208
1329
|
with db() as c:
|
1209
|
-
target = c.execute(
|
1210
|
-
|
1330
|
+
target = c.execute(
|
1331
|
+
"SELECT target_file.name, samples, level_type, speaker_id FROM target_file WHERE ? = target_file.id",
|
1332
|
+
(t_id,)).fetchone()
|
1211
1333
|
|
1212
1334
|
truth_settings: TruthSettings = []
|
1213
1335
|
for ts in c.execute(
|
@@ -1223,7 +1345,8 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
|
|
1223
1345
|
return TargetFile(name=target[0],
|
1224
1346
|
samples=target[1],
|
1225
1347
|
level_type=target[2],
|
1226
|
-
truth_settings=truth_settings
|
1348
|
+
truth_settings=truth_settings,
|
1349
|
+
speaker_id=target[3])
|
1227
1350
|
|
1228
1351
|
|
1229
1352
|
@lru_cache
|
@@ -84,6 +84,7 @@ def apply_impulse_response(audio: AudioT, ir: ImpulseResponseData) -> AudioT:
|
|
84
84
|
:return: Augmented audio
|
85
85
|
"""
|
86
86
|
import math
|
87
|
+
from pathlib import Path
|
87
88
|
import tempfile
|
88
89
|
|
89
90
|
import numpy as np
|
@@ -124,7 +125,9 @@ def apply_impulse_response(audio: AudioT, ir: ImpulseResponseData) -> AudioT:
|
|
124
125
|
except Exception as e:
|
125
126
|
raise SonusAIError(f'Error applying IR: {e}')
|
126
127
|
|
128
|
+
path = Path(temp.name)
|
127
129
|
temp.close()
|
130
|
+
path.unlink()
|
128
131
|
|
129
132
|
# Reset level to previous max value
|
130
133
|
tfm = Transformer()
|
@@ -1,4 +1,7 @@
|
|
1
|
-
|
1
|
+
from pathlib import Path
|
2
|
+
|
3
|
+
|
4
|
+
def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
|
2
5
|
"""Expand shell variables of the forms $var, ${var} and %var%.
|
3
6
|
Unknown variables are left unchanged.
|
4
7
|
|
@@ -25,6 +28,9 @@ def tokenized_expand(name: str | bytes) -> tuple[str, dict[str, str]]:
|
|
25
28
|
if isinstance(name, bytes):
|
26
29
|
name = name.decode('utf-8')
|
27
30
|
|
31
|
+
if isinstance(name, Path):
|
32
|
+
name = name.as_posix()
|
33
|
+
|
28
34
|
name = os.fspath(name)
|
29
35
|
token_map: dict = {}
|
30
36
|
|
@@ -121,6 +127,7 @@ def tokenized_expand(name: str | bytes) -> tuple[str, dict[str, str]]:
|
|
121
127
|
else:
|
122
128
|
result += c
|
123
129
|
index += 1
|
130
|
+
|
124
131
|
return result, token_map
|
125
132
|
|
126
133
|
|
sonusai/mkwav.py
CHANGED
@@ -72,7 +72,7 @@ def _process_mixture(mixid: int) -> None:
|
|
72
72
|
|
73
73
|
from sonusai.mixture import mixture_metadata
|
74
74
|
from sonusai.utils import float_to_int16
|
75
|
-
from sonusai.utils import
|
75
|
+
from sonusai.utils import write_audio
|
76
76
|
|
77
77
|
mixture_filename = join(MP_GLOBAL.mixdb.location, MP_GLOBAL.mixdb.mixtures[mixid].name)
|
78
78
|
mixture_basename = splitext(mixture_filename)[0]
|
@@ -100,11 +100,11 @@ def _process_mixture(mixid: int) -> None:
|
|
100
100
|
if MP_GLOBAL.write_noise:
|
101
101
|
noise = np.array(f['noise'])
|
102
102
|
|
103
|
-
|
103
|
+
write_audio(name=mixture_basename + '_mixture.wav', audio=float_to_int16(mixture))
|
104
104
|
if MP_GLOBAL.write_target:
|
105
|
-
|
105
|
+
write_audio(name=mixture_basename + '_target.wav', audio=float_to_int16(target))
|
106
106
|
if MP_GLOBAL.write_noise:
|
107
|
-
|
107
|
+
write_audio(name=mixture_basename + '_noise.wav', audio=float_to_int16(noise))
|
108
108
|
|
109
109
|
with open(file=mixture_basename + '.txt', mode='w') as f:
|
110
110
|
f.write(mixture_metadata(MP_GLOBAL.mixdb, MP_GLOBAL.mixdb.mixture(mixid)))
|
sonusai/onnx_predict.py
CHANGED
@@ -100,7 +100,7 @@ def main() -> None:
|
|
100
100
|
from sonusai.utils import create_ts_name
|
101
101
|
from sonusai.utils import load_ort_session
|
102
102
|
from sonusai.utils import reshape_inputs
|
103
|
-
from sonusai.utils import
|
103
|
+
from sonusai.utils import write_audio
|
104
104
|
|
105
105
|
mixdb_path = None
|
106
106
|
mixdb = None
|
@@ -201,7 +201,7 @@ def main() -> None:
|
|
201
201
|
predict = np.transpose(predict, [1, 0, 2])
|
202
202
|
predict_audio = get_audio_from_feature(feature=predict, feature_mode=feature_mode)
|
203
203
|
owav_name = splitext(output_fname)[0] + '_predict.wav'
|
204
|
-
|
204
|
+
write_audio(owav_name, predict_audio)
|
205
205
|
|
206
206
|
|
207
207
|
if __name__ == '__main__':
|