sonusai 0.17.3__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/mixture/mixdb.py CHANGED
@@ -1,15 +1,12 @@
1
1
  from functools import cached_property
2
2
  from functools import lru_cache
3
3
  from functools import partial
4
- from pathlib import Path
5
4
  from sqlite3 import Connection
6
5
  from sqlite3 import Cursor
7
6
  from typing import Any
8
7
  from typing import Callable
9
8
  from typing import Optional
10
9
 
11
- from praatio import textgrid
12
- from praatio.utilities.constants import Interval
13
10
  from sonusai.mixture.datatypes import AudioF
14
11
  from sonusai.mixture.datatypes import AudioT
15
12
  from sonusai.mixture.datatypes import AudiosF
@@ -33,7 +30,6 @@ from sonusai.mixture.datatypes import TargetFiles
33
30
  from sonusai.mixture.datatypes import TransformConfig
34
31
  from sonusai.mixture.datatypes import Truth
35
32
  from sonusai.mixture.datatypes import UniversalSNR
36
- from sonusai.mixture.tokenized_shell_vars import tokenized_expand
37
33
 
38
34
 
39
35
  def db_file(location: str, test: bool = False) -> str:
@@ -87,14 +83,12 @@ class MixtureDatabase:
87
83
  def __init__(self, location: str, test: bool = False) -> None:
88
84
  self.location = location
89
85
  self.db = partial(SQLiteContextManager, self.location, test)
90
- self._speaker_metadata_tiers: list[str] = []
91
86
 
92
87
  @cached_property
93
88
  def json(self) -> str:
94
89
  from .datatypes import MixtureDatabaseConfig
95
90
 
96
91
  config = MixtureDatabaseConfig(
97
- asr_manifest=self.asr_manifests,
98
92
  class_balancing=self.class_balancing,
99
93
  class_labels=self.class_labels,
100
94
  class_weights_threshold=self.class_weights_thresholds,
@@ -120,86 +114,6 @@ class MixtureDatabase:
120
114
  with open(file=json_name, mode='w') as file:
121
115
  file.write(self.json)
122
116
 
123
- def target_asr_data(self, t_id: int) -> str | None:
124
- """Get the ASR data for the given target ID
125
-
126
- :param t_id: Target ID
127
- :return: ASR text or None
128
- """
129
- from .tokenized_shell_vars import tokenized_expand
130
-
131
- name, _ = tokenized_expand(self.target_file(t_id).name)
132
- return self.asr_manifest_data.get(name, None)
133
-
134
- def mixture_asr_data(self, m_id: int) -> list[str | None]:
135
- """Get the ASR data for the given mixid
136
-
137
- :param m_id: Zero-based mixture ID
138
- :return: List of ASR text or None
139
- """
140
- return [self.target_asr_data(target.file_id) for target in self.mixture(m_id).targets]
141
-
142
- @cached_property
143
- def asr_manifest_data(self) -> dict[str, str]:
144
- """Get ASR data
145
-
146
- Each line of a manifest file should be in the following format:
147
-
148
- {"audio_filepath": "/path/to/audio.wav", "text": "the transcription of the utterance", "duration": 23.147}
149
-
150
- The audio_filepath field should provide an absolute path to the audio file corresponding to the utterance. The
151
- text field should contain the full transcript for the utterance, and the duration field should reflect the
152
- duration of the utterance in seconds.
153
-
154
- Each entry in the manifest (describing one audio file) should be bordered by '{' and '}' and must be contained
155
- on one line. The fields that describe the file should be separated by commas, and have the form
156
- "field_name": value, as shown above.
157
-
158
- Since the manifest specifies the path for each utterance, the audio files do not have to be located in the same
159
- directory as the manifest, or even in any specific directory structure.
160
-
161
- The manifest dictionary consists of key/value pairs where the keys are target file names and the values are ASR
162
- text.
163
- """
164
- import json
165
-
166
- from sonusai import SonusAIError
167
- from .tokenized_shell_vars import tokenized_expand
168
-
169
- expected_keys = ['audio_filepath', 'text', 'duration']
170
-
171
- def _error_preamble(e_name: str, e_line_num: int) -> str:
172
- return f'Invalid entry in ASR manifest {e_name} line {e_line_num}'
173
-
174
- asr_manifest_data: dict[str, str] = {}
175
-
176
- for name in self.asr_manifests:
177
- expanded_name, _ = tokenized_expand(name)
178
- with open(file=expanded_name, mode='r') as f:
179
- line_num = 1
180
- for line in f:
181
- result = json.loads(line.strip())
182
-
183
- for key in expected_keys:
184
- if key not in result:
185
- SonusAIError(f'{_error_preamble(name, line_num)}: missing field "{key}"')
186
-
187
- for key in result.keys():
188
- if key not in expected_keys:
189
- SonusAIError(f'{_error_preamble(name, line_num)}: unknown field "{key}"')
190
-
191
- key, _ = tokenized_expand(result['audio_filepath'])
192
- value = result['text']
193
-
194
- if key in asr_manifest_data:
195
- SonusAIError(f'{_error_preamble(name, line_num)}: entry already exists')
196
-
197
- asr_manifest_data[key] = value
198
-
199
- line_num += 1
200
-
201
- return asr_manifest_data
202
-
203
117
  @cached_property
204
118
  def fg_config(self) -> FeatureGeneratorConfig:
205
119
  return FeatureGeneratorConfig(feature_mode=self.feature,
@@ -292,14 +206,14 @@ class MixtureDatabase:
292
206
  def feature_step_samples(self) -> int:
293
207
  return self.ft_config.R * self.fg_decimation * self.fg_step
294
208
 
295
- def total_samples(self, mixids: GeneralizedIDs = '*') -> int:
296
- return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(mixids)])
209
+ def total_samples(self, m_ids: GeneralizedIDs = '*') -> int:
210
+ return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(m_ids)])
297
211
 
