sonusai 0.17.2__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. sonusai/__init__.py +0 -1
  2. sonusai/audiofe.py +3 -3
  3. sonusai/calc_metric_spenh.py +81 -52
  4. sonusai/doc/doc.py +0 -24
  5. sonusai/genmetrics.py +146 -0
  6. sonusai/genmixdb.py +0 -2
  7. sonusai/mixture/__init__.py +0 -1
  8. sonusai/mixture/constants.py +0 -1
  9. sonusai/mixture/datatypes.py +2 -9
  10. sonusai/mixture/generation.py +136 -38
  11. sonusai/mixture/helpers.py +58 -1
  12. sonusai/mixture/mapped_snr_f.py +56 -9
  13. sonusai/mixture/mixdb.py +293 -170
  14. sonusai/mixture/sox_augmentation.py +3 -0
  15. sonusai/mixture/tokenized_shell_vars.py +8 -1
  16. sonusai/mkwav.py +4 -4
  17. sonusai/onnx_predict.py +2 -2
  18. sonusai/post_spenh_targetf.py +2 -2
  19. sonusai/speech/textgrid.py +6 -24
  20. sonusai/speech/{voxceleb2.py → voxceleb.py} +19 -3
  21. sonusai/utils/__init__.py +1 -1
  22. sonusai/utils/asr_functions/aaware_whisper.py +2 -2
  23. sonusai/utils/{wave.py → write_audio.py} +2 -2
  24. {sonusai-0.17.2.dist-info → sonusai-0.18.0.dist-info}/METADATA +4 -1
  25. {sonusai-0.17.2.dist-info → sonusai-0.18.0.dist-info}/RECORD +27 -33
  26. sonusai/mixture/speaker_metadata.py +0 -35
  27. sonusai/mkmanifest.py +0 -209
  28. sonusai/utils/asr_manifest_functions/__init__.py +0 -6
  29. sonusai/utils/asr_manifest_functions/data.py +0 -1
  30. sonusai/utils/asr_manifest_functions/librispeech.py +0 -46
  31. sonusai/utils/asr_manifest_functions/mcgill_speech.py +0 -29
  32. sonusai/utils/asr_manifest_functions/vctk_noisy_speech.py +0 -66
  33. {sonusai-0.17.2.dist-info → sonusai-0.18.0.dist-info}/WHEEL +0 -0
  34. {sonusai-0.17.2.dist-info → sonusai-0.18.0.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py CHANGED
@@ -1,16 +1,12 @@
1
1
  from functools import cached_property
2
2
  from functools import lru_cache
3
3
  from functools import partial
4
- from pathlib import Path
5
4
  from sqlite3 import Connection
6
5
  from sqlite3 import Cursor
7
6
  from typing import Any
8
7
  from typing import Callable
9
8
  from typing import Optional
10
9
 
11
- from praatio import textgrid
12
- from praatio.utilities.constants import Interval
13
-
14
10
  from sonusai.mixture.datatypes import AudioF
15
11
  from sonusai.mixture.datatypes import AudioT
16
12
  from sonusai.mixture.datatypes import AudiosF
@@ -34,7 +30,6 @@ from sonusai.mixture.datatypes import TargetFiles
34
30
  from sonusai.mixture.datatypes import TransformConfig
35
31
  from sonusai.mixture.datatypes import Truth
36
32
  from sonusai.mixture.datatypes import UniversalSNR
37
- from sonusai.mixture.tokenized_shell_vars import tokenized_expand
38
33
 
39
34
 
40
35
  def db_file(location: str, test: bool = False) -> str:
@@ -88,14 +83,12 @@ class MixtureDatabase:
88
83
  def __init__(self, location: str, test: bool = False) -> None:
89
84
  self.location = location
90
85
  self.db = partial(SQLiteContextManager, self.location, test)
91
- self._speaker_metadata_tiers: list[str] = []
92
86
 
93
87
  @cached_property
94
88
  def json(self) -> str:
95
89
  from .datatypes import MixtureDatabaseConfig
96
90
 
97
91
  config = MixtureDatabaseConfig(
98
- asr_manifest=self.asr_manifests,
99
92
  class_balancing=self.class_balancing,
100
93
  class_labels=self.class_labels,
101
94
  class_weights_threshold=self.class_weights_thresholds,
@@ -121,86 +114,6 @@ class MixtureDatabase:
121
114
  with open(file=json_name, mode='w') as file:
122
115
  file.write(self.json)
123
116
 
124
- def target_asr_data(self, t_id: int) -> str | None:
125
- """Get the ASR data for the given target ID
126
-
127
- :param t_id: Target ID
128
- :return: ASR text or None
129
- """
130
- from .tokenized_shell_vars import tokenized_expand
131
-
132
- name, _ = tokenized_expand(self.target_file(t_id).name)
133
- return self.asr_manifest_data.get(name, None)
134
-
135
- def mixture_asr_data(self, m_id: int) -> list[str | None]:
136
- """Get the ASR data for the given mixid
137
-
138
- :param m_id: Zero-based mixture ID
139
- :return: List of ASR text or None
140
- """
141
- return [self.target_asr_data(target.file_id) for target in self.mixture(m_id).targets]
142
-
143
- @cached_property
144
- def asr_manifest_data(self) -> dict[str, str]:
145
- """Get ASR data
146
-
147
- Each line of a manifest file should be in the following format:
148
-
149
- {"audio_filepath": "/path/to/audio.wav", "text": "the transcription of the utterance", "duration": 23.147}
150
-
151
- The audio_filepath field should provide an absolute path to the audio file corresponding to the utterance. The
152
- text field should contain the full transcript for the utterance, and the duration field should reflect the
153
- duration of the utterance in seconds.
154
-
155
- Each entry in the manifest (describing one audio file) should be bordered by '{' and '}' and must be contained
156
- on one line. The fields that describe the file should be separated by commas, and have the form
157
- "field_name": value, as shown above.
158
-
159
- Since the manifest specifies the path for each utterance, the audio files do not have to be located in the same
160
- directory as the manifest, or even in any specific directory structure.
161
-
162
- The manifest dictionary consists of key/value pairs where the keys are target file names and the values are ASR
163
- text.
164
- """
165
- import json
166
-
167
- from sonusai import SonusAIError
168
- from .tokenized_shell_vars import tokenized_expand
169
-
170
- expected_keys = ['audio_filepath', 'text', 'duration']
171
-
172
- def _error_preamble(e_name: str, e_line_num: int) -> str:
173
- return f'Invalid entry in ASR manifest {e_name} line {e_line_num}'
174
-
175
- asr_manifest_data: dict[str, str] = {}
176
-
177
- for name in self.asr_manifests:
178
- expanded_name, _ = tokenized_expand(name)
179
- with open(file=expanded_name, mode='r') as f:
180
- line_num = 1
181
- for line in f:
182
- result = json.loads(line.strip())
183
-
184
- for key in expected_keys:
185
- if key not in result:
186
- SonusAIError(f'{_error_preamble(name, line_num)}: missing field "{key}"')
187
-
188
- for key in result.keys():
189
- if key not in expected_keys:
190
- SonusAIError(f'{_error_preamble(name, line_num)}: unknown field "{key}"')
191
-
192
- key, _ = tokenized_expand(result['audio_filepath'])
193
- value = result['text']
194
-
195
- if key in asr_manifest_data:
196
- SonusAIError(f'{_error_preamble(name, line_num)}: entry already exists')
197
-
198
- asr_manifest_data[key] = value
199
-
200
- line_num += 1
201
-
202
- return asr_manifest_data
203
-
204
117
  @cached_property
205
118
  def fg_config(self) -> FeatureGeneratorConfig:
206
119
  return FeatureGeneratorConfig(feature_mode=self.feature,
@@ -293,14 +206,14 @@ class MixtureDatabase:
293
206
  def feature_step_samples(self) -> int:
294
207
  return self.ft_config.R * self.fg_decimation * self.fg_step
295
208
 
296
- def total_samples(self, mixids: GeneralizedIDs = '*') -> int:
297
- return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(mixids)])
209
+ def total_samples(self, m_ids: GeneralizedIDs = '*') -> int:
210
+ return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(m_ids)])
298
211
 
