phoonnx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. phoonnx/__init__.py +0 -0
  2. phoonnx/config.py +490 -0
  3. phoonnx/locale/ca/phonetic_spellings.txt +2 -0
  4. phoonnx/locale/en/phonetic_spellings.txt +1 -0
  5. phoonnx/locale/gl/phonetic_spellings.txt +2 -0
  6. phoonnx/locale/pt/phonetic_spellings.txt +2 -0
  7. phoonnx/phoneme_ids.py +453 -0
  8. phoonnx/phonemizers/__init__.py +45 -0
  9. phoonnx/phonemizers/ar.py +42 -0
  10. phoonnx/phonemizers/base.py +216 -0
  11. phoonnx/phonemizers/en.py +250 -0
  12. phoonnx/phonemizers/fa.py +46 -0
  13. phoonnx/phonemizers/gl.py +142 -0
  14. phoonnx/phonemizers/he.py +67 -0
  15. phoonnx/phonemizers/ja.py +119 -0
  16. phoonnx/phonemizers/ko.py +97 -0
  17. phoonnx/phonemizers/mul.py +606 -0
  18. phoonnx/phonemizers/vi.py +44 -0
  19. phoonnx/phonemizers/zh.py +308 -0
  20. phoonnx/thirdparty/__init__.py +0 -0
  21. phoonnx/thirdparty/arpa2ipa.py +249 -0
  22. phoonnx/thirdparty/cotovia/cotovia_aarch64 +0 -0
  23. phoonnx/thirdparty/cotovia/cotovia_x86_64 +0 -0
  24. phoonnx/thirdparty/hangul2ipa.py +783 -0
  25. phoonnx/thirdparty/ko_tables/aspiration.csv +20 -0
  26. phoonnx/thirdparty/ko_tables/assimilation.csv +31 -0
  27. phoonnx/thirdparty/ko_tables/double_coda.csv +17 -0
  28. phoonnx/thirdparty/ko_tables/hanja.tsv +8525 -0
  29. phoonnx/thirdparty/ko_tables/ipa.csv +22 -0
  30. phoonnx/thirdparty/ko_tables/neutralization.csv +11 -0
  31. phoonnx/thirdparty/ko_tables/tensification.csv +56 -0
  32. phoonnx/thirdparty/ko_tables/yale.csv +22 -0
  33. phoonnx/thirdparty/kog2p/__init__.py +385 -0
  34. phoonnx/thirdparty/kog2p/rulebook.txt +212 -0
  35. phoonnx/thirdparty/mantoq/__init__.py +67 -0
  36. phoonnx/thirdparty/mantoq/buck/__init__.py +0 -0
  37. phoonnx/thirdparty/mantoq/buck/phonetise_buckwalter.py +569 -0
  38. phoonnx/thirdparty/mantoq/buck/symbols.py +64 -0
  39. phoonnx/thirdparty/mantoq/buck/tokenization.py +105 -0
  40. phoonnx/thirdparty/mantoq/num2words.py +37 -0
  41. phoonnx/thirdparty/mantoq/pyarabic/__init__.py +12 -0
  42. phoonnx/thirdparty/mantoq/pyarabic/arabrepr.py +64 -0
  43. phoonnx/thirdparty/mantoq/pyarabic/araby.py +1647 -0
  44. phoonnx/thirdparty/mantoq/pyarabic/named_const.py +227 -0
  45. phoonnx/thirdparty/mantoq/pyarabic/normalize.py +161 -0
  46. phoonnx/thirdparty/mantoq/pyarabic/number.py +826 -0
  47. phoonnx/thirdparty/mantoq/pyarabic/number_const.py +1704 -0
  48. phoonnx/thirdparty/mantoq/pyarabic/stack.py +52 -0
  49. phoonnx/thirdparty/mantoq/pyarabic/trans.py +517 -0
  50. phoonnx/thirdparty/mantoq/unicode_symbol2label.py +4173 -0
  51. phoonnx/thirdparty/tashkeel/LICENSE +22 -0
  52. phoonnx/thirdparty/tashkeel/SOURCE +1 -0
  53. phoonnx/thirdparty/tashkeel/__init__.py +212 -0
  54. phoonnx/thirdparty/tashkeel/hint_id_map.json +18 -0
  55. phoonnx/thirdparty/tashkeel/input_id_map.json +56 -0
  56. phoonnx/thirdparty/tashkeel/model.onnx +0 -0
  57. phoonnx/thirdparty/tashkeel/target_id_map.json +17 -0
  58. phoonnx/thirdparty/zh_num.py +238 -0
  59. phoonnx/util.py +705 -0
  60. phoonnx/version.py +6 -0
  61. phoonnx/voice.py +521 -0
  62. phoonnx-0.0.0.dist-info/METADATA +255 -0
  63. phoonnx-0.0.0.dist-info/RECORD +86 -0
  64. phoonnx-0.0.0.dist-info/WHEEL +5 -0
  65. phoonnx-0.0.0.dist-info/top_level.txt +2 -0
  66. phoonnx_train/__main__.py +151 -0
  67. phoonnx_train/export_onnx.py +109 -0
  68. phoonnx_train/norm_audio/__init__.py +92 -0
  69. phoonnx_train/norm_audio/trim.py +54 -0
  70. phoonnx_train/norm_audio/vad.py +54 -0
  71. phoonnx_train/preprocess.py +420 -0
  72. phoonnx_train/vits/__init__.py +0 -0
  73. phoonnx_train/vits/attentions.py +427 -0
  74. phoonnx_train/vits/commons.py +147 -0
  75. phoonnx_train/vits/config.py +330 -0
  76. phoonnx_train/vits/dataset.py +214 -0
  77. phoonnx_train/vits/lightning.py +352 -0
  78. phoonnx_train/vits/losses.py +58 -0
  79. phoonnx_train/vits/mel_processing.py +139 -0
  80. phoonnx_train/vits/models.py +732 -0
  81. phoonnx_train/vits/modules.py +527 -0
  82. phoonnx_train/vits/monotonic_align/__init__.py +20 -0
  83. phoonnx_train/vits/monotonic_align/setup.py +13 -0
  84. phoonnx_train/vits/transforms.py +212 -0
  85. phoonnx_train/vits/utils.py +16 -0
  86. phoonnx_train/vits/wavfile.py +860 -0