298
- def total_transform_frames(self, mixids: GeneralizedIDs = '*') -> int:
299
- return self.total_samples(mixids) // self.ft_config.R
212
+ def total_transform_frames(self, m_ids: GeneralizedIDs = '*') -> int:
213
+ return self.total_samples(m_ids) // self.ft_config.R
300
214
 
301
- def total_feature_frames(self, mixids: GeneralizedIDs = '*') -> int:
302
- return self.total_samples(mixids) // self.feature_step_samples
215
+ def total_feature_frames(self, m_ids: GeneralizedIDs = '*') -> int:
216
+ return self.total_samples(m_ids) // self.feature_step_samples
303
217
 
304
218
  def mixture_transform_frames(self, samples: int) -> int:
305
219
  return samples // self.ft_config.R
@@ -307,24 +221,15 @@ class MixtureDatabase:
307
221
  def mixture_feature_frames(self, samples: int) -> int:
308
222
  return samples // self.feature_step_samples
309
223
 
310
- def mixids_to_list(self, mixids: Optional[GeneralizedIDs] = None) -> list[int]:
224
+ def mixids_to_list(self, m_ids: Optional[GeneralizedIDs] = None) -> list[int]:
311
225
  """Resolve generalized mixture IDs to a list of integers
312
226
 
313
- :param mixids: Generalized mixture IDs
227
+ :param m_ids: Generalized mixture IDs
314
228
  :return: List of mixture ID integers
315
229
  """
316
230
  from .helpers import generic_ids_to_list
317
231
 
318
- return generic_ids_to_list(self.num_mixtures, mixids)
319
-
320
- @cached_property
321
- def asr_manifests(self) -> list[str]:
322
- """Get ASR manifests from db
323
-
324
- :return: ASR manifests
325
- """
326
- with self.db() as c:
327
- return [str(item[0]) for item in c.execute("SELECT asr_manifest.manifest FROM asr_manifest").fetchall()]
232
+ return generic_ids_to_list(self.num_mixtures, m_ids)
328
233
 
329
234
  @cached_property
330
235
  def class_labels(self) -> list[str]:
@@ -407,7 +312,8 @@ class MixtureDatabase:
407
312
 
