sonusai 0.16.1__py3-none-any.whl → 0.17.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/onnx_predict.py CHANGED
@@ -1,19 +1,21 @@
1
- """sonusai predict
1
+ """sonusai onnx_predict
2
2
 
3
- usage: predict [-hvr] [-i MIXID] (-m MODEL) INPUT
3
+ usage: onnx_predict [-hvlwr] [--include GLOB] [-i MIXID] MODEL DATA ...
4
4
 
5
5
  options:
6
6
  -h, --help
7
7
  -v, --verbose Be verbose.
8
8
  -i MIXID, --mixid MIXID Mixture ID(s) to generate if input is a mixture database. [default: *].
9
- -m MODEL, --model MODEL Trained ONNX model file.
10
- -r, --reset Reset model between each file.
9
+ --include GLOB Search only files whose base name matches GLOB. [default: *.{wav,flac}].
10
+ -w, --write-wav Calculate inverse transform of prediction and write .wav files
11
11
 
12
- Run prediction on a trained ONNX model using SonusAI genft or WAV data.
12
+ Run prediction (inference) using an ONNX model on a SonusAI mixture dataset or audio files from a glob path.
13
+ The ONNX Runtime (ort) inference engine is used to execute the inference.
13
14
 
14
15
  Inputs:
15
- MODEL A SonusAI trained ONNX model file.
16
- INPUT The input data must be one of the following:
16
+ MODEL ONNX model .onnx file of a trained model (weights are expected to be in the file).
17
+
18
+ DATA The input data must be one of the following:
17
19
  * WAV
18
20
  Using the given model, generate feature data and run prediction. A model file must be
19
21
  provided. The MIXID is ignored.
@@ -22,19 +24,30 @@ Inputs:
22
24
  Using the given SonusAI mixture database directory, generate feature and truth data if not found.
23
25
  Run prediction. The MIXID is required.
24
26
 
27
+
28
+ Note there are multiple ways to process model prediction over multiple audio data files:
29
+ 1. TSE (timestep single extension): mixture transform frames are fit into the timestep dimension and the model run as
30
+ a single inference call. If batch_size is > 1 then run multiple mixtures in one call with shorter mixtures
31
+ zero-padded to the size of the largest mixture.
32
+ 2. TME (timestep multi-extension): mixture is split into multiple timesteps, i.e. batch[0] is starting timesteps, ...
33
+ Note that batches are run independently, thus sequential state from one set of timesteps to the next will not be
34
+ maintained, thus results for such models (i.e. conv, LSTMs, in the timestep dimension) would not match using
35
+ TSE mode.
36
+
37
+ TBD not sure below make sense, need to continue ??
38
+ 2. BSE (batch single extension): mixture transform frames are fit into the batch dimension. This make sense only if
39
+ independent predictions are made on each frame w/o considering previous frames (timesteps=1) or there is no
40
+ timestep dimension in the model (timesteps=0).
41
+ 3. Classification
42
+
25
43
  Outputs the following to opredict-<TIMESTAMP> directory:
26
44
  <id>.h5
27
45
  dataset: predict
28
46
  onnx_predict.log
29
47
 