299
- def total_transform_frames(self, mixids: GeneralizedIDs = '*') -> int:
300
- return self.total_samples(mixids) // self.ft_config.R
212
+ def total_transform_frames(self, m_ids: GeneralizedIDs = '*') -> int:
213
+ return self.total_samples(m_ids) // self.ft_config.R
301
214
 
302
- def total_feature_frames(self, mixids: GeneralizedIDs = '*') -> int:
303
- return self.total_samples(mixids) // self.feature_step_samples
215
+ def total_feature_frames(self, m_ids: GeneralizedIDs = '*') -> int:
216
+ return self.total_samples(m_ids) // self.feature_step_samples
304
217
 
305
218
  def mixture_transform_frames(self, samples: int) -> int:
306
219
  return samples // self.ft_config.R
@@ -308,24 +221,15 @@ class MixtureDatabase:
308
221
  def mixture_feature_frames(self, samples: int) -> int:
309
222
  return samples // self.feature_step_samples
310
223
 
311
- def mixids_to_list(self, mixids: Optional[GeneralizedIDs] = None) -> list[int]:
224
+ def mixids_to_list(self, m_ids: Optional[GeneralizedIDs] = None) -> list[int]:
312
225
  """Resolve generalized mixture IDs to a list of integers
313
226
 
314
- :param mixids: Generalized mixture IDs
227
+ :param m_ids: Generalized mixture IDs
315
228
  :return: List of mixture ID integers
316
229
  """