408
313
  with self.db() as c:
409
314
  target_files: TargetFiles = []
410
- for target in c.execute("SELECT target_file.name, samples, level_type, id FROM target_file").fetchall():
315
+ for target in c.execute(
316
+ "SELECT target_file.name, samples, level_type, id, speaker_id FROM target_file").fetchall():
411
317
  truth_settings: TruthSettings = []
412
318
  for ts in c.execute(
413
319
  "SELECT truth_setting.setting " +
@@ -422,7 +328,8 @@ class MixtureDatabase:
422
328
  target_files.append(TargetFile(name=target[0],
423
329
  samples=target[1],
424
330
  level_type=target[2],
425
- truth_settings=truth_settings))
331
+ truth_settings=truth_settings,
332
+ speaker_id=target[4]))
426
333
  return target_files
427
334
 
428
335
  @cached_property
@@ -719,7 +626,7 @@ class MixtureDatabase:
719
626
 
720
627
  :param m_id: Zero-based mixture ID
721
628
  :param targets: List of augmented target audio data (one per target in the mixup)
722
- :param target: Augmented target audio for the given mixid
629
+ :param target: Augmented target audio for the given m_id
723
630
  :param force: Force computing data from original sources regardless of whether cached data exists
724
631
  :return: Augmented target transform data
725
632
  """
@@ -1077,97 +984,312 @@ class MixtureDatabase:
1077
984
  return class_count
1078
985
 
1079
986
  @cached_property
1080
- def _speech_metadata(self) -> dict[str, dict[str, SpeechMetadata]]:
1081
- """Speech metadata is a nested dictionary.
987
+ def speaker_metadata_tiers(self) -> list[str]:
988
+ import json
1082
989
 