30
48
  """
31
-
32
49
  import signal
33
50
 
34
- from sonusai.mixture import Feature
35
- from sonusai.mixture import Predict
36
- from sonusai.utils import SonusAIMetaData
37
-
38
51
 
39
52
  def signal_handler(_sig, _frame):
40
53
  import sys
@@ -57,193 +70,138 @@ def main() -> None:
57
70
  args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
58
71
 
59
72
  verbose = args['--verbose']
73
+ wav = args['--write-wav']
60
74
  mixids = args['--mixid']
61
- model_name = args['--model']
62
- reset = args['--reset']
63
- input_name = args['INPUT']
75
+ include = args['--include']
76
+ model_path = args['MODEL']
77
+ data_paths = args['DATA']
64
78
 
65
79
  from os import makedirs
80
+ from os.path import abspath
81
+ from os.path import basename
66
82
  from os.path import isdir
67
83
  from os.path import join
84
+ from os.path import normpath
85
+ from os.path import realpath
68
86
  from os.path import splitext
69
87
 
70
88
  import h5py
71
- import onnxruntime as rt
72
89
  import numpy as np
90
+ import onnxruntime as ort
73
91
 
74
92
  from sonusai import create_file_handler
75
93
  from sonusai import initial_log_messages
76
94
  from sonusai import logger
77
95
  from sonusai import update_console_handler
78
96
  from sonusai.mixture import MixtureDatabase
79
- from sonusai.mixture import get_feature_from_audio
80
- from sonusai.mixture import read_audio
97
+ from sonusai.mixture import get_audio_from_feature
98
+ from sonusai.utils import PathInfo
99
+ from sonusai.utils import braced_iglob
81
100
  from sonusai.utils import create_ts_name
82
- from sonusai.utils import get_frames_per_batch
83
- from sonusai.utils import get_sonusai_metadata
101
+ from sonusai.utils import load_ort_session
102
+ from sonusai.utils import reshape_inputs
103
+ from sonusai.utils import write_wav
84
104
 
85
- output_dir = create_ts_name('opredict')
105
+ mixdb_path = None
106
+ mixdb = None
107
+ p_mixids = None
108
+ entries: list[PathInfo] = []
109
+
110
+ if len(data_paths) == 1 and isdir(data_paths[0]):
111
+ # Assume it's a single path to SonusAI mixdb subdir
112
+ in_basename = basename(normpath(data_paths[0]))
113
+ mixdb_path = data_paths[0]
114
+ else:
115
+ # search all data paths for .wav, .flac (or whatever is specified in include)
116
+ in_basename = ''
117
+
118
+ output_dir = create_ts_name('opredict-' + in_basename)
86
119
  makedirs(output_dir, exist_ok=True)
87
120
 
88
121
  # Setup logging file
89
- create_file_handler(join(output_dir, 'onnx_predict.log'))
122
+ create_file_handler(join(output_dir, 'onnx-predict.log'))
90
123
  update_console_handler(verbose)
91
124
  initial_log_messages('onnx_predict')
92
125
 
93
- model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
94
- model_metadata = get_sonusai_metadata(model)
95
-
96
- batch_size = model_metadata.input_shape[0]
97
- if model_metadata.timestep:
98
- timesteps = model_metadata.input_shape[1]
99
- else:
100
- timesteps = 0
101
- num_classes = model_metadata.output_shape[-1]
102
-
103
- frames_per_batch = get_frames_per_batch(batch_size, timesteps)
104
-
105
- logger.info('')
106
- logger.info(f'feature {model_metadata.feature}')
107
- logger.info(f'num_classes {num_classes}')
108
- logger.info(f'batch_size {batch_size}')
109
- logger.info(f'timesteps {timesteps}')
110
- logger.info(f'flatten {model_metadata.flattened}')
111
- logger.info(f'add1ch {model_metadata.channel}')
112
- logger.info(f'truth_mutex {model_metadata.mutex}')
113
- logger.info(f'input_shape {model_metadata.input_shape}')
114
- logger.info(f'output_shape {model_metadata.output_shape}')
115
- logger.info('')
116
-
117
- if splitext(input_name)[1] == '.wav':
118
- # Convert WAV to feature data
119
- logger.info('')
120
- logger.info(f'Run prediction on {input_name}')
121
- audio = read_audio(input_name)
122
- feature = get_feature_from_audio(audio=audio, feature_mode=model_metadata.feature)
123
-
124
- predict = pad_and_predict(feature=feature,
125
- model_name=model_name,
126
- model_metadata=model_metadata,
127
- frames_per_batch=frames_per_batch,
128
- batch_size=batch_size,
129
- timesteps=timesteps,
130
- reset=reset)
131
-
132
- output_name = splitext(input_name)[0] + '.h5'
133
- with h5py.File(output_name, 'a') as f:
134
- if 'feature' in f:
135
- del f['feature']
136
- f.create_dataset(name='feature', data=feature)
137
-
138
- if 'predict' in f:
139
- del f['predict']
140
- f.create_dataset(name='predict', data=predict)
141
-
142
- logger.info(f'Saved results to {output_name}')
143
- return
144
-
145
- if not isdir(input_name):
146
- logger.exception(f'Do not know how to process input from {input_name}')
147
- raise SystemExit(1)
148
-
149
- mixdb = MixtureDatabase(input_name)
126
+ providers = ort.get_available_providers()
127
+ logger.info(f'Loaded ONNX Runtime, available providers: {providers}.')
150
128
 
151
- if mixdb.feature != model_metadata.feature:
152
- logger.exception(f'Feature in mixture database does not match feature in model')
129
+ session, options, model_root, hparams, sess_inputs, sess_outputs = load_ort_session(model_path)
130
+ if hparams is None:
131
+ logger.error(f'Error: ONNX model does not have required SonusAI hyperparameters, cannot proceed.')
153
132
  raise SystemExit(1)
154
-
155
- mixids = mixdb.mixids_to_list(mixids)
156
- if reset:
157
- # reset mode cycles through each file one at a time
158
- for mixid in mixids:
159
- feature, _ = mixdb.mixture_ft(mixid)
160
-
161
- predict = pad_and_predict(feature=feature,
162
- model_name=model_name,
163
- model_metadata=model_metadata,
164
- frames_per_batch=frames_per_batch,
165
- batch_size=batch_size,
166
- timesteps=timesteps,
167
- reset=reset)
168
-
169
- output_name = join(output_dir, mixdb.mixtures[mixid].name)
170
- with h5py.File(output_name, 'a') as f:
171
- if 'predict' in f:
172
- del f['predict']
173
- f.create_dataset(name='predict', data=predict)
133
+ if len(sess_inputs) != 1:
134
+ logger.error(f'Error: ONNX model does not have 1 input, but {len(sess_inputs)}. Exit due to unknown input.')
135
+
136
+ in0name = sess_inputs[0].name
137
+ in0type = sess_inputs[0].type
138
+ out_names = [n.name for n in session.get_outputs()]
139
+
140
+ logger.info(f'Read and compiled ONNX model from {model_path}.')
141
+
142
+ if mixdb_path is not None:
143
+ # Assume it's a single path to SonusAI mixdb subdir
144
+ logger.debug(f'Attempting to load mixture database from {mixdb_path}')
145
+ mixdb = MixtureDatabase(mixdb_path)
146
+ logger.info(f'SonusAI mixdb: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes')
147
+ p_mixids = mixdb.mixids_to_list(mixids)
148
+ if len(p_mixids) != mixdb.num_mixtures:
149
+ logger.info(f'Processing a subset of {p_mixids} from available mixtures.')
150
+ else:
151
+ for p in data_paths:
152
+ location = join(realpath(abspath(p)), '**', include)
153
+ logger.debug(f'Processing {location}')
154
+ for file in braced_iglob(pathname=location, recursive=True):
155
+ name = file
156
+ entries.append(PathInfo(abs_path=file, audio_filepath=name))
157
+ logger.info(f'{len(data_paths)} data paths specified, found {len(entries)} audio files.')
158
+
159
+ if in0type.find('float16') != -1:
160
+ model_is_fp16 = True
161
+ logger.info(f'Detected input of float16, converting all feature inputs to that type.')
174
162
  else:
175
- features: list[Feature] = []
176
- file_indices: list[slice] = []
177
- total_frames = 0
178
- for mixid in mixids:
179
- current_feature, _ = mixdb.mixture_ft(mixid)
180
- current_frames = current_feature.shape[0]
181
- features.append(current_feature)
182
- file_indices.append(slice(total_frames, total_frames + current_frames))
183
- total_frames += current_frames
184
- feature = np.vstack([features[i] for i in range(len(features))])
185
-
186
- predict = pad_and_predict(feature=feature,
187
- model_name=model_name,
188
- model_metadata=model_metadata,
189
- frames_per_batch=frames_per_batch,
190
- batch_size=batch_size,
191
- timesteps=timesteps,
192
- reset=reset)
193
-
194
- # Write data to separate files
195
- for idx, mixid in enumerate(mixids):
196
- output_name = join(output_dir, mixdb.mixtures[mixid].name)
197
- with h5py.File(output_name, 'a') as f:
163
+ model_is_fp16 = False
164
+
165
+ if mixdb_path is not None and hparams['batch_size'] == 1:
166
+ # mixdb input
167
+ # Assume (of course) that mixdb feature, etc. is what model expects
168
+ if hparams['feature'] != mixdb.feature:
169
+ logger.warning(f'Mixture feature does not match model feature, this inference run may fail.')
170
+ # no choice, can't use hparams.feature since it's different from the mixdb
171
+ feature_mode = mixdb.feature
172
+
173
+ for mixid in p_mixids:
174
+ # frames x stride x feature_params
175
+ feature, _ = mixdb.mixture_ft(mixid)
176
+ if hparams['timesteps'] == 0:
177
+ # no timestep dimension, reshape will handle
178
+ timesteps = 0
179
+ else:
180
+ # fit frames into timestep dimension (TSE mode)
181
+ timesteps = feature.shape[0]
182
+
183
+ feature, _ = reshape_inputs(feature=feature,
184
+ batch_size=1,
185
+ timesteps=timesteps,
186
+ flatten=hparams['flatten'],
187
+ add1ch=hparams['add1ch'])
188
+ if model_is_fp16:
189
+ feature = np.float16(feature) # type: ignore
190
+ # run inference, ort session wants i.e. batch x timesteps x feat_params, outputs numpy BxTxFP or BxFP
191
+ predict = session.run(out_names, {in0name: feature})[0]
192
+ # predict, _ = reshape_outputs(predict=predict[0], timesteps=frames) # frames x feat_params
193
+ output_fname = join(output_dir, mixdb.mixtures[mixid].name)
194
+ with h5py.File(output_fname, 'a') as f:
198
195
  if 'predict' in f:
199
196
  del f['predict']
200
- f.create_dataset('predict', data=predict[file_indices[idx]])
201
-
202
- logger.info(f'Saved results to {output_dir}')
203
-
204
-
205
- def pad_and_predict(feature: Feature,
206
- model_name: str,
207
- model_metadata: SonusAIMetaData,
208
- frames_per_batch: int,
209
- batch_size: int,
210
- timesteps: int,
211
- reset: bool) -> Predict:
212
- import onnxruntime as rt
213
- import numpy as np
214
-
215
- from sonusai.utils import reshape_inputs
216
- from sonusai.utils import reshape_outputs
217
-
218
- frames = feature.shape[0]
219
- padding = frames_per_batch - frames % frames_per_batch
220
- feature = np.pad(array=feature, pad_width=((0, padding), (0, 0), (0, 0)))
221
- feature, _ = reshape_inputs(feature=feature,
222
- batch_size=batch_size,
223
- timesteps=timesteps,
224
- flatten=model_metadata.flattened,
225
- add1ch=model_metadata.channel)
226
- sequences = feature.shape[0] // model_metadata.input_shape[0]
227
- feature = np.reshape(feature, [sequences, *model_metadata.input_shape])
228
-
229
- model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
230
- output_names = [n.name for n in model.get_outputs()]
231
- input_names = [n.name for n in model.get_inputs()]
232
-
233
- predict = []
234
- for sequence in range(sequences):
235
- predict.append(model.run(output_names, {input_names[0]: feature[sequence]}))
236
- if reset:
237
- model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
238
-
239
- predict_arr = np.vstack(predict)
240
- # Combine [sequences, batch_size, ...] into [frames, ...]
241
- predict_shape = predict_arr.shape
242
- predict_arr = np.reshape(predict_arr, [predict_shape[0] * predict_shape[1], *predict_shape[2:]])
243
- predict_arr, _ = reshape_outputs(predict=predict_arr, timesteps=timesteps)
244
- predict_arr = predict_arr[:frames, :]
245
-
246
- return predict_arr
197
+ f.create_dataset('predict', data=predict)
198
+ if wav:
199
+ # note only makes sense if model is predicting audio, i.e., timestep dimension exists
200
+ # predict_audio wants [frames, channels, feature_parameters] equivalent to timesteps, batch, bins
201
+ predict = np.transpose(predict, [1, 0, 2])
202
+ predict_audio = get_audio_from_feature(feature=predict, feature_mode=feature_mode)
203
+ owav_name = splitext(output_fname)[0] + '_predict.wav'
204
+ write_wav(owav_name, predict_audio)
247
205
 
248
206
 
249
207
  if __name__ == '__main__':
@@ -134,7 +134,7 @@ def get_mixids_from_target(mixdb: MixtureDatabase,
134
134
  """