317
230
  from .helpers import generic_ids_to_list
318
231
 
319
- return generic_ids_to_list(self.num_mixtures, mixids)
320
-
321
- @cached_property
322
- def asr_manifests(self) -> list[str]:
323
- """Get ASR manifests from db
324
-
325
- :return: ASR manifests
326
- """
327
- with self.db() as c:
328
- return [str(item[0]) for item in c.execute("SELECT asr_manifest.manifest FROM asr_manifest").fetchall()]
232
+ return generic_ids_to_list(self.num_mixtures, m_ids)
329
233
 
330
234
  @cached_property
331
235
  def class_labels(self) -> list[str]:
@@ -408,7 +312,8 @@ class MixtureDatabase:
408
312
 
409
313
  with self.db() as c:
410
314
  target_files: TargetFiles = []
411
- for target in c.execute("SELECT target_file.name, samples, level_type, id FROM target_file").fetchall():
315
+ for target in c.execute(
316
+ "SELECT target_file.name, samples, level_type, id, speaker_id FROM target_file").fetchall():
412
317
  truth_settings: TruthSettings = []
413
318
  for ts in c.execute(
414
319
  "SELECT truth_setting.setting " +
@@ -423,7 +328,8 @@ class MixtureDatabase:
423
328
  target_files.append(TargetFile(name=target[0],
424
329
  samples=target[1],
425
330
  level_type=target[2],
426
- truth_settings=truth_settings))
331
+ truth_settings=truth_settings,
332
+ speaker_id=target[4]))
427
333
  return target_files
428
334
 
429
335
  @cached_property
@@ -720,7 +626,7 @@ class MixtureDatabase:
720
626
 
721
627
  :param m_id: Zero-based mixture ID
722
628
  :param targets: List of augmented target audio data (one per target in the mixup)
723
- :param target: Augmented target audio for the given mixid
629
+ :param target: Augmented target audio for the given m_id
724
630
  :param force: Force computing data from original sources regardless of whether cached data exists
725
631
  :return: Augmented target transform data
726
632
  """
@@ -1078,97 +984,312 @@ class MixtureDatabase:
1078
984
  return class_count
1079
985
 
1080
986
  @cached_property
1081
- def _speech_metadata(self) -> dict[str, dict[str, SpeechMetadata]]:
1082
- """Speech metadata is a nested dictionary.
987
+ def speaker_metadata_tiers(self) -> list[str]:
988
+ import json
1083
989
 
