sonusai 0.17.3__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/mixture/mixdb.py CHANGED
@@ -1,15 +1,12 @@
1
1
  from functools import cached_property
2
2
  from functools import lru_cache
3
3
  from functools import partial
4
- from pathlib import Path
5
4
  from sqlite3 import Connection
6
5
  from sqlite3 import Cursor
7
6
  from typing import Any
8
7
  from typing import Callable
9
8
  from typing import Optional
10
9
 
11
- from praatio import textgrid
12
- from praatio.utilities.constants import Interval
13
10
  from sonusai.mixture.datatypes import AudioF
14
11
  from sonusai.mixture.datatypes import AudioT
15
12
  from sonusai.mixture.datatypes import AudiosF
@@ -33,7 +30,6 @@ from sonusai.mixture.datatypes import TargetFiles
33
30
  from sonusai.mixture.datatypes import TransformConfig
34
31
  from sonusai.mixture.datatypes import Truth
35
32
  from sonusai.mixture.datatypes import UniversalSNR
36
- from sonusai.mixture.tokenized_shell_vars import tokenized_expand
37
33
 
38
34
 
39
35
  def db_file(location: str, test: bool = False) -> str:
@@ -87,14 +83,12 @@ class MixtureDatabase:
87
83
  def __init__(self, location: str, test: bool = False) -> None:
88
84
  self.location = location
89
85
  self.db = partial(SQLiteContextManager, self.location, test)
90
- self._speaker_metadata_tiers: list[str] = []
91
86
 
92
87
  @cached_property
93
88
  def json(self) -> str:
94
89
  from .datatypes import MixtureDatabaseConfig
95
90
 