1083
- data['target_file_name'] = { 'tier': SpeechMetadata, ... }
1084
- """
1085
- data: dict[str, dict[str, SpeechMetadata]] = {}
1086
- for file in self.target_files:
1087
- data[file.name] = {}
1088
- file_name, _ = tokenized_expand(file.name)
1089
- tg_file = Path(file_name).with_suffix('.TextGrid')
1090
- if tg_file.exists():
1091
- tg = textgrid.openTextgrid(str(tg_file), includeEmptyIntervals=False)
1092
- for tier in tg.tierNames:
1093
- entries = tg.getTier(tier).entries
1094
- if len(entries) > 1:
1095
- data[file.name][tier] = entries
1096
- else:
1097
- data[file.name][tier] = entries[0].label
990
+ with self.db() as c:
991
+ return json.loads(c.execute("SELECT speaker_metadata_tiers FROM top WHERE 1 = id").fetchone()[0])
1098
992
 
1099
- return data
993
+ @cached_property
994
+ def textgrid_metadata_tiers(self) -> list[str]:
995
+ import json
996
+
997
+ with self.db() as c:
998
+ return json.loads(c.execute("SELECT textgrid_metadata_tiers FROM top WHERE 1 = id").fetchone()[0])
1100
999
 
1101
1000
  @cached_property
1102
1001
  def speech_metadata_tiers(self) -> list[str]:
1103
- return sorted(list(set([key for value in self._speech_metadata.values() for key in value.keys()])))
1002
+ return sorted(set(self.speaker_metadata_tiers + self.textgrid_metadata_tiers))
1104
1003
 
1105
- def speech_metadata_all(self, tier: str) -> list[SpeechMetadata]:
1106
- results = sorted(
1107
- set([value.get(tier) for value in self._speech_metadata.values() if isinstance(value.get(tier), str)]))
1108
- return results
1004
+ def speaker(self, speaker_id: int | None, tier: str) -> Optional[str]:
1005
+ if speaker_id is None:
1006
+ return None
1007
+
1008
+ with self.db() as c:
1009
+ data = c.execute(f'SELECT {tier} FROM speaker WHERE ? = id', (speaker_id,)).fetchone()
1010
+ if data is None:
1011
+ return None
1012
+ if data[0] is None:
1013
+ return None
1014
+ return data[0]
1015
+
1016
+ def speech_metadata(self, tier: str) -> list[str]:
1017
+ from .helpers import get_textgrid_tier_from_target_file
1018
+
1019
+ results: set[str] = set()
1020
+ if tier in self.textgrid_metadata_tiers:
1021
+ for target_file in self.target_files:
1022
+ data = get_textgrid_tier_from_target_file(target_file.name, tier)
1023
+ if data is None:
1024
+ continue
1025
+ if isinstance(data, list):
1026
+ for item in data:
1027
+ results.add(item.label)
1028
+ else:
1029
+ results.add(data)
1030
+ elif tier in self.speaker_metadata_tiers:
1031
+ for target_file in self.target_files:
1032
+ data = self.speaker(target_file.speaker_id, tier)
1033
+ if data is not None:
1034
+ results.add(data)
1035
+
1036
+ return sorted(results)
1037
+
1038
+ def mixture_speech_metadata(self, mixid: int, tier: str) -> list[SpeechMetadata]:
1039
+ from praatio.utilities.constants import Interval
1040
+
1041
+ from .helpers import get_textgrid_tier_from_target_file
1042
+
1043
+ results: list[SpeechMetadata] = []
1044
+ is_textgrid = tier in self.textgrid_metadata_tiers
1045
+ if is_textgrid:
1046
+ for target in self.mixture(mixid).targets:
1047
+ data = get_textgrid_tier_from_target_file(self.target_file(target.file_id).name, tier)
1048
+ if data is not None:
1049
+ if isinstance(data, list):
1050
+ # Check for tempo augmentation and adjust Interval start and end data as needed
1051
+ entries = []
1052
+ for entry in data:
1053
+ if target.augmentation.tempo is not None:
1054
+ entries.append(Interval(entry.start / target.augmentation.tempo,
1055
+ entry.end / target.augmentation.tempo,
1056
+ entry.label))
1057
+ else:
1058
+ entries.append(entry)
1059
+ results.append(entries)
1060
+ else:
1061
+ results.append(data)
1062
+ else:
1063
+ for target in self.mixture(mixid).targets:
1064
+ data = self.speaker(self.target_file(target.file_id).speaker_id, tier)
1065
+ if data is not None:
1066
+ results.append(data)
1067
+
1068
+ return sorted(results)
1109
1069
 
1110
1070
  def mixids_for_speech_metadata(self,
1111
1071
  tier: str,
1112
- value: str,
1072
+ value: str | None,
1113
1073
  predicate: Callable[[str], bool] = None) -> list[int]:
1114
- """Get a list of mixids for the given speech metadata tier.
1074
+ """Get a list of mixture IDs for the given speech metadata tier.
1115
1075
 
1116
- If 'predicate' is None, then include mixids whose tier values are equal to the given 'value'. If 'predicate' is
1117
- not None, then ignore 'value' and use the given callable to determine which entries to include.
1076
+ If 'predicate' is None, then include mixture IDs whose tier values are equal to the given 'value'.
1077
+ If 'predicate' is not None, then ignore 'value' and use the given callable to determine which entries
1078
+ to include.
1118
1079
 
1119
1080
  Examples:
1081
+ >>> mixdb = MixtureDatabase('/mixdb_location')
1120
1082
 
1121
1083
  >>> mixids = mixdb.mixids_for_speech_metadata('speaker_id', 'TIMIT_ARC0')
1122
- Get mixids for mixtures with speakers whose speaker_ids are 'TIMIT_ARC0'.
1084
+ Get mixutre IDs for mixtures with speakers whose speaker_ids are 'TIMIT_ARC0'.
1123
1085
 
1124
1086
  >>> mixids = mixdb.mixids_for_speech_metadata('age', '', lambda x: int(x) < 25)
1125
- Get mixids for mixtures with speakers whose ages are less than 25.
1087
+ Get mixture IDs for mixtures with speakers whose ages are less than 25.
1126
1088
 
1127
1089
  >>> mixids = mixdb.mixids_for_speech_metadata('dialect', '', lambda x: x in ['New York City', 'Northern'])
1128
- Get mixids for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
1090
+ Get mixture IDs for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
1129
1091
  """