1084
- data['target_file_name'] = { 'tier': SpeechMetadata, ... }
1085
- """
1086
- data: dict[str, dict[str, SpeechMetadata]] = {}
1087
- for file in self.target_files:
1088
- data[file.name] = {}
1089
- file_name, _ = tokenized_expand(file.name)
1090
- tg_file = Path(file_name).with_suffix('.TextGrid')
1091
- if tg_file.exists():
1092
- tg = textgrid.openTextgrid(str(tg_file), includeEmptyIntervals=False)
1093
- for tier in tg.tierNames:
1094
- entries = tg.getTier(tier).entries
1095
- if len(entries) > 1:
1096
- data[file.name][tier] = entries
1097
- else:
1098
- data[file.name][tier] = entries[0].label
990
+ with self.db() as c:
991
+ return json.loads(c.execute("SELECT speaker_metadata_tiers FROM top WHERE 1 = id").fetchone()[0])
1099
992
 
1100
- return data
993
+ @cached_property
994
+ def textgrid_metadata_tiers(self) -> list[str]:
995
+ import json
996
+
997
+ with self.db() as c:
998
+ return json.loads(c.execute("SELECT textgrid_metadata_tiers FROM top WHERE 1 = id").fetchone()[0])
1101
999
 
1102
1000
  @cached_property
1103
1001
  def speech_metadata_tiers(self) -> list[str]:
1104
- return sorted(list(set([key for value in self._speech_metadata.values() for key in value.keys()])))
1002
+ return sorted(set(self.speaker_metadata_tiers + self.textgrid_metadata_tiers))
1105
1003
 
1106
- def speech_metadata_all(self, tier: str) -> list[SpeechMetadata]:
1107
- results = sorted(
1108
- set([value.get(tier) for value in self._speech_metadata.values() if isinstance(value.get(tier), str)]))
1109
- return results
1004
+ def speaker(self, speaker_id: int | None, tier: str) -> Optional[str]:
1005
+ if speaker_id is None:
1006
+ return None
1007
+
1008
+ with self.db() as c:
1009
+ data = c.execute(f'SELECT {tier} FROM speaker WHERE ? = id', (speaker_id,)).fetchone()
1010
+ if data is None:
1011
+ return None
1012
+ if data[0] is None:
1013
+ return None
1014
+ return data[0]
1015
+
1016
+ def speech_metadata(self, tier: str) -> list[str]:
1017
+ from .helpers import get_textgrid_tier_from_target_file
1018
+
1019
+ results: set[str] = set()
1020
+ if tier in self.textgrid_metadata_tiers:
1021
+ for target_file in self.target_files:
1022
+ data = get_textgrid_tier_from_target_file(target_file.name, tier)
1023
+ if data is None:
1024
+ continue
1025
+ if isinstance(data, list):
1026
+ for item in data:
1027
+ results.add(item.label)
1028
+ else:
1029
+ results.add(data)
1030
+ elif tier in self.speaker_metadata_tiers:
1031
+ for target_file in self.target_files:
1032
+ data = self.speaker(target_file.speaker_id, tier)
1033
+ if data is not None:
1034
+ results.add(data)
1035
+
1036
+ return sorted(results)
1037
+
1038
+ def mixture_speech_metadata(self, mixid: int, tier: str) -> list[SpeechMetadata]:
1039
+ from praatio.utilities.constants import Interval
1040
+
1041
+ from .helpers import get_textgrid_tier_from_target_file
1042
+
1043
+ results: list[SpeechMetadata] = []
1044
+ is_textgrid = tier in self.textgrid_metadata_tiers
1045
+ if is_textgrid:
1046
+ for target in self.mixture(mixid).targets:
1047
+ data = get_textgrid_tier_from_target_file(self.target_file(target.file_id).name, tier)
1048
+ if data is not None:
1049
+ if isinstance(data, list):
1050
+ # Check for tempo augmentation and adjust Interval start and end data as needed
1051
+ entries = []
1052
+ for entry in data:
1053
+ if target.augmentation.tempo is not None:
1054
+ entries.append(Interval(entry.start / target.augmentation.tempo,
1055
+ entry.end / target.augmentation.tempo,
1056
+ entry.label))
1057
+ else:
1058
+ entries.append(entry)
1059
+ results.append(entries)
1060
+ else:
1061
+ results.append(data)
1062
+ else:
1063
+ for target in self.mixture(mixid).targets:
1064
+ data = self.speaker(self.target_file(target.file_id).speaker_id, tier)
1065
+ if data is not None:
1066
+ results.append(data)
1067
+
1068
+ return sorted(results)
1110
1069
 
1111
1070
  def mixids_for_speech_metadata(self,
1112
1071
  tier: str,
1113
- value: str,
1072
+ value: str | None,
1114
1073
  predicate: Callable[[str], bool] = None) -> list[int]:
1115
- """Get a list of mixids for the given speech metadata tier.
1074
+ """Get a list of mixture IDs for the given speech metadata tier.
1116
1075
 
1117
- If 'predicate' is None, then include mixids whose tier values are equal to the given 'value'. If 'predicate' is
1118
- not None, then ignore 'value' and use the given callable to determine which entries to include.
1076
+ If 'predicate' is None, then include mixture IDs whose tier values are equal to the given 'value'.
1077
+ If 'predicate' is not None, then ignore 'value' and use the given callable to determine which entries
1078
+ to include.
1119
1079
 
1120
1080
  Examples:
1081
+ >>> mixdb = MixtureDatabase('/mixdb_location')
1121
1082
 
1122
1083
  >>> mixids = mixdb.mixids_for_speech_metadata('speaker_id', 'TIMIT_ARC0')
1123
- Get mixids for mixtures with speakers whose speaker_ids are 'TIMIT_ARC0'.
1084
+ Get mixutre IDs for mixtures with speakers whose speaker_ids are 'TIMIT_ARC0'.
1124
1085
 
1125
1086
  >>> mixids = mixdb.mixids_for_speech_metadata('age', '', lambda x: int(x) < 25)
1126
- Get mixids for mixtures with speakers whose ages are less than 25.
1087
+ Get mixture IDs for mixtures with speakers whose ages are less than 25.
1127
1088
 
1128
1089
  >>> mixids = mixdb.mixids_for_speech_metadata('dialect', '', lambda x: x in ['New York City', 'Northern'])
1129
- Get mixids for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
1090
+ Get mixture IDs for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
1130
1091
  """