96
91
  config = MixtureDatabaseConfig(
97
- asr_manifest=self.asr_manifests,
98
92
  class_balancing=self.class_balancing,
99
93
  class_labels=self.class_labels,
100
94
  class_weights_threshold=self.class_weights_thresholds,
@@ -120,86 +114,6 @@ class MixtureDatabase:
120
114
  with open(file=json_name, mode='w') as file:
121
115
  file.write(self.json)
122
116
 
123
- def target_asr_data(self, t_id: int) -> str | None:
124
- """Get the ASR data for the given target ID
125
-
126
- :param t_id: Target ID
127
- :return: ASR text or None
128
- """
129
- from .tokenized_shell_vars import tokenized_expand
130
-
131
- name, _ = tokenized_expand(self.target_file(t_id).name)
132
- return self.asr_manifest_data.get(name, None)
133
-
134
- def mixture_asr_data(self, m_id: int) -> list[str | None]:
135
- """Get the ASR data for the given mixid
136
-
137
- :param m_id: Zero-based mixture ID
138
- :return: List of ASR text or None
139
- """
140
- return [self.target_asr_data(target.file_id) for target in self.mixture(m_id).targets]
141
-
142
- @cached_property
143
- def asr_manifest_data(self) -> dict[str, str]:
144
- """Get ASR data
145
-
146
- Each line of a manifest file should be in the following format:
147
-
148
- {"audio_filepath": "/path/to/audio.wav", "text": "the transcription of the utterance", "duration": 23.147}
149
-
150
- The audio_filepath field should provide an absolute path to the audio file corresponding to the utterance. The
151
- text field should contain the full transcript for the utterance, and the duration field should reflect the
152
- duration of the utterance in seconds.
153
-
154
- Each entry in the manifest (describing one audio file) should be bordered by '{' and '}' and must be contained
155
- on one line. The fields that describe the file should be separated by commas, and have the form
156
- "field_name": value, as shown above.
157
-
158
- Since the manifest specifies the path for each utterance, the audio files do not have to be located in the same
159
- directory as the manifest, or even in any specific directory structure.
160
-
161
- The manifest dictionary consists of key/value pairs where the keys are target file names and the values are ASR
162
- text.
163
- """
164
- import json
165
-
166
- from sonusai import SonusAIError
167
- from .tokenized_shell_vars import tokenized_expand
168
-
169
- expected_keys = ['audio_filepath', 'text', 'duration']
170
-
171
- def _error_preamble(e_name: str, e_line_num: int) -> str:
172
- return f'Invalid entry in ASR manifest {e_name} line {e_line_num}'
173
-
174
- asr_manifest_data: dict[str, str] = {}
175
-
176
- for name in self.asr_manifests:
177
- expanded_name, _ = tokenized_expand(name)
178
- with open(file=expanded_name, mode='r') as f:
179
- line_num = 1
180
- for line in f:
181
- result = json.loads(line.strip())
182
-
183
- for key in expected_keys:
184
- if key not in result:
185
- SonusAIError(f'{_error_preamble(name, line_num)}: missing field "{key}"')
186
-
187
- for key in result.keys():
188
- if key not in expected_keys:
189
- SonusAIError(f'{_error_preamble(name, line_num)}: unknown field "{key}"')
190
-
191
- key, _ = tokenized_expand(result['audio_filepath'])
192
- value = result['text']
193
-
194
- if key in asr_manifest_data:
195
- SonusAIError(f'{_error_preamble(name, line_num)}: entry already exists')
196
-
197
- asr_manifest_data[key] = value
198
-
199
- line_num += 1
200
-
201
- return asr_manifest_data
202
-
203
117
  @cached_property
204
118
  def fg_config(self) -> FeatureGeneratorConfig:
205
119
  return FeatureGeneratorConfig(feature_mode=self.feature,
@@ -215,32 +129,32 @@ class MixtureDatabase:
215
129
  @cached_property
216
130
  def num_classes(self) -> int:
217
131
  with self.db() as c:
218
- return int(c.execute("SELECT top.num_classes from top").fetchone()[0])
132
+ return int(c.execute("SELECT top.num_classes FROM top").fetchone()[0])
219
133
 
220
134
  @cached_property
221
135
  def truth_mutex(self) -> bool:
222
136
  with self.db() as c:
223
- return bool(c.execute("SELECT top.truth_mutex from top").fetchone()[0])
137
+ return bool(c.execute("SELECT top.truth_mutex FROM top").fetchone()[0])
224
138
 
225
139
  @cached_property
226
140
  def truth_reduction_function(self) -> str:
227
141
  with self.db() as c:
228
- return str(c.execute("SELECT top.truth_reduction_function from top").fetchone()[0])
142
+ return str(c.execute("SELECT top.truth_reduction_function FROM top").fetchone()[0])
229
143
 
230
144
  @cached_property
231
145
  def noise_mix_mode(self) -> str:
232
146
  with self.db() as c:
233
- return str(c.execute("SELECT top.noise_mix_mode from top").fetchone()[0])
147
+ return str(c.execute("SELECT top.noise_mix_mode FROM top").fetchone()[0])
234
148
 
235
149
  @cached_property
236
150
  def class_balancing(self) -> bool:
237
151
  with self.db() as c:
238
- return bool(c.execute("SELECT top.class_balancing from top").fetchone()[0])
152
+ return bool(c.execute("SELECT top.class_balancing FROM top").fetchone()[0])
239
153
 
240
154
  @cached_property
241
155
  def feature(self) -> str:
242
156
  with self.db() as c:
243
- return str(c.execute("SELECT top.feature from top").fetchone()[0])
157
+ return str(c.execute("SELECT top.feature FROM top").fetchone()[0])
244
158
 
245
159
  @cached_property
246
160
  def fg_decimation(self) -> int:
@@ -292,14 +206,14 @@ class MixtureDatabase:
292
206
  def feature_step_samples(self) -> int:
293
207
  return self.ft_config.R * self.fg_decimation * self.fg_step
294
208
 
295
- def total_samples(self, mixids: GeneralizedIDs = '*') -> int:
296
- return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(mixids)])
209
+ def total_samples(self, m_ids: GeneralizedIDs = '*') -> int:
210
+ return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(m_ids)])
297
211
 
298
- def total_transform_frames(self, mixids: GeneralizedIDs = '*') -> int:
299
- return self.total_samples(mixids) // self.ft_config.R
212
+ def total_transform_frames(self, m_ids: GeneralizedIDs = '*') -> int:
213
+ return self.total_samples(m_ids) // self.ft_config.R
300
214
 
301
- def total_feature_frames(self, mixids: GeneralizedIDs = '*') -> int:
302
- return self.total_samples(mixids) // self.feature_step_samples
215
+ def total_feature_frames(self, m_ids: GeneralizedIDs = '*') -> int:
216
+ return self.total_samples(m_ids) // self.feature_step_samples
303
217
 
304
218
  def mixture_transform_frames(self, samples: int) -> int:
305
219
  return samples // self.ft_config.R
@@ -307,24 +221,15 @@ class MixtureDatabase:
307
221
  def mixture_feature_frames(self, samples: int) -> int:
308
222
  return samples // self.feature_step_samples
309
223
 
310
- def mixids_to_list(self, mixids: Optional[GeneralizedIDs] = None) -> list[int]:
224
+ def mixids_to_list(self, m_ids: Optional[GeneralizedIDs] = None) -> list[int]:
311
225
  """Resolve generalized mixture IDs to a list of integers
312
226
 
313
- :param mixids: Generalized mixture IDs
227
+ :param m_ids: Generalized mixture IDs
314
228
  :return: List of mixture ID integers
315
229
  """
316
230
  from .helpers import generic_ids_to_list
317
231
 
318
- return generic_ids_to_list(self.num_mixtures, mixids)
319
-
320
- @cached_property
321
- def asr_manifests(self) -> list[str]:
322
- """Get ASR manifests from db
323
-
324
- :return: ASR manifests
325
- """
326
- with self.db() as c:
327
- return [str(item[0]) for item in c.execute("SELECT asr_manifest.manifest FROM asr_manifest").fetchall()]
232
+ return generic_ids_to_list(self.num_mixtures, m_ids)
328
233
 
329
234
  @cached_property
330
235
  def class_labels(self) -> list[str]:
@@ -377,14 +282,16 @@ class MixtureDatabase:
377
282
 
378
283
  :return: Spectral masks
379
284
  """
285
+ from .db_datatypes import SpectralMaskRecord
286
+
380
287
  with self.db() as c:
381
- results = c.execute(
382
- "SELECT spectral_mask.f_max_width, f_num, t_max_width, t_num, t_max_percent FROM spectral_mask")
383
- return [SpectralMask(f_max_width=spectral_mask[0],
384
- f_num=spectral_mask[1],
385
- t_max_width=spectral_mask[2],
386
- t_num=spectral_mask[3],
387
- t_max_percent=spectral_mask[4]) for spectral_mask in results.fetchall()]
288
+ spectral_masks = [SpectralMaskRecord(*result) for result in
289
+ c.execute("SELECT * FROM spectral_mask").fetchall()]
290
+ return [SpectralMask(f_max_width=spectral_mask.f_max_width,
291
+ f_num=spectral_mask.f_num,
292
+ t_max_width=spectral_mask.t_max_width,
293
+ t_num=spectral_mask.t_num,
294
+ t_max_percent=spectral_mask.t_max_percent) for spectral_mask in spectral_masks]
388
295
 
389
296
  def spectral_mask(self, sm_id: int) -> SpectralMask:
390
297
  """Get spectral mask with ID from db
@@ -404,25 +311,29 @@ class MixtureDatabase:
404
311
 
405
312
  from .datatypes import TruthSetting
406
313
  from .datatypes import TruthSettings
314
+ from .db_datatypes import TargetFileRecord
407
315
 
408
316
  with self.db() as c:
409
317
  target_files: TargetFiles = []
410
- for target in c.execute("SELECT target_file.name, samples, level_type, id FROM target_file").fetchall():
318
+ target_file_records = [TargetFileRecord(*result) for result in
319
+ c.execute("SELECT * FROM target_file").fetchall()]
320
+ for target_file_record in target_file_records:
411
321
  truth_settings: TruthSettings = []
412
- for ts in c.execute(
322
+ for truth_setting_records in c.execute(
413
323
  "SELECT truth_setting.setting " +
414
324
  "FROM truth_setting, target_file_truth_setting " +
415
325
  "WHERE ? = target_file_truth_setting.target_file_id " +
416
326
  "AND truth_setting.id = target_file_truth_setting.truth_setting_id",
417
- (target[3],)).fetchall():
418
- entry = json.loads(ts[0])
419
- truth_settings.append(TruthSetting(config=entry.get('config', None),
420
- function=entry.get('function', None),
421
- index=entry.get('index', None)))
422
- target_files.append(TargetFile(name=target[0],
423
- samples=target[1],
424
- level_type=target[2],
425
- truth_settings=truth_settings))
327
+ (target_file_record.id,)).fetchall():
328
+ truth_setting = json.loads(truth_setting_records[0])
329
+ truth_settings.append(TruthSetting(config=truth_setting.get('config', None),
330
+ function=truth_setting.get('function', None),
331
+ index=truth_setting.get('index', None)))
332
+ target_files.append(TargetFile(name=target_file_record.name,
333
+ samples=target_file_record.samples,
334
+ level_type=target_file_record.level_type,
335
+ truth_settings=truth_settings,
336
+ speaker_id=target_file_record.speaker_id))
426
337
  return target_files