@@ -0,0 +1,92 @@
1
+ from hashlib import sha256
2
+ from pathlib import Path
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import librosa
6
+ import torch
7
+
8
+ from phoonnx_train.vits.mel_processing import spectrogram_torch
9
+
10
+ from .trim import trim_silence
11
+ from .vad import SileroVoiceActivityDetector
12
+
13
+ _DIR = Path(__file__).parent
14
+
15
+
16
+ def make_silence_detector() -> SileroVoiceActivityDetector:
17
+ silence_model = _DIR / "models" / "silero_vad.onnx"
18
+ return SileroVoiceActivityDetector(silence_model)
19
+
20
+
21
+ def cache_norm_audio(
22
+ audio_path: Union[str, Path],
23
+ cache_dir: Union[str, Path],
24
+ detector: SileroVoiceActivityDetector,
25
+ sample_rate: int,
26
+ silence_threshold: float = 0.2,
27
+ silence_samples_per_chunk: int = 480,
28
+ silence_keep_chunks_before: int = 2,
29
+ silence_keep_chunks_after: int = 2,
30
+ filter_length: int = 1024,
31
+ window_length: int = 1024,
32
+ hop_length: int = 256,
33
+ ignore_cache: bool = False,
34
+ ) -> Tuple[Path, Path]:
35
+ audio_path = Path(audio_path).absolute()
36
+ cache_dir = Path(cache_dir)
37
+
38
+ # Cache id is the SHA256 of the full audio path
39
+ audio_cache_id = sha256(str(audio_path).encode()).hexdigest()
40
+
41
+ audio_norm_path = cache_dir / f"{audio_cache_id}.pt"
42
+ audio_spec_path = cache_dir / f"{audio_cache_id}.spec.pt"
43
+
44
+ # Normalize audio
45
+ audio_norm_tensor: Optional[torch.FloatTensor] = None
46
+ if ignore_cache or (not audio_norm_path.exists()):
47
+ # Trim silence first.
48
+ #
49
+ # The VAD model works on 16khz, so we determine the portion of audio
50
+ # to keep and then just load that with librosa.
51
+ vad_sample_rate = 16000
52
+ audio_16khz, _sr = librosa.load(path=audio_path, sr=vad_sample_rate)
53
+
54
+ offset_sec, duration_sec = trim_silence(
55
+ audio_16khz,
56
+ detector,
57
+ threshold=silence_threshold,
58
+ samples_per_chunk=silence_samples_per_chunk,
59
+ sample_rate=vad_sample_rate,
60
+ keep_chunks_before=silence_keep_chunks_before,
61
+ keep_chunks_after=silence_keep_chunks_after,
62
+ )
63
+
64
+ # NOTE: audio is already in [-1, 1] coming from librosa
65
+ audio_norm_array, _sr = librosa.load(
66
+ path=audio_path,
67
+ sr=sample_rate,
68
+ offset=offset_sec,
69
+ duration=duration_sec,
70
+ )
71
+
72
+ # Save to cache directory
73
+ audio_norm_tensor = torch.FloatTensor(audio_norm_array).unsqueeze(0)
74
+ torch.save(audio_norm_tensor, audio_norm_path)
75
+
76
+ # Compute spectrogram
77
+ if ignore_cache or (not audio_spec_path.exists()):
78
+ if audio_norm_tensor is None:
79
+ # Load pre-cached normalized audio
80
+ audio_norm_tensor = torch.load(audio_norm_path)
81
+
82
+ audio_spec_tensor = spectrogram_torch(
83
+ y=audio_norm_tensor,
84
+ n_fft=filter_length,
85
+ sampling_rate=sample_rate,
86
+ hop_size=hop_length,
87
+ win_size=window_length,
88
+ center=False,
89
+ ).squeeze(0)
90
+ torch.save(audio_spec_tensor, audio_spec_path)
91
+
92
+ return audio_norm_path, audio_spec_path
@@ -0,0 +1,54 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+
5
+ from .vad import SileroVoiceActivityDetector
6
+
7
+
8
+ def trim_silence(
9
+ audio_array: np.ndarray,
10
+ detector: SileroVoiceActivityDetector,
11
+ threshold: float = 0.2,
12
+ samples_per_chunk=480,
13
+ sample_rate=16000,
14
+ keep_chunks_before: int = 2,
15
+ keep_chunks_after: int = 2,
16
+ ) -> Tuple[float, Optional[float]]:
17
+ """Returns the offset/duration of trimmed audio in seconds"""
18
+ offset_sec: float = 0.0
19
+ duration_sec: Optional[float] = None
20
+ first_chunk: Optional[int] = None
21
+ last_chunk: Optional[int] = None
22
+ seconds_per_chunk: float = samples_per_chunk / sample_rate
23
+
24
+ chunk = audio_array[:samples_per_chunk]
25
+ audio_array = audio_array[samples_per_chunk:]
26
+ chunk_idx: int = 0
27
+
28
+ # Determine main block of speech
29
+ while len(audio_array) > 0:
30
+ prob = detector(chunk, sample_rate=sample_rate)
31
+ is_speech = prob >= threshold
32
+
33
+ if is_speech:
34
+ if first_chunk is None:
35
+ # First speech
36
+ first_chunk = chunk_idx
37
+ else:
38
+ # Last speech so far
39
+ last_chunk = chunk_idx
40
+
41
+ chunk = audio_array[:samples_per_chunk]
42
+ audio_array = audio_array[samples_per_chunk:]
43
+ chunk_idx += 1
44
+
45
+ if (first_chunk is not None) and (last_chunk is not None):
46
+ first_chunk = max(0, first_chunk - keep_chunks_before)
47
+ last_chunk = min(chunk_idx, last_chunk + keep_chunks_after)
48
+
49
+ # Compute offset/duration
50
+ offset_sec = first_chunk * seconds_per_chunk
51
+ last_sec = (last_chunk + 1) * seconds_per_chunk
52
+ duration_sec = last_sec - offset_sec
53
+
54
+ return offset_sec, duration_sec
@@ -0,0 +1,54 @@
1
+ import typing
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import onnxruntime
6
+
7
+
8
+ class SileroVoiceActivityDetector:
9
+ """Detects speech/silence using Silero VAD.
10
+
11
+ https://github.com/snakers4/silero-vad
12
+ """
13
+
14
+ def __init__(self, onnx_path: typing.Union[str, Path]):
15
+ onnx_path = str(onnx_path)
16
+
17
+ self.session = onnxruntime.InferenceSession(onnx_path)
18
+ self.session.intra_op_num_threads = 1
19
+ self.session.inter_op_num_threads = 1
20
+
21
+ self._h = np.zeros((2, 1, 64)).astype("float32")
22
+ self._c = np.zeros((2, 1, 64)).astype("float32")
23
+
24
+ def __call__(self, audio_array: np.ndarray, sample_rate: int = 16000):
25
+ """Return probability of speech in audio [0-1].
26
+
27
+ Audio must be 16Khz 16-bit mono PCM.
28
+ """
29
+ if len(audio_array.shape) == 1:
30
+ # Add batch dimension
31
+ audio_array = np.expand_dims(audio_array, 0)
32
+
33
+ if len(audio_array.shape) > 2:
34
+ raise ValueError(
35
+ f"Too many dimensions for input audio chunk {audio_array.shape}"
36
+ )
37
+
38
+ if audio_array.shape[0] > 1:
39
+ raise ValueError("Onnx model does not support batching")
40
+
41
+ if sample_rate != 16000:
42
+ raise ValueError("Only 16Khz audio is supported")
43
+
44
+ ort_inputs = {
45
+ "input": audio_array.astype(np.float32),
46
+ "h0": self._h,
47
+ "c0": self._c,
48
+ }
49
+ ort_outs = self.session.run(None, ort_inputs)
50
+ out, self._h, self._c = ort_outs
51
+
52
+ out = out.squeeze(2)[:, 1] # make output type match JIT analog
53
+
54
+ return out
@@ -0,0 +1,420 @@
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import csv
4
+ import dataclasses
5
+ import itertools
6
+ import json
7
+ import logging
8
+ import os
9
+ from collections import Counter
10
+ from dataclasses import dataclass
11
+ from multiprocessing import JoinableQueue, Process, Queue
12
+ from pathlib import Path
13
+ from typing import Dict, Iterable, List, Optional, Tuple, Any, Set, Union
14
+
15
+ from phoonnx.util import normalize
16
+ from phoonnx.config import PhonemeType, get_phonemizer, Alphabet
17
+ from phoonnx.phonemizers import Phonemizer
18
+ from phoonnx.phoneme_ids import (phonemes_to_ids, DEFAULT_IPA_PHONEME_ID_MAP, DEFAULT_PAD_TOKEN,
19
+ DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_BLANK_WORD_TOKEN)
20
+ from phoonnx_train.norm_audio import cache_norm_audio, make_silence_detector
21
+ from tqdm import tqdm
22
+
23
+ _VERSION = "0.0.0"
24
+ _LOGGER = logging.getLogger("preprocess")
25
+
26
+ # Base phoneme map
27
+ DEFAULT_SPECIAL_PHONEME_ID_MAP: Dict[str, int] = {
28
+ DEFAULT_PAD_TOKEN: 0,
29
+ DEFAULT_BOS_TOKEN: 1,
30
+ DEFAULT_EOS_TOKEN: 2,
31
+ DEFAULT_BLANK_WORD_TOKEN: 3,
32
+ }
33
+
34
+ # -----------------------------------------------------------------------------
35
+
36
+ @dataclass
37
+ class Utterance:
38
+ """Represents a single utterance in the dataset."""
39
+ text: str
40
+ audio_path: Path
41
+ speaker: Optional[str] = None
42
+ speaker_id: Optional[int] = None
43
+ phonemes: Optional[List[str]] = None
44
+ phoneme_ids: Optional[List[int]] = None
45
+ audio_norm_path: Optional[Path] = None
46
+ audio_spec_path: Optional[Path] = None
47
+
48
+ def asdict(self) -> Dict[str, Any]:
49
+ """Custom asdict to handle Path objects."""
50
+ data = dataclasses.asdict(self)
51
+ for key, value in data.items():
52
+ if isinstance(value, Path):
53
+ data[key] = str(value)
54
+ return data
55
+
56
+
57
+ class PathEncoder(json.JSONEncoder):
58
+ """JSON encoder for Path objects."""
59
+
60
+ def default(self, o):
61
+ if isinstance(o, Path):
62
+ return str(o)
63
+ return super().default(o)
64
+
65
+
66
+ def get_text_casing(casing: str):
67
+ """Returns a function to apply text casing based on a string."""
68
+ if casing == "lower":
69
+ return str.lower
70
+ if casing == "upper":
71
+ return str.upper
72
+ if casing == "casefold":
73
+ return str.casefold
74
+ return lambda s: s
75
+
76
+
77
+ def ljspeech_dataset(args: argparse.Namespace) -> Iterable[Utterance]:
78
+ """
79
+ Generator for LJSpeech-style dataset.
80
+ Loads metadata and resolves audio file paths.
81
+ """
82
+ dataset_dir = args.input_dir
83
+ metadata_path = dataset_dir / "metadata.csv"
84
+ if not metadata_path.exists():
85
+ _LOGGER.error(f"Missing metadata file: {metadata_path}")
86
+ return
87
+
88
+ wav_dirs = [dataset_dir / "wav", dataset_dir / "wavs"]
89
+
90
+ with open(metadata_path, "r", encoding="utf-8") as csv_file:
91
+ reader = csv.reader(csv_file, delimiter="|")
92
+ for row in reader:
93
+ if len(row) < 2:
94
+ _LOGGER.warning(f"Skipping malformed row: {row}")
95
+ continue
96
+
97
+ filename: str = row[0]
98
+ text: str = row[-1]
99
+ speaker: Optional[str] = None
100
+
101
+ if not args.single_speaker and len(row) > 2:
102
+ speaker = row[1]
103
+ else:
104
+ speaker = None
105
+
106
+ wav_path = None
107
+ for wav_dir in wav_dirs:
108
+ potential_paths = [wav_dir / filename, wav_dir / f"{filename}.wav"]
109
+ for path in potential_paths:
110
+ if path.exists():
111
+ wav_path = path
112
+ break
113
+ if wav_path:
114
+ break
115
+
116
+ if not args.skip_audio and not wav_path:
117
+ _LOGGER.warning("Missing audio file for filename: %s", filename)
118
+ continue
119
+
120
+ if not args.skip_audio and wav_path and wav_path.stat().st_size == 0:
121
+ _LOGGER.warning("Empty audio file: %s", wav_path)
122
+ continue
123
+
124
+ yield Utterance(
125
+ text=text,
126
+ audio_path=wav_path,
127
+ speaker=speaker,
128
+ speaker_id=args.speaker_id,
129
+ )
130
+
131
+
132
+ def phonemize_worker(
133
+ args: argparse.Namespace,
134
+ task_queue: JoinableQueue,
135
+ result_queue: Queue,
136
+ phonemizer: Phonemizer,
137
+ ):
138
+ """
139
+ Worker process for phonemization and audio processing.
140
+ Returns the utterance and the unique phonemes found in its batch.
141
+ """
142
+ try:
143
+ casing = get_text_casing(args.text_casing)
144
+ silence_detector = make_silence_detector()
145
+
146
+ while True:
147
+ # Get a batch of utterances to process
148
+ utterance_batch: Union[List[Utterance], None] = task_queue.get()
149
+ if utterance_batch is None:
150
+ # Signal to exit
151
+ task_queue.task_done()
152
+ break
153
+
154
+ for utt in utterance_batch:
155
+ try:
156
+ # Phonemize the text
157
+ norm_utt = casing(normalize(utt.text, args.language))
158
+ utt.phonemes = phonemizer.phonemize_to_list(norm_utt, args.language)
159
+
160
+ # Process audio if not skipping
161
+ if not args.skip_audio:
162
+ utt.audio_norm_path, utt.audio_spec_path = cache_norm_audio(
163
+ utt.audio_path,
164
+ args.cache_dir,
165
+ silence_detector,
166
+ args.sample_rate,
167
+ )
168
+
169
+ # Put the processed utterance and its phonemes into the result queue
170
+ result_queue.put((utt, set(utt.phonemes)))
171
+ except Exception:
172
+ _LOGGER.exception("Failed to process utterance: %s", utt.audio_path)
173
+ result_queue.put((None, set()))
174
+
175
+ task_queue.task_done()
176
+
177
+ except Exception:
178
+ _LOGGER.exception("Worker process failed")
179
+
180
+
181
+ def main() -> None:
182
+ parser = argparse.ArgumentParser(
183
+ description="Preprocess a TTS dataset for training a VITS-style model."
184
+ )
185
+ parser.add_argument(
186
+ "--input-dir", required=True, help="Directory with audio dataset"
187
+ )
188
+ parser.add_argument(
189
+ "--output-dir",
190
+ required=True,
191
+ help="Directory to write output files for training",
192
+ )
193
+ parser.add_argument("--language", required=True, help="eSpeak-ng voice")
194
+ parser.add_argument(
195
+ "--sample-rate",
196
+ type=int,
197
+ required=True,
198
+ help="Target sample rate for voice (hertz)",
199
+ )
200
+ parser.add_argument("--cache-dir", help="Directory to cache processed audio files")
201
+ parser.add_argument("--max-workers", type=int)
202
+ parser.add_argument(
203
+ "--single-speaker", action="store_true", help="Force single speaker dataset"
204
+ )
205
+ parser.add_argument(
206
+ "--speaker-id", type=int, help="Add speaker id to single speaker dataset"
207
+ )
208
+ parser.add_argument(
209
+ "--phoneme-type",
210
+ choices=list(PhonemeType),
211
+ default=PhonemeType.ESPEAK,
212
+ help="Type of phonemes to use (default: espeak)",
213
+ )
214
+ parser.add_argument(
215
+ "--alphabet",
216
+ choices=list(Alphabet),
217
+ default=Alphabet.IPA,
218
+ help="Casing applied to utterance text",
219
+ )
220
+ parser.add_argument(
221
+ "--phonemizer-model",
222
+ default="",
223
+ help="phonemizer model, if applicable",
224
+ )
225
+ parser.add_argument(
226
+ "--text-casing",
227
+ choices=("ignore", "lower", "upper", "casefold"),
228
+ default="ignore",
229
+ help="Casing applied to utterance text",
230
+ )
231
+ parser.add_argument(
232
+ "--dataset-name",
233
+ help="Name of dataset to put in config (default: name of <ouput_dir>/../)",
234
+ )
235
+ parser.add_argument(
236
+ "--audio-quality",
237
+ help="Audio quality to put in config (default: name of <output_dir>)",
238
+ )
239
+ parser.add_argument(
240
+ "--skip-audio", action="store_true", help="Don't preprocess audio"
241
+ )
242
+ parser.add_argument(
243
+ "--debug", action="store_true", help="Print DEBUG messages to the console"
244
+ )
245
+ args = parser.parse_args()
246
+
247
+ # Setup
248
+ level = logging.DEBUG if args.debug else logging.INFO
249
+ logging.basicConfig(level=level)
250
+ logging.getLogger().setLevel(level)
251
+ logging.getLogger("numba").setLevel(logging.WARNING)
252
+
253
+ if args.single_speaker and (args.speaker_id is not None):
254
+ _LOGGER.fatal("--single-speaker and --speaker-id cannot both be provided")
255
+ return
256
+
257
+ args.input_dir = Path(args.input_dir)
258
+ args.output_dir = Path(args.output_dir)
259
+ args.output_dir.mkdir(parents=True, exist_ok=True)
260
+ args.cache_dir = (
261
+ Path(args.cache_dir)
262
+ if args.cache_dir
263
+ else args.output_dir / "cache" / str(args.sample_rate)
264
+ )
265
+ args.cache_dir.mkdir(parents=True, exist_ok=True)
266
+ args.phoneme_type = PhonemeType(args.phoneme_type)
267
+
268
+ # Load all utterances from the dataset
269
+ _LOGGER.info("Loading utterances from dataset...")
270
+ utterances = list(ljspeech_dataset(args))
271
+ if not utterances:
272
+ _LOGGER.error("No valid utterances found in dataset.")
273
+ return
274
+
275
+ num_utterances = len(utterances)
276
+ _LOGGER.info("Found %d utterances.", num_utterances)
277
+
278
+ # Count speakers
279
+ speaker_counts: Counter[str] = Counter(u.speaker for u in utterances if u.speaker)
280
+ is_multispeaker = len(speaker_counts) > 1
281
+ speaker_ids: Dict[str, int] = {}
282
+ if is_multispeaker:
283
+ _LOGGER.info("%s speakers detected", len(speaker_counts))
284
+ # Assign speaker ids by most number of utterances first
285
+ for speaker_id, (speaker, _) in enumerate(speaker_counts.most_common()):
286
+ speaker_ids[speaker] = speaker_id
287
+ else:
288
+ _LOGGER.info("Single speaker dataset")
289
+
290
+ # --- Single Pass: Process audio/phonemes and collect results ---
291
+ # Set up multiprocessing
292
+ args.max_workers = args.max_workers if args.max_workers is not None and args.max_workers > 0 else os.cpu_count()
293
+ _LOGGER.info("Starting single pass processing with %d workers...", args.max_workers)
294
+
295
+ # Initialize the phonemizer only once in the main process
296
+ phonemizer = get_phonemizer(args.phoneme_type, args.alphabet, args.phonemizer_model)
297
+
298
+ batch_size = max(1, int(num_utterances / (args.max_workers * 2)))
299
+
300
+ task_queue: "Queue[Optional[List[Utterance]]]" = JoinableQueue()
301
+ # The result queue will hold tuples of (Utterance, set(phonemes))
302
+ result_queue: "Queue[Optional[Tuple[Utterance, Set[str]]]]" = Queue()
303
+
304
+ # Start workers
305
+ processes = [
306
+ Process(
307
+ target=phonemize_worker,
308
+ args=(args, task_queue, result_queue, phonemizer)
309
+ )
310
+ for _ in range(args.max_workers)
311
+ ]
312
+
313
+ for proc in processes:
314
+ proc.start()
315
+
316
+ # Populate the task queue with batches
317
+ task_count = 0
318
+ for utt_batch in batched(utterances, batch_size):
319
+ task_queue.put(utt_batch)
320
+ task_count += len(utt_batch)
321
+
322
+ # Signal workers to stop
323
+ for _ in range(args.max_workers):
324
+ task_queue.put(None)
325
+
326
+ # Collect results from the queue with a progress bar
327
+ processed_utterances: List[Utterance] = []
328
+ all_phonemes: Set[str] = set()
329
+ for _ in tqdm(range(task_count), desc="Processing utterances"):
330
+ utt, unique_phonemes = result_queue.get()
331
+ if utt is not None:
332
+ processed_utterances.append(utt)
333
+ all_phonemes.update(unique_phonemes)
334
+
335
+ # Wait for workers to finish
336
+ task_queue.join()
337
+ for proc in processes:
338
+ proc.join()
339
+
340
+ # --- Build the final phoneme map from the collected phonemes ---
341
+ _LOGGER.info("Building a complete phoneme map from collected phonemes...")
342
+
343
+ final_phoneme_id_map = DEFAULT_SPECIAL_PHONEME_ID_MAP.copy()
344
+ if phonemizer.alphabet == Alphabet.IPA:
345
+ all_phonemes.update(DEFAULT_IPA_PHONEME_ID_MAP.keys())
346
+
347
+ # Filter out special tokens that are already in the map
348
+ existing_keys = set(final_phoneme_id_map.keys())
349
+ new_phonemes = sorted([p for p in all_phonemes if p not in existing_keys])
350
+
351
+ current_id = len(final_phoneme_id_map)
352
+ for pho in new_phonemes:
353
+ final_phoneme_id_map[pho] = current_id
354
+ current_id += 1
355
+
356
+ _LOGGER.info("Final phoneme map contains %d symbols.", len(final_phoneme_id_map))
357
+
358
+ # --- Write the final config.json ---
359
+ _LOGGER.info("Writing dataset config...")
360
+ audio_quality = args.audio_quality or args.output_dir.name
361
+ dataset_name = args.dataset_name or args.output_dir.parent.name
362
+
363
+ config = {
364
+ "dataset": dataset_name,
365
+ "audio": {
366
+ "sample_rate": args.sample_rate,
367
+ "quality": audio_quality,
368
+ },
369
+ "lang_code": args.language,
370
+ "inference": {"noise_scale": 0.667, "length_scale": 1, "noise_w": 0.8},
371
+ "alphabet": phonemizer.alphabet.value,
372
+ "phoneme_type": args.phoneme_type.value,
373
+ "phonemizer_model": args.phonemizer_model,
374
+ "phoneme_id_map": final_phoneme_id_map,
375
+ "num_symbols": len(final_phoneme_id_map),
376
+ "num_speakers": len(speaker_counts) if is_multispeaker else 1,
377
+ "speaker_id_map": speaker_ids,
378
+ "phoonnx_version": _VERSION,
379
+ }
380
+
381
+ with open(args.output_dir / "config.json", "w", encoding="utf-8") as config_file:
382
+ json.dump(config, config_file, ensure_ascii=False, indent=2)
383
+
384
+ # --- Apply final phoneme IDs and write dataset.jsonl ---
385
+ _LOGGER.info("Writing dataset.jsonl...")
386
+ with open(args.output_dir / "dataset.jsonl", "w", encoding="utf-8") as dataset_file:
387
+ for utt in processed_utterances:
388
+ if utt.speaker is not None:
389
+ utt.speaker_id = speaker_ids[utt.speaker]
390
+
391
+ # Apply the final phoneme ID map to each utterance
392
+ if utt.phonemes:
393
+ utt.phoneme_ids = phonemes_to_ids(utt.phonemes, id_map=final_phoneme_id_map)
394
+
395
+ json.dump(
396
+ utt.asdict(),
397
+ dataset_file,
398
+ ensure_ascii=False,
399
+ cls=PathEncoder,
400
+ )
401
+ print("", file=dataset_file)
402
+
403
+ _LOGGER.info("Preprocessing complete.")
404
+
405
+
406
+ # -----------------------------------------------------------------------------
407
+
408
+ def batched(iterable, n):
409
+ "Batch data into lists of length n. The last batch may be shorter."
410
+ if n < 1:
411
+ raise ValueError("n must be at least one")
412
+ it = iter(iterable)
413
+ batch = list(itertools.islice(it, n))
414
+ while batch:
415
+ yield batch
416
+ batch = list(itertools.islice(it, n))
417
+
418
+
419
+ if __name__ == "__main__":
420
+ main()
File without changes