1092
+ from .helpers import get_textgrid_tier_from_target_file
1093
+
1131
1094
  if predicate is None:
1132
- def predicate(x: str) -> bool:
1095
+ def predicate(x: str | None) -> bool:
1133
1096
  return x == value
1134
1097
 
1135
1098
  # First get list of matching target files
1136
- target_files = [k for k, v in self._speech_metadata.items() if
1137
- isinstance(v.get(tier), str) and predicate(str(v.get(tier)))]
1099
+ target_files: list[str] = []
1100
+ is_textgrid = tier in self.textgrid_metadata_tiers
1101
+ for target_file in self.target_files:
1102
+ if is_textgrid:
1103
+ metadata = get_textgrid_tier_from_target_file(target_file.name, tier)
1104
+ else:
1105
+ metadata = self.speaker(target_file.speaker_id, tier)
1138
1106
 
1139
- # Next get list of mixids that contain those target files
1140
- mixids: list[int] = []
1141
- for mixid in self.mixids_to_list():
1142
- mixid_target_files = [self.target_file(target.file_id).name for target in self.mixture(mixid).targets]
1107
+ if not isinstance(metadata, list) and predicate(metadata):
1108
+ target_files.append(target_file.name)
1109
+
1110
+ # Next get list of mixture IDs that contain those target files
1111
+ m_ids: list[int] = []
1112
+ for m_id in self.mixids_to_list():
1113
+ mixid_target_files = [self.target_file(target.file_id).name for target in self.mixture(m_id).targets]
1143
1114
  for mixid_target_file in mixid_target_files:
1144
1115
  if mixid_target_file in target_files:
1145
- mixids.append(mixid)
1116
+ m_ids.append(m_id)
1146
1117
 
1147
- # Return sorted, unique list of mixids
1148
- return sorted(list(set(mixids)))
1118
+ # Return sorted, unique list of mixture IDs
1119
+ return sorted(list(set(m_ids)))
1149
1120
 
1150
- def get_speech_metadata(self, mixid: int, tier: str) -> list[SpeechMetadata]:
1151
- results: list[SpeechMetadata] = []
1152
- for target in self.mixture(mixid).targets:
1153
- data = self._speech_metadata[self.target_file(target.file_id).name].get(tier)
1121
+ def mixture_all_speech_metadata(self, m_id: int) -> list[dict[str, SpeechMetadata]]:
1122
+ from .helpers import mixture_all_speech_metadata
1154
1123
 
1155
- if data is None:
1156
- results.append(None)
1157
- elif isinstance(data, list):
1158
- # Check for tempo augmentation and adjust Interval start and end data as needed
1159
- entries = []
1160
- for entry in data:
1161
- if target.augmentation.tempo is not None:
1162
- entries.append(Interval(entry.start / target.augmentation.tempo,
1163
- entry.end / target.augmentation.tempo,
1164
- entry.label))
1165
- else:
1166
- entries.append(entry)
1124
+ return mixture_all_speech_metadata(self, self.mixture(m_id))
1167
1125
 
1168
- else:
1169
- results.append(data)
1126
+ def mixture_metric(self, m_id: int, metric: str, force: bool = False) -> Any:
1127
+ """Get metric data for the given mixture ID
1128
+
1129
+ :param m_id: Zero-based mixture ID
1130
+ :param metric: Metric data to retrieve
1131
+ :param force: Force computing data from original sources regardless of whether cached data exists
1132
+ :return: Metric data
1133
+ """
1134
+ from sonusai import SonusAIError
1135
+
1136
+ supported_metrics = (
1137
+ 'MXSNR',
1138
+ 'MXSSNRAVG',
1139
+ 'MXSSNRSTD',
1140
+ 'MXSSNRDAVG',
1141
+ 'MXSSNRDSTD',
1142
+ 'MXPESQ',
1143
+ 'MXWSDR',
1144
+ 'MXPD',
1145
+ 'MXSTOI',
1146
+ 'MXCSIG',
1147
+ 'MXCBAK',
1148
+ 'MXCOVL',
1149
+ 'TDCO',
1150
+ 'TMIN',
1151
+ 'TMAX',
1152
+ 'TPKDB',
1153
+ 'TLRMS',
1154
+ 'TPKR',
1155
+ 'TTR',
1156
+ 'TCR',
1157
+ 'TFL',
1158
+ 'TPKC',
1159
+ 'NDCO',
1160
+ 'NMIN',
1161
+ 'NMAX',
1162
+ 'NPKDB',
1163
+ 'NLRMS',
1164
+ 'NPKR',
1165
+ 'NTR',
1166
+ 'NCR',
1167
+ 'NFL',
1168
+ 'NPKC',
1169
+ 'SEDAVG',
1170
+ 'SEDCNT',
1171
+ 'SEDTOPN',
1172
+ )
1173
+
1174
+ if not (metric in supported_metrics or metric.startswith('MXWER')):
1175
+ raise ValueError(f'Unsupported metric: {metric}')
1176
+
1177
+ if not force:
1178
+ result = self.read_mixture_data(m_id, metric)
1179
+ if result is not None:
1180
+ return result
1181
+
1182
+ mixture = self.mixture(m_id)
1183
+ if mixture is None:
1184
+ raise SonusAIError(f'Could not find mixture for m_id: {m_id}')
1185
+
1186
+ if metric.startswith('MXWER'):
1187
+ return None
1188
+
1189
+ if metric == 'MXSNR':
1190
+ return self.snrs
1191
+
1192
+ if metric == 'MXSSNRAVG':
1193
+ return None
1194
+
1195
+ if metric == 'MXSSNRSTD':
1196
+ return None
1197
+
1198
+ if metric == 'MXSSNRDAVG':
1199
+ return None
1200
+
1201
+ if metric == 'MXSSNRDSTD':
1202
+ return None
1203
+
1204
+ if metric == 'MXPESQ':
1205
+ return None
1206
+
1207
+ if metric == 'MXWSDR':
1208
+ return None
1209
+
1210
+ if metric == 'MXPD':
1211
+ return None
1212
+
1213
+ if metric == 'MXSTOI':
1214
+ return None
1215
+
1216
+ if metric == 'MXCSIG':
1217
+ return None
1218
+
1219
+ if metric == 'MXCBAK':
1220
+ return None
1221
+
1222
+ if metric == 'MXCOVL':
1223
+ return None
1224
+
1225
+ if metric == 'TDCO':
1226
+ return None
1227
+
1228
+ if metric == 'TMIN':
1229
+ return None
1230
+
1231
+ if metric == 'TMAX':
1232
+ return None
1233
+
1234
+ if metric == 'TPKDB':
1235
+ return None
1236
+
1237
+ if metric == 'TLRMS':
1238
+ return None
1239
+
1240
+ if metric == 'TPKR':
1241
+ return None
1242
+
1243
+ if metric == 'TTR':
1244
+ return None
1245
+
1246
+ if metric == 'TCR':
1247
+ return None
1248
+
1249
+ if metric == 'TFL':
1250
+ return None
1251
+
1252
+ if metric == 'TPKC':
1253
+ return None
1254
+
1255
+ if metric == 'NDCO':
1256
+ return None
1257
+
1258
+ if metric == 'NMIN':
1259
+ return None
1260
+
1261
+ if metric == 'NMAX':
1262
+ return None
1263
+
1264
+ if metric == 'NPKDB':
1265
+ return None
1266
+
1267
+ if metric == 'NLRMS':
1268
+ return None
1269
+
1270
+ if metric == 'NPKR':
1271
+ return None
1272
+
1273
+ if metric == 'NTR':
1274
+ return None
1275
+
1276
+ if metric == 'NCR':
1277
+ return None
1278
+
1279
+ if metric == 'NFL':
1280
+ return None
1281
+
1282
+ if metric == 'NPKC':
1283
+ return None
1284
+
1285
+ if metric == 'SEDAVG':
1286
+ return None
1287
+
1288
+ if metric == 'SEDCNT':
1289
+ return None
1170
1290
 