427
338
 
428
339
  @cached_property
@@ -532,18 +443,16 @@ class MixtureDatabase:
532
443
  """
533
444
  from .helpers import to_mixture
534
445
  from .helpers import to_target
446
+ from .db_datatypes import MixtureRecord
447
+ from .db_datatypes import TargetRecord
535
448
 
536
449
  with self.db() as c:
537
450
  mixtures: Mixtures = []
538
- for mixture in c.execute(
539
- "SELECT mixture.name, noise_file_id, noise_augmentation, noise_offset, noise_snr_gain, " +
540
- "random_snr, snr, samples, spectral_mask_id, spectral_mask_seed, target_snr_gain, id " +
541
- "FROM mixture").fetchall():
542
- targets = [to_target(target) for target in c.execute(
543
- "SELECT target.file_id, augmentation, gain " +
544
- "FROM target, mixture_target " +
451
+ for mixture in [MixtureRecord(*record) for record in c.execute("SELECT * FROM mixture").fetchall()]:
452
+ targets = [to_target(TargetRecord(*target)) for target in c.execute(
453
+ "SELECT target.* FROM target, mixture_target " +
545
454
  "WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
546
- (mixture[11],)).fetchall()]
455
+ (mixture.id,)).fetchall()]
547
456
  mixtures.append(to_mixture(mixture, targets))
548
457
  return mixtures
549
458
 
@@ -567,7 +476,7 @@ class MixtureDatabase:
567
476
  @cached_property
568
477
  def mixid_width(self) -> int:
569
478
  with self.db() as c:
570
- return int(c.execute("SELECT top.mixid_width from top").fetchone()[0])
479
+ return int(c.execute("SELECT top.mixid_width FROM top").fetchone()[0])
571
480
 
572
481
  def location_filename(self, name: str) -> str:
573
482
  """Add the location to the given file name
@@ -719,7 +628,7 @@ class MixtureDatabase:
719
628
 
720
629
  :param m_id: Zero-based mixture ID
721
630
  :param targets: List of augmented target audio data (one per target in the mixup)
722
- :param target: Augmented target audio for the given mixid
631
+ :param target: Augmented target audio for the given m_id
723
632
  :param force: Force computing data from original sources regardless of whether cached data exists
724
633
  :return: Augmented target transform data
725
634
  """
@@ -1077,97 +986,298 @@ class MixtureDatabase:
1077
986
  return class_count
1078
987
 
1079
988
  @cached_property
1080
- def _speech_metadata(self) -> dict[str, dict[str, SpeechMetadata]]:
1081
- """Speech metadata is a nested dictionary.
989
+ def speaker_metadata_tiers(self) -> list[str]:
990
+ import json
1082
991
 