1092
+ from .helpers import get_textgrid_tier_from_target_file
1093
+
1130
1094
  if predicate is None:
1131
- def predicate(x: str) -> bool:
1095
+ def predicate(x: str | None) -> bool:
1132
1096
  return x == value
1133
1097
 
1134
1098
  # First get list of matching target files
1135
- target_files = [k for k, v in self._speech_metadata.items() if
1136
- isinstance(v.get(tier), str) and predicate(str(v.get(tier)))]
1099
+ target_files: list[str] = []
1100
+ is_textgrid = tier in self.textgrid_metadata_tiers
1101
+ for target_file in self.target_files:
1102
+ if is_textgrid:
1103
+ metadata = get_textgrid_tier_from_target_file(target_file.name, tier)
1104
+ else:
1105
+ metadata = self.speaker(target_file.speaker_id, tier)
1137
1106
 
1138
- # Next get list of mixids that contain those target files
1139
- mixids: list[int] = []
1140
- for mixid in self.mixids_to_list():
1141
- mixid_target_files = [self.target_file(target.file_id).name for target in self.mixture(mixid).targets]
1107
+ if not isinstance(metadata, list) and predicate(metadata):
1108
+ target_files.append(target_file.name)
1109
+
1110
+ # Next get list of mixture IDs that contain those target files
1111
+ m_ids: list[int] = []
1112
+ for m_id in self.mixids_to_list():
1113
+ mixid_target_files = [self.target_file(target.file_id).name for target in self.mixture(m_id).targets]
1142
1114
  for mixid_target_file in mixid_target_files:
1143
1115
  if mixid_target_file in target_files:
1144
- mixids.append(mixid)
1116
+ m_ids.append(m_id)
1145
1117
 
1146
- # Return sorted, unique list of mixids
1147
- return sorted(list(set(mixids)))
1118
+ # Return sorted, unique list of mixture IDs
1119
+ return sorted(list(set(m_ids)))
1148
1120
 
1149
- def get_speech_metadata(self, mixid: int, tier: str) -> list[SpeechMetadata]:
1150
- results: list[SpeechMetadata] = []
1151
- for target in self.mixture(mixid).targets:
1152
- data = self._speech_metadata[self.target_file(target.file_id).name].get(tier)
1121
+ def mixture_all_speech_metadata(self, m_id: int) -> list[dict[str, SpeechMetadata]]:
1122
+ from .helpers import mixture_all_speech_metadata
1153
1123
 
1154
- if data is None:
1155
- results.append(None)
1156
- elif isinstance(data, list):
1157
- # Check for tempo augmentation and adjust Interval start and end data as needed
1158
- entries = []
1159
- for entry in data:
1160
- if target.augmentation.tempo is not None:
1161
- entries.append(Interval(entry.start / target.augmentation.tempo,
1162
- entry.end / target.augmentation.tempo,
1163
- entry.label))
1164
- else:
1165
- entries.append(entry)
1124
+ return mixture_all_speech_metadata(self, self.mixture(m_id))
1166
1125
 