135
135
  return get_mixids_from_mixture_field_predicate(mixdb=mixdb,
136
136
  mixids=mixids,
137
- field='target_id',
137
+ field='target_ids',
138
138
  predicate=predicate)
139
139
 
140
140
 
@@ -0,0 +1,3 @@
1
+ from .textgrid import annotate_textgrid
2
+ from .textgrid import create_textgrid
3
+ from .types import TimeAlignedType
@@ -0,0 +1,116 @@
1
+ import os
2
+ import string
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ from .types import TimeAlignedType
7
+
8
+
9
+ def _get_duration(name: str) -> float:
10
+ import soundfile
11
+
12
+ from sonusai import SonusAIError
13
+
14
+ try:
15
+ return soundfile.info(name).duration
16
+ except Exception as e:
17
+ raise SonusAIError(f'Error reading {name}: {e}')
18
+
19
+
20
+ def load_text(audio: str | os.PathLike[str]) -> Optional[TimeAlignedType]:
21
+ """Load time-aligned text data given a L2-ARCTIC audio file.
22
+
23
+ :param audio: Path to the L2-ARCTIC audio file.
24
+ :return: A TimeAlignedType object.
25
+ """
26
+ file = Path(audio).parent.parent / 'transcript' / (Path(audio).stem + '.txt')
27
+ if not os.path.exists(file):
28
+ return None
29
+
30
+ with open(file, mode='r', encoding='utf-8') as f:
31
+ line = f.read()
32
+
33
+ return TimeAlignedType(0,
34
+ _get_duration(str(audio)),
35
+ line.strip().lower().translate(str.maketrans('', '', string.punctuation)))
36
+
37
+
38
+ def load_words(audio: str | os.PathLike[str]) -> Optional[list[TimeAlignedType]]:
39
+ """Load time-aligned word data given a L2-ARCTIC audio file.
40
+
41
+ :param audio: Path to the L2-ARCTIC audio file.
42
+ :return: A list of TimeAlignedType objects.
43
+ """
44
+ return _load_ta(audio, 'words')
45
+
46
+
47
+ def load_phonemes(audio: str | os.PathLike[str]) -> Optional[list[TimeAlignedType]]:
48
+ """Load time-aligned phonemes data given a L2-ARCTIC audio file.
49
+
50
+ :param audio: Path to the L2-ARCTIC audio file.
51
+ :return: A list of TimeAlignedType objects.
52
+ """
53
+ return _load_ta(audio, 'phones')
54
+
55
+
56
+ def _load_ta(audio: str | os.PathLike[str], tier: str) -> Optional[list[TimeAlignedType]]:
57
+ from praatio import textgrid
58
+
59
+ file = Path(audio).parent.parent / 'textgrid' / (Path(audio).stem + '.TextGrid')
60
+ if not os.path.exists(file):
61
+ return None
62
+
63
+ tg = textgrid.openTextgrid(str(file), includeEmptyIntervals=False)
64
+ if tier not in tg.tierNames:
65
+ return None
66
+
67
+ entries: list[TimeAlignedType] = []
68
+ for entry in tg.getTier(tier).entries:
69
+ entries.append(TimeAlignedType(text=entry.label, start=entry.start, end=entry.end))
70
+
71
+ return entries
72
+
73
+
74
+ def load_annotations(audio: str | os.PathLike[str]) -> Optional[dict[str, list[TimeAlignedType]]]:
75
+ """Load time-aligned annotation data given a L2-ARCTIC audio file.
76
+
77
+ :param audio: Path to the L2-ARCTIC audio file.
78
+ :return: A dictionary of a list of TimeAlignedType objects.
79
+ """
80
+ from praatio import textgrid
81
+
82
+ file = Path(audio).parent.parent / 'annotation' / (Path(audio).stem + '.TextGrid')
83
+ if not os.path.exists(file):
84
+ return None
85
+
86
+ tg = textgrid.openTextgrid(str(file), includeEmptyIntervals=False)
87
+ result: dict[str, list[TimeAlignedType]] = {}
88
+ for tier in tg.tierNames:
89
+ entries: list[TimeAlignedType] = []
90
+ for entry in tg.getTier(tier).entries:
91
+ entries.append(TimeAlignedType(text=entry.label, start=entry.start, end=entry.end))
92
+ result[tier] = entries
93
+
94
+ return result
95
+
96
+
97
+ def load_speakers(input_dir: Path) -> dict:
98
+ speakers = {}
99
+ with open(input_dir / 'readme-download.txt') as file:
100
+ processing = False
101
+ for line in file:
102
+ if not processing and line.startswith('|---|'):
103
+ processing = True
104
+ continue
105
+
106
+ if processing:
107
+ if line.startswith('|**Total**|'):
108
+ break
109
+ else:
110
+ fields = line.strip().split('|')
111
+ speaker_id = fields[1]
112
+ gender = fields[2]
113
+ dialect = fields[3]
114
+ speakers[speaker_id] = {'gender': gender, 'dialect': dialect}
115
+
116
+ return speakers
@@ -0,0 +1,99 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ from .types import TimeAlignedType
6
+
7
+
8
+ def _get_num_samples(audio: str | os.PathLike[str]) -> int:
9
+ """Get number of samples from audio file using soundfile
10
+
11
+ :param audio: Audio file name
12
+ :return: Number of samples
13
+ """
14
+ import soundfile
15
+ from pydub import AudioSegment
16
+
17
+ if Path(audio).suffix == '.mp3':
18
+ return AudioSegment.from_mp3(audio).frame_count()
19
+
20
+ if Path(audio).suffix == '.m4a':
21
+ return AudioSegment.from_file(audio).frame_count()
22
+
23
+ return soundfile.info(audio).frames
24
+
25
+
26
+ def load_text(audio: str | os.PathLike[str]) -> Optional[TimeAlignedType]:
27
+ """Load text data from a LibriSpeech transcription file given a LibriSpeech audio filename.
28
+
29
+ :param audio: Path to the LibriSpeech audio file.
30
+ :return: A TimeAlignedType object.
31
+ """
32
+ import string
33
+
34
+ from sonusai.mixture import get_sample_rate
35
+
36
+ path = Path(audio)
37
+ name = path.stem
38
+ transcript_filename = path.parent / f'{path.parent.parent.name}-{path.parent.name}.trans.txt'
39
+
40
+ if not os.path.exists(transcript_filename):
41
+ return None
42
+
43
+ with open(transcript_filename, mode='r', encoding='utf-8') as f:
44
+ for line in f.readlines():
45
+ fields = line.strip().split()
46
+ key = fields[0]
47
+ if key == name:
48
+ text = ' '.join(fields[1:]).lower().translate(str.maketrans('', '', string.punctuation))
49
+ return TimeAlignedType(0, _get_num_samples(audio) / get_sample_rate(str(audio)), text)
50
+
51
+ return None
52
+
53
+
54
+ def load_words(audio: str | os.PathLike[str]) -> Optional[list[TimeAlignedType]]:
55
+ """Load time-aligned word data given a LibriSpeech audio file.
56
+
57
+ :param audio: Path to the Librispeech audio file.
58
+ :return: A list of TimeAlignedType objects.
59
+ """
60
+ return _load_ta(audio, 'words')
61
+
62
+
63
+ def load_phonemes(audio: str | os.PathLike[str]) -> Optional[list[TimeAlignedType]]:
64
+ """Load time-aligned phonemes data given a LibriSpeech audio file.
65
+
66
+ :param audio: Path to the LibriSpeech audio file.
67
+ :return: A list of TimeAlignedType objects.
68
+ """
69
+ return _load_ta(audio, 'phones')
70
+
71
+
72
+ def _load_ta(audio: str | os.PathLike[str], tier: str) -> Optional[list[TimeAlignedType]]:
73
+ from praatio import textgrid
74
+
75
+ file = Path(audio).with_suffix('.TextGrid')
76
+ if not os.path.exists(file):
77
+ return None
78
+
79
+ tg = textgrid.openTextgrid(str(file), includeEmptyIntervals=False)
80
+ if tier not in tg.tierNames:
81
+ return None
82
+
83
+ entries: list[TimeAlignedType] = []
84
+ for entry in tg.getTier(tier).entries:
85
+ entries.append(TimeAlignedType(text=entry.label, start=entry.start, end=entry.end))
86
+
87
+ return entries
88
+
89
+
90
+ def load_speakers(input_dir: Path) -> dict:
91
+ speakers = {}
92
+ with open(input_dir / 'SPEAKERS.TXT') as file:
93
+ for line in file:
94
+ if not line.startswith(';'):
95
+ fields = line.strip().split('|')
96
+ speaker_id = fields[0].strip()
97
+ gender = fields[1].strip()
98
+ speakers[speaker_id] = {'gender': gender}
99
+ return speakers
@@ -0,0 +1,70 @@
1
+ import os
2
+ from typing import Optional
3
+
4
+ from .types import TimeAlignedType
5
+
6
+
7
+ def load_text(audio: str | os.PathLike[str]) -> Optional[TimeAlignedType]:
8
+ """Load time-aligned text data given a McGill-Speech audio file.
9
+
10
+ :param audio: Path to the McGill-Speech audio file.
11
+ :return: A TimeAlignedType object.
12
+ """
13
+ import string
14
+ import struct
15
+
16
+ from sonusai.mixture import get_sample_rate
17
+
18
+ if not os.path.exists(audio):
19
+ return None
20
+
21
+ sample_rate = get_sample_rate(str(audio))
22
+
23
+ with open(audio, mode='rb') as f:
24
+ content = f.read()
25
+
26
+ riff_id, file_size, wave_id = struct.unpack('<4si4s', content[:12])
27
+ if riff_id.decode('utf-8') != 'RIFF':
28
+ return None
29
+
30
+ if wave_id.decode('utf-8') != 'WAVE':
31
+ return None
32
+
33
+ fmt_id, fmt_size = struct.unpack('<4si', content[12:20])
34
+
35
+ if fmt_id.decode('utf-8') != 'fmt ':
36
+ return None
37
+
38
+ if fmt_size != 16:
39
+ return None
40
+
41
+ (_wave_format_tag,
42
+ channels,
43
+ _samples_per_sec,
44
+ _avg_bytes_per_sec,
45
+ _block_align,
46
+ bits_per_sample) = struct.unpack('<hhiihh', content[20:36])
47
+
48
+ i = 36
49
+ samples = None
50
+ text = None
51
+ while i < file_size:
52
+ chunk_id = struct.unpack('<4s', content[i:i + 4])[0].decode('utf-8')
53
+ chunk_size = struct.unpack('<i', content[i + 4:i + 8])[0]
54
+
55
+ if chunk_id == 'data':
56
+ samples = chunk_size / channels / (bits_per_sample / 8)
57
+ break
58
+
59
+ if chunk_id == 'afsp':
60
+ chunks = struct.unpack(f'<{chunk_size}s', content[i + 8:i + 8 + chunk_size])[0]
61
+ chunks = chunks.decode('utf-8').split('\x00')
62
+ for chunk in chunks:
63
+ if chunk.startswith('text: "'):
64
+ text = chunk[7:-1].lower().translate(str.maketrans('', '', string.punctuation))
65
+ i += 8 + chunk_size + chunk_size % 2
66
+
67
+ if text and samples:
68
+ return TimeAlignedType(start=0, end=samples / sample_rate, text=text)
69
+
70
+ return None