1083
- data['target_file_name'] = { 'tier': SpeechMetadata, ... }
1084
- """
1085
- data: dict[str, dict[str, SpeechMetadata]] = {}
1086
- for file in self.target_files:
1087
- data[file.name] = {}
1088
- file_name, _ = tokenized_expand(file.name)
1089
- tg_file = Path(file_name).with_suffix('.TextGrid')
1090
- if tg_file.exists():
1091
- tg = textgrid.openTextgrid(str(tg_file), includeEmptyIntervals=False)
1092
- for tier in tg.tierNames:
1093
- entries = tg.getTier(tier).entries
1094
- if len(entries) > 1:
1095
- data[file.name][tier] = entries
1096
- else:
1097
- data[file.name][tier] = entries[0].label
992
+ with self.db() as c:
993
+ return json.loads(c.execute("SELECT speaker_metadata_tiers FROM top WHERE 1 = id").fetchone()[0])
1098
994
 
1099
- return data
995
+ @cached_property
996
+ def textgrid_metadata_tiers(self) -> list[str]:
997
+ import json
998
+
999
+ with self.db() as c:
1000
+ return json.loads(c.execute("SELECT textgrid_metadata_tiers FROM top WHERE 1 = id").fetchone()[0])
1100
1001
 
1101
1002
  @cached_property
1102
1003
  def speech_metadata_tiers(self) -> list[str]:
1103
- return sorted(list(set([key for value in self._speech_metadata.values() for key in value.keys()])))
1004
+ return sorted(set(self.speaker_metadata_tiers + self.textgrid_metadata_tiers))
1005
+
1006
+ def speaker(self, s_id: int | None, tier: str) -> Optional[str]:
1007
+ return _speaker(self.db, s_id, tier)
1008
+
1009
+ def speech_metadata(self, tier: str) -> list[str]:
1010
+ from .helpers import get_textgrid_tier_from_target_file
1011
+
1012
+ results: set[str] = set()
1013
+ if tier in self.textgrid_metadata_tiers:
1014
+ for target_file in self.target_files:
1015
+ data = get_textgrid_tier_from_target_file(target_file.name, tier)
1016
+ if data is None:
1017
+ continue
1018
+ if isinstance(data, list):
1019
+ for item in data:
1020
+ results.add(item.label)
1021
+ else:
1022
+ results.add(data)
1023
+ elif tier in self.speaker_metadata_tiers:
1024
+ for target_file in self.target_files:
1025
+ data = self.speaker(target_file.speaker_id, tier)
1026
+ if data is not None:
1027
+ results.add(data)
1028
+
1029
+ return sorted(results)
1030
+
1031
+ def mixture_speech_metadata(self, mixid: int, tier: str) -> list[SpeechMetadata]:
1032
+ from praatio.utilities.constants import Interval
1033
+
1034
+ from .helpers import get_textgrid_tier_from_target_file
1035
+
1036
+ results: list[SpeechMetadata] = []
1037
+ is_textgrid = tier in self.textgrid_metadata_tiers
1038
+ if is_textgrid:
1039
+ for target in self.mixture(mixid).targets:
1040
+ data = get_textgrid_tier_from_target_file(self.target_file(target.file_id).name, tier)
1041
+ if data is not None:
1042
+ if isinstance(data, list):
1043
+ # Check for tempo augmentation and adjust Interval start and end data as needed
1044
+ entries = []
1045
+ for entry in data:
1046
+ if target.augmentation.tempo is not None:
1047
+ entries.append(Interval(entry.start / target.augmentation.tempo,
1048
+ entry.end / target.augmentation.tempo,
1049
+ entry.label))
1050
+ else:
1051
+ entries.append(entry)
1052
+ results.append(entries)
1053
+ else:
1054
+ results.append(data)
1055
+ else:
1056
+ for target in self.mixture(mixid).targets:
1057
+ data = self.speaker(self.target_file(target.file_id).speaker_id, tier)
1058
+ if data is not None:
1059
+ results.append(data)
1104
1060
 
1105
- def speech_metadata_all(self, tier: str) -> list[SpeechMetadata]:
1106
- results = sorted(
1107
- set([value.get(tier) for value in self._speech_metadata.values() if isinstance(value.get(tier), str)]))
1108
- return results
1061
+ return sorted(results)
1109
1062
 
1110
1063
  def mixids_for_speech_metadata(self,
1111
1064
  tier: str,
1112
- value: str,
1065
+ value: str | None,
1113
1066
  predicate: Callable[[str], bool] = None) -> list[int]:
1114
- """Get a list of mixids for the given speech metadata tier.
1067
+ """Get a list of mixture IDs for the given speech metadata tier.
1115
1068
 
1116
- If 'predicate' is None, then include mixids whose tier values are equal to the given 'value'. If 'predicate' is
1117
- not None, then ignore 'value' and use the given callable to determine which entries to include.
1069
+ If 'predicate' is None, then include mixture IDs whose tier values are equal to the given 'value'.
1070
+ If 'predicate' is not None, then ignore 'value' and use the given callable to determine which entries
1071
+ to include.
1118
1072
 
1119
1073
  Examples:
1074
+ >>> mixdb = MixtureDatabase('/mixdb_location')
1120
1075
 
1121
1076
  >>> mixids = mixdb.mixids_for_speech_metadata('speaker_id', 'TIMIT_ARC0')
1122
- Get mixids for mixtures with speakers whose speaker_ids are 'TIMIT_ARC0'.
1077
+ Get mixutre IDs for mixtures with speakers whose speaker_ids are 'TIMIT_ARC0'.
1123
1078
 
1124
1079
  >>> mixids = mixdb.mixids_for_speech_metadata('age', '', lambda x: int(x) < 25)
1125
- Get mixids for mixtures with speakers whose ages are less than 25.
1080
+ Get mixture IDs for mixtures with speakers whose ages are less than 25.
1126
1081
 
1127
1082
  >>> mixids = mixdb.mixids_for_speech_metadata('dialect', '', lambda x: x in ['New York City', 'Northern'])
1128
- Get mixids for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
1083
+ Get mixture IDs for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
1129
1084
  """