1171
- return results
1291
+ if metric == 'SEDTOPN':
1292
+ return None
1172
1293
 
1173
1294
 
1174
1295
  @lru_cache
@@ -1206,8 +1327,9 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
1206
1327
  from .datatypes import TruthSettings
1207
1328
 
1208
1329
  with db() as c:
1209
- target = c.execute("SELECT target_file.name, samples, level_type FROM target_file WHERE ? = target_file.id",
1210
- (t_id,)).fetchone()
1330
+ target = c.execute(
1331
+ "SELECT target_file.name, samples, level_type, speaker_id FROM target_file WHERE ? = target_file.id",
1332
+ (t_id,)).fetchone()
1211
1333
 
1212
1334
  truth_settings: TruthSettings = []
1213
1335
  for ts in c.execute(
@@ -1223,7 +1345,8 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
1223
1345
  return TargetFile(name=target[0],
1224
1346
  samples=target[1],
1225
1347
  level_type=target[2],
1226
- truth_settings=truth_settings)
1348
+ truth_settings=truth_settings,
1349
+ speaker_id=target[3])
1227
1350
 
1228
1351
 
1229
1352
  @lru_cache
@@ -84,6 +84,7 @@ def apply_impulse_response(audio: AudioT, ir: ImpulseResponseData) -> AudioT:
84
84
  :return: Augmented audio