1167
- else:
1168
- results.append(data)
1126
+ def mixture_metric(self, m_id: int, metric: str, force: bool = False) -> Any:
1127
+ """Get metric data for the given mixture ID
1128
+
1129
+ :param m_id: Zero-based mixture ID
1130
+ :param metric: Metric data to retrieve
1131
+ :param force: Force computing data from original sources regardless of whether cached data exists
1132
+ :return: Metric data
1133
+ """
1134
+ from sonusai import SonusAIError
1135
+
1136
+ supported_metrics = (
1137
+ 'MXSNR',
1138
+ 'MXSSNRAVG',
1139
+ 'MXSSNRSTD',
1140
+ 'MXSSNRDAVG',
1141
+ 'MXSSNRDSTD',
1142
+ 'MXPESQ',
1143
+ 'MXWSDR',
1144
+ 'MXPD',
1145
+ 'MXSTOI',
1146
+ 'MXCSIG',
1147
+ 'MXCBAK',
1148
+ 'MXCOVL',
1149
+ 'TDCO',
1150
+ 'TMIN',
1151
+ 'TMAX',
1152
+ 'TPKDB',
1153
+ 'TLRMS',
1154
+ 'TPKR',
1155
+ 'TTR',
1156
+ 'TCR',
1157
+ 'TFL',
1158
+ 'TPKC',
1159
+ 'NDCO',
1160
+ 'NMIN',
1161
+ 'NMAX',
1162
+ 'NPKDB',
1163
+ 'NLRMS',
1164
+ 'NPKR',
1165
+ 'NTR',
1166
+ 'NCR',
1167
+ 'NFL',
1168
+ 'NPKC',
1169
+ 'SEDAVG',
1170
+ 'SEDCNT',
1171
+ 'SEDTOPN',
1172
+ )
1173
+
1174
+ if not (metric in supported_metrics or metric.startswith('MXWER')):
1175
+ raise ValueError(f'Unsupported metric: {metric}')
1176
+
1177
+ if not force:
1178
+ result = self.read_mixture_data(m_id, metric)
1179
+ if result is not None:
1180
+ return result
1181
+
1182
+ mixture = self.mixture(m_id)
1183
+ if mixture is None:
1184
+ raise SonusAIError(f'Could not find mixture for m_id: {m_id}')
1185
+
1186
+ if metric.startswith('MXWER'):
1187
+ return None
1188
+
1189
+ if metric == 'MXSNR':
1190
+ return self.snrs
1191
+
1192
+ if metric == 'MXSSNRAVG':
1193
+ return None
1194
+
1195
+ if metric == 'MXSSNRSTD':
1196
+ return None
1197
+
1198
+ if metric == 'MXSSNRDAVG':
1199
+ return None
1200
+
1201
+ if metric == 'MXSSNRDSTD':
1202
+ return None
1203
+
1204
+ if metric == 'MXPESQ':
1205
+ return None
1206
+
1207
+ if metric == 'MXWSDR':
1208
+ return None
1209
+
1210
+ if metric == 'MXPD':
1211
+ return None
1212
+
1213
+ if metric == 'MXSTOI':
1214
+ return None
1215
+
1216
+ if metric == 'MXCSIG':
1217
+ return None
1218
+
1219
+ if metric == 'MXCBAK':
1220
+ return None
1221
+
1222
+ if metric == 'MXCOVL':
1223
+ return None
1224
+
1225
+ if metric == 'TDCO':
1226
+ return None
1227
+
1228
+ if metric == 'TMIN':
1229
+ return None
1230
+
1231
+ if metric == 'TMAX':
1232
+ return None
1233
+
1234
+ if metric == 'TPKDB':
1235
+ return None
1236
+
1237
+ if metric == 'TLRMS':
1238
+ return None
1239
+
1240
+ if metric == 'TPKR':
1241
+ return None
1242
+
1243
+ if metric == 'TTR':
1244
+ return None
1245
+
1246
+ if metric == 'TCR':
1247
+ return None
1248
+
1249
+ if metric == 'TFL':
1250
+ return None
1251
+
1252
+ if metric == 'TPKC':
1253
+ return None
1254
+
1255
+ if metric == 'NDCO':
1256
+ return None
1257
+
1258
+ if metric == 'NMIN':
1259
+ return None
1260
+
1261
+ if metric == 'NMAX':
1262
+ return None
1263
+
1264
+ if metric == 'NPKDB':
1265
+ return None
1266
+
1267
+ if metric == 'NLRMS':
1268
+ return None
1269
+
1270
+ if metric == 'NPKR':
1271
+ return None
1272
+
1273
+ if metric == 'NTR':
1274
+ return None
1275
+
1276
+ if metric == 'NCR':
1277
+ return None
1278
+
1279
+ if metric == 'NFL':
1280
+ return None
1281
+
1282
+ if metric == 'NPKC':
1283
+ return None
1284
+
1285
+ if metric == 'SEDAVG':
1286
+ return None
1287
+
1288
+ if metric == 'SEDCNT':
1289
+ return None
1169
1290
 