1085
+ from .helpers import get_textgrid_tier_from_target_file
1086
+
1130
1087
  if predicate is None:
1131
- def predicate(x: str) -> bool:
1088
+ def predicate(x: str | None) -> bool:
1132
1089
  return x == value
1133
1090
 
1134
1091
  # First get list of matching target files
1135
- target_files = [k for k, v in self._speech_metadata.items() if
1136
- isinstance(v.get(tier), str) and predicate(str(v.get(tier)))]
1092
+ target_file_ids: list[int] = []
1093
+ is_textgrid = tier in self.textgrid_metadata_tiers
1094
+ for target_file_id, target_file in enumerate(self.target_files):
1095
+ if is_textgrid:
1096
+ metadata = get_textgrid_tier_from_target_file(target_file.name, tier)
1097
+ else:
1098
+ metadata = self.speaker(target_file.speaker_id, tier)
1137
1099
 
1138
- # Next get list of mixids that contain those target files
1139
- mixids: list[int] = []
1140
- for mixid in self.mixids_to_list():
1141
- mixid_target_files = [self.target_file(target.file_id).name for target in self.mixture(mixid).targets]
1142
- for mixid_target_file in mixid_target_files:
1143
- if mixid_target_file in target_files:
1144
- mixids.append(mixid)
1100
+ if not isinstance(metadata, list) and predicate(metadata):
1101
+ target_file_ids.append(target_file_id + 1)
1145
1102
 
1146
- # Return sorted, unique list of mixids
1147
- return sorted(list(set(mixids)))
1103
+ # Next get list of mixture IDs that contain those target files
1104
+ with self.db() as c:
1105
+ m_ids = c.execute("SELECT mixture_id FROM mixture_target " +
1106
+ f"WHERE mixture_target.target_id IN ({','.join(map(str, target_file_ids))})").fetchall()
1107
+ return [x[0] - 1 for x in m_ids]
1148
1108
 
1149
- def get_speech_metadata(self, mixid: int, tier: str) -> list[SpeechMetadata]:
1150
- results: list[SpeechMetadata] = []
1151
- for target in self.mixture(mixid).targets:
1152
- data = self._speech_metadata[self.target_file(target.file_id).name].get(tier)
1153
-
1154
- if data is None:
1155
- results.append(None)
1156
- elif isinstance(data, list):
1157
- # Check for tempo augmentation and adjust Interval start and end data as needed
1158
- entries = []
1159
- for entry in data:
1160
- if target.augmentation.tempo is not None:
1161
- entries.append(Interval(entry.start / target.augmentation.tempo,
1162
- entry.end / target.augmentation.tempo,
1163
- entry.label))
1164
- else:
1165
- entries.append(entry)
1109
+ def mixture_all_speech_metadata(self, m_id: int) -> list[dict[str, SpeechMetadata]]:
1110
+ from .helpers import mixture_all_speech_metadata
1166
1111
 
1167
- else:
1168
- results.append(data)
1112
+ return mixture_all_speech_metadata(self, self.mixture(m_id))
1169
1113
 