85
85
  """
86
86
  import math
87
+ from pathlib import Path
87
88
  import tempfile
88
89
 
89
90
  import numpy as np
@@ -124,7 +125,9 @@ def apply_impulse_response(audio: AudioT, ir: ImpulseResponseData) -> AudioT:
124
125
  except Exception as e:
125
126
  raise SonusAIError(f'Error applying IR: {e}')
126
127
 
128
+ path = Path(temp.name)
127
129
  temp.close()
130
+ path.unlink()
128
131
 
129
132
  # Reset level to previous max value
130
133
  tfm = Transformer()
@@ -1,4 +1,7 @@
1
- def tokenized_expand(name: str | bytes) -> tuple[str, dict[str, str]]:
1
+ from pathlib import Path
2
+
3
+
4
+ def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
2
5
  """Expand shell variables of the forms $var, ${var} and %var%.
3
6
  Unknown variables are left unchanged.
4
7
 
@@ -25,6 +28,9 @@ def tokenized_expand(name: str | bytes) -> tuple[str, dict[str, str]]:
25
28
  if isinstance(name, bytes):
26
29
  name = name.decode('utf-8')
27
30
 
31
+ if isinstance(name, Path):
32
+ name = name.as_posix()
33
+
28
34
  name = os.fspath(name)
29
35
  token_map: dict = {}
30
36
 
@@ -121,6 +127,7 @@ def tokenized_expand(name: str | bytes) -> tuple[str, dict[str, str]]:
121
127
  else:
122
128
  result += c
123
129
  index += 1
130
+
124
131
  return result, token_map
125
132
 
126
133
 
sonusai/mkwav.py CHANGED
@@ -72,7 +72,7 @@ def _process_mixture(mixid: int) -> None:
72
72
 
73
73
  from sonusai.mixture import mixture_metadata
74
74
  from sonusai.utils import float_to_int16
75
- from sonusai.utils import write_wav
75
+ from sonusai.utils import write_audio
76
76
 
77
77
  mixture_filename = join(MP_GLOBAL.mixdb.location, MP_GLOBAL.mixdb.mixtures[mixid].name)
78
78
  mixture_basename = splitext(mixture_filename)[0]
@@ -100,11 +100,11 @@ def _process_mixture(mixid: int) -> None:
100
100
  if MP_GLOBAL.write_noise:
101
101
  noise = np.array(f['noise'])
102
102
 
103
- write_wav(name=mixture_basename + '_mixture.wav', audio=float_to_int16(mixture))
103
+ write_audio(name=mixture_basename + '_mixture.wav', audio=float_to_int16(mixture))
104
104
  if MP_GLOBAL.write_target:
105
- write_wav(name=mixture_basename + '_target.wav', audio=float_to_int16(target))
105
+ write_audio(name=mixture_basename + '_target.wav', audio=float_to_int16(target))
106
106
  if MP_GLOBAL.write_noise:
107
- write_wav(name=mixture_basename + '_noise.wav', audio=float_to_int16(noise))
107
+ write_audio(name=mixture_basename + '_noise.wav', audio=float_to_int16(noise))
108
108
 
109
109
  with open(file=mixture_basename + '.txt', mode='w') as f:
110
110
  f.write(mixture_metadata(MP_GLOBAL.mixdb, MP_GLOBAL.mixdb.mixture(mixid)))
sonusai/onnx_predict.py CHANGED
@@ -100,7 +100,7 @@ def main() -> None:
100
100
  from sonusai.utils import create_ts_name
101
101
  from sonusai.utils import load_ort_session
102
102
  from sonusai.utils import reshape_inputs
103
- from sonusai.utils import write_wav
103
+ from sonusai.utils import write_audio
104
104
 
105
105
  mixdb_path = None
106
106
  mixdb = None
@@ -201,7 +201,7 @@ def main() -> None:
201
201
  predict = np.transpose(predict, [1, 0, 2])
202
202
  predict_audio = get_audio_from_feature(feature=predict, feature_mode=feature_mode)
203
203
  owav_name = splitext(output_fname)[0] + '_predict.wav'
204
- write_wav(owav_name, predict_audio)
204
+ write_audio(owav_name, predict_audio)
205
205
 
206
206
 
207
207
  if __name__ == '__main__':