1170
- return results
1291
+ if metric == 'SEDTOPN':
1292
+ return None
1171
1293
 
1172
1294
 
1173
1295
  @lru_cache
@@ -1205,8 +1327,9 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
1205
1327
  from .datatypes import TruthSettings
1206
1328
 
1207
1329
  with db() as c:
1208
- target = c.execute("SELECT target_file.name, samples, level_type FROM target_file WHERE ? = target_file.id",
1209
- (t_id,)).fetchone()
1330
+ target = c.execute(
1331
+ "SELECT target_file.name, samples, level_type, speaker_id FROM target_file WHERE ? = target_file.id",
1332
+ (t_id,)).fetchone()
1210
1333
 
1211
1334
  truth_settings: TruthSettings = []
1212
1335
  for ts in c.execute(
@@ -1222,7 +1345,8 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
1222
1345
  return TargetFile(name=target[0],
1223
1346
  samples=target[1],
1224
1347
  level_type=target[2],
1225
- truth_settings=truth_settings)
1348
+ truth_settings=truth_settings,
1349
+ speaker_id=target[3])
1226
1350
 
1227
1351
 
1228
1352
  @lru_cache
@@ -1,4 +1,7 @@
1
- def tokenized_expand(name: str | bytes) -> tuple[str, dict[str, str]]:
1
+ from pathlib import Path
2
+
3
+
4
+ def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
2
5
  """Expand shell variables of the forms $var, ${var} and %var%.
3
6
  Unknown variables are left unchanged.
4
7
 
@@ -25,6 +28,9 @@ def tokenized_expand(name: str | bytes) -> tuple[str, dict[str, str]]:
25
28
  if isinstance(name, bytes):
26
29
  name = name.decode('utf-8')
27
30
 
31
+ if isinstance(name, Path):
32
+ name = name.as_posix()
33
+
28
34
  name = os.fspath(name)
29
35
  token_map: dict = {}
30
36
 
@@ -121,6 +127,7 @@ def tokenized_expand(name: str | bytes) -> tuple[str, dict[str, str]]:
121
127
  else:
122
128
  result += c
123
129
  index += 1
130
+
124
131
  return result, token_map
125
132
 
126
133
 
@@ -6,37 +6,19 @@ from praatio.utilities.constants import Interval
6
6
  from .types import TimeAlignedType
7
7
 
8
8
 
9
- def _get_duration(name: str) -> float:
10
- from pydub import AudioSegment
11
-
12
- from sonusai import SonusAIError
13
-
14
- try:
15
- return AudioSegment.from_file(name).duration_seconds
16
- except Exception as e:
17
- raise SonusAIError(f'Error reading {name}: {e}')
18
-
19
-
20
9
  def create_textgrid(prompt: Path,
21
- speaker_id: str,
22
- speaker: dict,
23
10
  output_dir: Path,
24
11
  text: TimeAlignedType = None,
25
12
  words: list[TimeAlignedType] = None,
26
13
  phonemes: list[TimeAlignedType] = None) -> None:
27
- if text is not None or words is not None or phonemes is not None:
28
- min_t, max_t = _get_min_max({'phonemes': phonemes,
29
- 'text': [text],
30
- 'words': words})
31
- else:
32
- min_t = 0
33
- max_t = _get_duration(str(prompt))
14
+ if text is None and words is None and phonemes is None:
15
+ return
34
16
 
35
- tg = textgrid.Textgrid()
17
+ min_t, max_t = _get_min_max({'phonemes': phonemes,
18
+ 'text': [text],
19
+ 'words': words})
36
20
 
37
- tg.addTier(textgrid.IntervalTier('speaker_id', [Interval(min_t, max_t, speaker_id)], min_t, max_t))
38
- for tier in speaker.keys():
39
- tg.addTier(textgrid.IntervalTier(tier, [Interval(min_t, max_t, str(speaker[tier]))], min_t, max_t))
21
+ tg = textgrid.Textgrid()
40
22
 
41
23
  if text is not None:
42
24
  entries = [Interval(text.start, text.end, text.text)]