1170
- return results
1114
+ def mixture_metric(self, m_id: int, metric: str, force: bool = False) -> Any:
1115
+ """Get metric data for the given mixture ID
1116
+
1117
+ :param m_id: Zero-based mixture ID
1118
+ :param metric: Metric data to retrieve
1119
+ :param force: Force computing data from original sources regardless of whether cached data exists
1120
+ :return: Metric data
1121
+ """
1122
+ from sonusai import SonusAIError
1123
+
1124
+ supported_metrics = (
1125
+ 'MXSNR',
1126
+ 'MXSSNRAVG',
1127
+ 'MXSSNRSTD',
1128
+ 'MXSSNRDAVG',
1129
+ 'MXSSNRDSTD',
1130
+ 'MXPESQ',
1131
+ 'MXWSDR',
1132
+ 'MXPD',
1133
+ 'MXSTOI',
1134
+ 'MXCSIG',
1135
+ 'MXCBAK',
1136
+ 'MXCOVL',
1137
+ 'TDCO',
1138
+ 'TMIN',
1139
+ 'TMAX',
1140
+ 'TPKDB',
1141
+ 'TLRMS',
1142
+ 'TPKR',
1143
+ 'TTR',
1144
+ 'TCR',
1145
+ 'TFL',
1146
+ 'TPKC',
1147
+ 'NDCO',
1148
+ 'NMIN',
1149
+ 'NMAX',
1150
+ 'NPKDB',
1151
+ 'NLRMS',
1152
+ 'NPKR',
1153
+ 'NTR',
1154
+ 'NCR',
1155
+ 'NFL',
1156
+ 'NPKC',
1157
+ 'SEDAVG',
1158
+ 'SEDCNT',
1159
+ 'SEDTOPN',
1160
+ )
1161
+
1162
+ if not (metric in supported_metrics or metric.startswith('MXWER')):
1163
+ raise ValueError(f'Unsupported metric: {metric}')
1164
+
1165
+ if not force:
1166
+ result = self.read_mixture_data(m_id, metric)
1167
+ if result is not None:
1168
+ return result
1169
+
1170
+ mixture = self.mixture(m_id)
1171
+ if mixture is None:
1172
+ raise SonusAIError(f'Could not find mixture for m_id: {m_id}')
1173
+
1174
+ if metric.startswith('MXWER'):
1175
+ return None
1176
+
1177
+ if metric == 'MXSNR':
1178
+ return self.snrs
1179
+
1180
+ if metric == 'MXSSNRAVG':
1181
+ return None
1182
+
1183
+ if metric == 'MXSSNRSTD':
1184
+ return None
1185
+
1186
+ if metric == 'MXSSNRDAVG':
1187
+ return None
1188
+
1189
+ if metric == 'MXSSNRDSTD':
1190
+ return None
1191
+
1192
+ if metric == 'MXPESQ':
1193
+ return None
1194
+
1195
+ if metric == 'MXWSDR':
1196
+ return None
1197
+
1198
+ if metric == 'MXPD':
1199
+ return None
1200
+
1201
+ if metric == 'MXSTOI':
1202
+ return None
1203
+
1204
+ if metric == 'MXCSIG':
1205
+ return None
1206
+
1207
+ if metric == 'MXCBAK':
1208
+ return None
1209
+
1210
+ if metric == 'MXCOVL':
1211
+ return None
1212
+
1213
+ if metric == 'TDCO':
1214
+ return None
1215
+
1216
+ if metric == 'TMIN':
1217
+ return None
1218
+
1219
+ if metric == 'TMAX':
1220
+ return None
1221
+
1222
+ if metric == 'TPKDB':
1223
+ return None
1224
+
1225
+ if metric == 'TLRMS':
1226
+ return None
1227
+
1228
+ if metric == 'TPKR':
1229
+ return None
1230
+
1231
+ if metric == 'TTR':
1232
+ return None
1233
+
1234
+ if metric == 'TCR':
1235
+ return None
1236
+
1237
+ if metric == 'TFL':
1238
+ return None
1239
+
1240
+ if metric == 'TPKC':
1241
+ return None
1242
+
1243
+ if metric == 'NDCO':
1244
+ return None
1245
+
1246
+ if metric == 'NMIN':
1247
+ return None
1248
+
1249
+ if metric == 'NMAX':
1250
+ return None
1251
+
1252
+ if metric == 'NPKDB':
1253
+ return None
1254
+
1255
+ if metric == 'NLRMS':
1256
+ return None
1257
+
1258
+ if metric == 'NPKR':
1259
+ return None
1260
+
1261
+ if metric == 'NTR':
1262
+ return None
1263
+
1264
+ if metric == 'NCR':
1265
+ return None
1266
+
1267
+ if metric == 'NFL':
1268
+ return None
1269
+
1270
+ if metric == 'NPKC':
1271
+ return None
1272
+
1273
+ if metric == 'SEDAVG':
1274
+ return None
1275
+
1276
+ if metric == 'SEDCNT':
1277
+ return None
1278
+
1279
+ if metric == 'SEDTOPN':
1280
+ return None
1171
1281
 
1172
1282
 
1173
1283
  @lru_cache
@@ -1178,17 +1288,16 @@ def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
1178
1288
  :param sm_id: Spectral mask ID
1179
1289
  :return: Spectral mask
1180
1290
  """
1291
+ from .db_datatypes import SpectralMaskRecord
1292
+
1181
1293
  with db() as c:
1182
- spectral_mask = c.execute(
1183
- "SELECT spectral_mask.f_max_width, f_num, t_max_width, t_num, t_max_percent " +
1184
- "FROM spectral_mask " +
1185
- "WHERE ? = spectral_mask.id",
1186
- (sm_id,)).fetchone()
1187
- return SpectralMask(f_max_width=spectral_mask[0],
1188
- f_num=spectral_mask[1],
1189
- t_max_width=spectral_mask[2],
1190
- t_num=spectral_mask[3],
1191
- t_max_percent=spectral_mask[4])
1294
+ spectral_mask = SpectralMaskRecord(*c.execute("SELECT * FROM spectral_mask WHERE ? = spectral_mask.id",
1295
+ (sm_id,)).fetchone())
1296
+ return SpectralMask(f_max_width=spectral_mask.f_max_width,
1297
+ f_num=spectral_mask.f_num,
1298
+ t_max_width=spectral_mask.t_max_width,
1299
+ t_num=spectral_mask.t_num,
1300
+ t_max_percent=spectral_mask.t_max_percent)
1192
1301
 
1193
1302
 
1194
1303
  @lru_cache
@@ -1203,10 +1312,11 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
1203
1312
 
1204
1313
  from .datatypes import TruthSetting
1205
1314
  from .datatypes import TruthSettings
1315
+ from .db_datatypes import TargetFileRecord
1206
1316
 
1207
1317
  with db() as c:
1208
- target = c.execute("SELECT target_file.name, samples, level_type FROM target_file WHERE ? = target_file.id",
1209
- (t_id,)).fetchone()
1318
+ target_file = TargetFileRecord(
1319
+ *c.execute("SELECT * FROM target_file WHERE ? = target_file.id", (t_id,)).fetchone())
1210
1320
 
1211
1321
  truth_settings: TruthSettings = []
1212
1322
  for ts in c.execute(
@@ -1219,10 +1329,11 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
1219
1329
  truth_settings.append(TruthSetting(config=entry.get('config', None),
1220
1330
  function=entry.get('function', None),
1221
1331
  index=entry.get('index', None)))
1222
- return TargetFile(name=target[0],
1223
- samples=target[1],
1224
- level_type=target[2],
1225
- truth_settings=truth_settings)
1332
+ return TargetFile(name=target_file.name,
1333
+ samples=target_file.samples,
1334
+ level_type=target_file.level_type,
1335
+ truth_settings=truth_settings,
1336
+ speaker_id=target_file.speaker_id)
1226
1337
 
1227
1338
 
1228
1339
  @lru_cache
@@ -1263,19 +1374,29 @@ def _mixture(db: partial, m_id: int) -> Mixture:
1263
1374
  """
1264
1375
  from .helpers import to_mixture
1265
1376
  from .helpers import to_target
1377
+ from .db_datatypes import MixtureRecord
1378
+ from .db_datatypes import TargetRecord
1266
1379
 
1267
1380
  with db() as c:
1268
- mixture = c.execute(
1269
- "SELECT mixture.name, noise_file_id, noise_augmentation, noise_offset, noise_snr_gain, " +
1270
- "random_snr, snr, samples, spectral_mask_id, spectral_mask_seed, target_snr_gain, id " +
1271
- "FROM mixture " +
1272
- "WHERE ? = mixture.id",
1273
- (m_id + 1,)).fetchone()
1274
-
1275
- targets = [to_target(target) for target in c.execute(
1276
- "SELECT target.file_id, augmentation, gain " +
1381
+ mixture = MixtureRecord(*c.execute("SELECT * FROM mixture WHERE ? = mixture.id", (m_id + 1,)).fetchone())
1382
+ targets = [to_target(TargetRecord(*target)) for target in c.execute(
1383
+ "SELECT target.* " +
1277
1384
  "FROM target, mixture_target " +
1278
1385
  "WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
1279
- (mixture[11],)).fetchall()]
1386
+ (mixture.id,)).fetchall()]
1280
1387
 
1281
1388
  return to_mixture(mixture, targets)
1389
+
1390
+
1391
+ @lru_cache
1392
+ def _speaker(db: partial, s_id: int | None, tier: str) -> Optional[str]:
1393
+ if s_id is None:
1394
+ return None
1395
+
1396
+ with db() as c:
1397
+ data = c.execute(f'SELECT {tier} FROM speaker WHERE ? = id', (s_id,)).fetchone()
1398
+ if data is None:
1399
+ return None
1400
+ if data[0] is None:
1401
+ return None
1402
+ return data[0]