pocket-tts 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,374 @@
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Pocket TTS - Streaming</title>
7
+ <script src="https://cdn.tailwindcss.com"></script>
8
+ <style>
9
+ body {
10
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
11
+ }
12
+ .spinner {
13
+ animation: spin 1s linear infinite;
14
+ }
15
+ @keyframes spin {
16
+ from { transform: rotate(0deg); }
17
+ to { transform: rotate(360deg); }
18
+ }
19
+ </style>
20
+ </head>
21
+ <body class="bg-gray-50 min-h-screen">
22
+ <div class="max-w-xl mx-auto p-4 space-y-4">
23
+ <h1 class="text-xl font-bold text-center">Pocket TTS</h1>
24
+
25
+ <div class="bg-white p-4 rounded-lg shadow">
26
+ <textarea
27
+ id="text-input"
28
+ class="w-full border rounded p-2"
29
+ placeholder="Enter your text here..."
30
+ rows="4"
31
+ >Hello world. I am Kyutai's Pocket TTS. I'm fast enough to run on small CPUs. I hope you'll like me.</textarea>
32
+ </div>
33
+
34
+ <div class="bg-white p-4 rounded-lg shadow">
35
+ <label for="voice-url-input" class="block text-sm font-medium text-gray-700 mb-2">
36
+ Optional voice URL (leave empty to use default voice):
37
+ </label>
38
+ <input
39
+ type="text"
40
+ id="voice-url-input"
41
+ class="w-full border rounded p-2"
42
+ placeholder="hf://kyutai/tts-voices/alba-mackenna/casual.wav"
43
+ value="alba"
44
+ />
45
+ <p class="text-xs text-gray-500 mt-1">
46
+ Supports: http://, https://, or hf:// URLs.<br>
47
+ You can also use predefined voices:<br>
48
+ "alba", "marius", "javert", "jean", "fantine", "cosette", "eponine", "azelma".
49
+ </p>
50
+ </div>
51
+
52
+ <div class="bg-white p-4 rounded-lg shadow">
53
+ <label for="voice-wav-input" class="block text-sm font-medium text-gray-700 mb-2">
54
+ Or upload an audio file for voice cloning:
55
+ </label>
56
+ <input
57
+ type="file"
58
+ id="voice-wav-input"
59
+ class="w-full border rounded p-2"
60
+ accept=".wav,.mp3,.flac,.ogg,.m4a,audio/*"
61
+ />
62
+ <p class="text-xs text-gray-500 mt-1">
63
+ Upload an audio file (WAV, MP3, FLAC, etc.) to use as voice reference. Takes precedence over voice URL.
64
+ </p>
65
+ </div>
66
+
67
+
68
+ <button
69
+ id="generate-btn"
70
+ class="w-full px-4 py-2 rounded text-white bg-blue-600 hover:bg-blue-700 disabled:bg-gray-400 disabled:cursor-not-allowed"
71
+ >
72
+ <span id="generate-text">Generate audio</span>
73
+ </button>
74
+
75
+ <div id="status" class="hidden">
76
+ <div class="flex items-center space-x-2">
77
+ <div id="status-spinner" class="spinner w-4 h-4 border-2 border-blue-600 border-t-transparent rounded-full"></div>
78
+ <span id="status-text"></span>
79
+ </div>
80
+ </div>
81
+
82
+ <div id="audio-section" class="hidden">
83
+ <div class="bg-white p-4 rounded-lg shadow">
84
+ <h3 class="text-lg font-semibold mb-2">Generated Audio</h3>
85
+ <audio id="audio" controls class="w-full"></audio>
86
+ <div class="mt-2">
87
+ <a id="download-link" class="text-blue-600 text-sm hover:underline">Download</a>
88
+ </div>
89
+ </div>
90
+ </div>
91
+ </div>
92
+
93
+ <script>
94
+ class StreamingWavPlayer {
95
+ constructor() {
96
+ this.audioContext = new (window.AudioContext || window.webkitAudioContext)();
97
+ this.sampleRate = 0;
98
+ this.numChannels = 0;
99
+ this.headerParsed = false;
100
+ this.headerBuffer = new Uint8Array(44);
101
+ this.headerBytesReceived = 0;
102
+ this.nextStartTime = 0;
103
+ this.isPlaying = false;
104
+ this.minBufferSize = 16384;
105
+ this.pcmData = new Uint8Array(0);
106
+ }
107
+
108
+ parseWavHeader(header) {
109
+ const view = new DataView(header.buffer);
110
+
111
+ const riff = String.fromCharCode.apply(null, Array.from(header.slice(0, 4)));
112
+ const wave = String.fromCharCode.apply(null, Array.from(header.slice(8, 12)));
113
+
114
+ if (riff !== 'RIFF' || wave !== 'WAVE') {
115
+ throw new Error('Invalid WAV file');
116
+ }
117
+
118
+ this.numChannels = view.getUint16(22, true);
119
+ this.sampleRate = view.getUint32(24, true);
120
+ const bitsPerSample = view.getUint16(34, true);
121
+
122
+ console.log(`WAV Format: ${this.sampleRate}Hz, ${this.numChannels} channels, ${bitsPerSample} bits`);
123
+
124
+ this.headerParsed = true;
125
+ }
126
+
127
+ appendPcmData(newData) {
128
+ const newBuffer = new Uint8Array(this.pcmData.length + newData.length);
129
+ newBuffer.set(this.pcmData);
130
+ newBuffer.set(newData, this.pcmData.length);
131
+ this.pcmData = newBuffer;
132
+ }
133
+
134
+ async tryPlayBuffer() {
135
+ if (!this.headerParsed || this.pcmData.length < this.minBufferSize) {
136
+ return;
137
+ }
138
+
139
+ const bytesPerSample = this.numChannels * 2;
140
+ const samplesToPlay = Math.floor(this.pcmData.length / bytesPerSample);
141
+ const bytesToPlay = samplesToPlay * bytesPerSample;
142
+
143
+ if (bytesToPlay === 0) return;
144
+
145
+ const dataToPlay = this.pcmData.slice(0, bytesToPlay);
146
+ this.pcmData = this.pcmData.slice(bytesToPlay);
147
+
148
+ const audioBuffer = this.audioContext.createBuffer(
149
+ this.numChannels,
150
+ samplesToPlay,
151
+ this.sampleRate
152
+ );
153
+
154
+ const int16Data = new Int16Array(dataToPlay.buffer, dataToPlay.byteOffset, samplesToPlay * this.numChannels);
155
+
156
+ for (let channel = 0; channel < this.numChannels; channel++) {
157
+ const channelData = audioBuffer.getChannelData(channel);
158
+ for (let i = 0; i < samplesToPlay; i++) {
159
+ channelData[i] = int16Data[i * this.numChannels + channel] / 32768;
160
+ }
161
+ }
162
+
163
+ const source = this.audioContext.createBufferSource();
164
+ source.buffer = audioBuffer;
165
+ source.connect(this.audioContext.destination);
166
+
167
+ const currentTime = this.audioContext.currentTime;
168
+ const startTime = Math.max(currentTime, this.nextStartTime);
169
+
170
+ source.start(startTime);
171
+
172
+ // Track first audio playback
173
+ if (!this.firstAudioPlayed && window.firstAudioCallback) {
174
+ this.firstAudioPlayed = true;
175
+ window.firstAudioCallback();
176
+ }
177
+
178
+ this.nextStartTime = startTime + audioBuffer.duration;
179
+ this.isPlaying = true;
180
+
181
+ if (this.pcmData.length >= this.minBufferSize) {
182
+ setTimeout(() => this.tryPlayBuffer(), 10);
183
+ }
184
+ }
185
+
186
+ addChunk(chunk) {
187
+ if (!this.headerParsed) {
188
+ const headerBytesNeeded = 44 - this.headerBytesReceived;
189
+ const bytesToCopy = Math.min(headerBytesNeeded, chunk.length);
190
+
191
+ this.headerBuffer.set(
192
+ chunk.slice(0, bytesToCopy),
193
+ this.headerBytesReceived
194
+ );
195
+
196
+ this.headerBytesReceived += bytesToCopy;
197
+
198
+ if (this.headerBytesReceived >= 44) {
199
+ this.parseWavHeader(this.headerBuffer);
200
+
201
+ if (chunk.length > bytesToCopy) {
202
+ this.appendPcmData(chunk.slice(bytesToCopy));
203
+ }
204
+ }
205
+ } else {
206
+ this.appendPcmData(chunk);
207
+ }
208
+
209
+ this.tryPlayBuffer();
210
+ }
211
+
212
+ stop() {
213
+ this.audioContext.close();
214
+ this.isPlaying = false;
215
+ }
216
+ }
217
+
218
+ // Application state
219
+ let streamingPlayer = null;
220
+ let currentAudioBlob = null;
221
+
222
+ // DOM elements
223
+ const textInput = document.getElementById('text-input');
224
+ const generateBtn = document.getElementById('generate-btn');
225
+ const generateText = document.getElementById('generate-text');
226
+ const status = document.getElementById('status');
227
+ const statusText = document.getElementById('status-text');
228
+ const statusSpinner = document.getElementById('status-spinner');
229
+ const audioSection = document.getElementById('audio-section');
230
+ const audio = document.getElementById('audio');
231
+ const downloadLink = document.getElementById('download-link');
232
+
233
+ // Event listeners
234
+ generateBtn.addEventListener('click', generateAudio);
235
+
236
+ function showStatus(message, isLoading = false) {
237
+ statusText.textContent = message;
238
+ statusSpinner.style.display = isLoading ? 'block' : 'none';
239
+ status.classList.remove('hidden');
240
+ }
241
+
242
+ function hideStatus() {
243
+ status.classList.add('hidden');
244
+ }
245
+
246
+ async function generateAudio() {
247
+ const text = textInput.value.trim();
248
+
249
+ if (!text) {
250
+ showStatus('Please enter some text to generate speech.', false);
251
+ setTimeout(hideStatus, 3000);
252
+ return;
253
+ }
254
+
255
+ // Stop any currently playing audio
256
+ if (streamingPlayer) {
257
+ streamingPlayer.stop();
258
+ streamingPlayer = null;
259
+ }
260
+
261
+ // Track timing
262
+ const startTime = performance.now();
263
+ let firstAudioTime = null;
264
+
265
+ // Set callback for first audio
266
+ window.firstAudioCallback = () => {
267
+ if (!firstAudioTime) {
268
+ firstAudioTime = performance.now();
269
+ const timeToFirstAudio = ((firstAudioTime - startTime) / 1000).toFixed(2);
270
+ showStatus(`First audio in ${timeToFirstAudio}s...`, true);
271
+ }
272
+ };
273
+
274
+ // Update UI
275
+ generateBtn.disabled = true;
276
+ generateText.textContent = 'Generating...';
277
+ showStatus('Generating speech...', true);
278
+ audioSection.classList.add('hidden');
279
+
280
+ try {
281
+ const formData = new FormData();
282
+ formData.append('text', text);
283
+
284
+ // Add voice URL if provided (only if no WAV file is uploaded)
285
+ const voiceUrl = document.getElementById('voice-url-input').value.trim();
286
+ const voiceWavFile = document.getElementById('voice-wav-input').files[0];
287
+
288
+ if (voiceWavFile) {
289
+ // If WAV file is uploaded, only use the WAV file (ignore voice URL)
290
+ formData.append('voice_wav', voiceWavFile);
291
+ } else if (voiceUrl) {
292
+ // Only use voice URL if no WAV file is uploaded
293
+ formData.append('voice_url', voiceUrl);
294
+ }
295
+
296
+ const response = await fetch('/tts', {
297
+ method: 'POST',
298
+ body: formData
299
+ });
300
+
301
+ if (!response.ok) {
302
+ throw new Error(`Server error: ${response.status}`);
303
+ }
304
+
305
+ // Clone the response for both streaming and blob collection
306
+ const responseForPlayback = response.clone();
307
+ const responseForHistory = response.clone();
308
+
309
+ // Start streaming playback
310
+ const reader = responseForPlayback.body.getReader();
311
+ streamingPlayer = new StreamingWavPlayer();
312
+
313
+ const processStream = async () => {
314
+ try {
315
+ while (true) {
316
+ const { done, value } = await reader.read();
317
+ if (done) break;
318
+
319
+ if (value) {
320
+ streamingPlayer.addChunk(value);
321
+ }
322
+ }
323
+ } catch (e) {
324
+ console.error('Error processing stream:', e);
325
+ }
326
+ };
327
+
328
+ // Start processing the stream
329
+ processStream();
330
+
331
+ // Also collect the full blob for download/audio element
332
+ const blob = await responseForHistory.blob();
333
+ const totalTime = ((performance.now() - startTime) / 1000).toFixed(2);
334
+ currentAudioBlob = blob;
335
+
336
+ // Update UI with audio
337
+ const audioUrl = URL.createObjectURL(blob);
338
+ audio.src = audioUrl;
339
+
340
+ // Wait for audio metadata to load to get duration
341
+ audio.addEventListener('loadedmetadata', () => {
342
+ const audioDuration = audio.duration;
343
+ const speedRatio = (audioDuration / parseFloat(totalTime)).toFixed(1);
344
+
345
+ const timeToFirst = firstAudioTime ? ((firstAudioTime - startTime) / 1000).toFixed(2) : 'N/A';
346
+
347
+ showStatus(
348
+ `✨ First audio: ${timeToFirst}s | Total: ${totalTime}s | ${speedRatio}x faster than real-time`,
349
+ false
350
+ );
351
+ });
352
+ downloadLink.href = audioUrl;
353
+ downloadLink.download = `tts-audio.wav`;
354
+
355
+ audioSection.classList.remove('hidden');
356
+
357
+ } catch (error) {
358
+ console.error('Error generating audio:', error);
359
+ showStatus(`Error: ${error.message}`, false);
360
+ setTimeout(hideStatus, 3000);
361
+ } finally {
362
+ // Re-enable button
363
+ generateBtn.disabled = false;
364
+ generateText.textContent = 'Generate Audio';
365
+ }
366
+ }
367
+
368
+ // Focus on text input when page loads
369
+ window.addEventListener('load', () => {
370
+ textInput.focus();
371
+ });
372
+ </script>
373
+ </body>
374
+ </html>
@@ -0,0 +1 @@
1
+ """Utilities."""
@@ -0,0 +1,122 @@
1
+ """Configuration models for loading YAML config files."""
2
+
3
+ from pathlib import Path
4
+
5
+ import yaml
6
+ from pydantic import BaseModel, ConfigDict
7
+
8
+
9
+ class StrictModel(BaseModel):
10
+ model_config = ConfigDict(extra="forbid")
11
+
12
+
13
+ # Flow configuration
14
+ class FlowConfig(StrictModel):
15
+ dim: int
16
+ depth: int
17
+
18
+
19
+ # Transformer configuration for FlowLM
20
+ class FlowLMTransformerConfig(StrictModel):
21
+ hidden_scale: int
22
+ max_period: int
23
+ d_model: int
24
+ num_heads: int
25
+ num_layers: int
26
+
27
+
28
+ class LookupTable(StrictModel):
29
+ dim: int
30
+ n_bins: int
31
+ tokenizer: str
32
+ tokenizer_path: str
33
+
34
+
35
+ # Root configuration
36
+ class FlowLMConfig(StrictModel):
37
+ """Root configuration model for YAML config files."""
38
+
39
+ dtype: str
40
+
41
+ # Nested configurations
42
+ flow: FlowConfig
43
+ transformer: FlowLMTransformerConfig
44
+
45
+ # conditioning
46
+ lookup_table: LookupTable
47
+ weights_path: str | None = None
48
+
49
+
50
+ # SEANet configuration
51
+ class SEANetConfig(StrictModel):
52
+ dimension: int
53
+ channels: int
54
+ n_filters: int
55
+ n_residual_layers: int
56
+ ratios: list[int]
57
+ kernel_size: int
58
+ residual_kernel_size: int
59
+ last_kernel_size: int
60
+ dilation_base: int
61
+ pad_mode: str
62
+ compress: int
63
+
64
+
65
+ # Transformer configuration for Mimi
66
+ class MimiTransformerConfig(StrictModel):
67
+ d_model: int
68
+ input_dimension: int
69
+ output_dimensions: tuple[int, ...]
70
+ num_heads: int
71
+ num_layers: int
72
+ layer_scale: float
73
+ context: int
74
+ max_period: float = 10000.0
75
+ dim_feedforward: int
76
+
77
+
78
+ # Quantizer configuration
79
+ class QuantizerConfig(StrictModel):
80
+ dimension: int
81
+ output_dimension: int
82
+
83
+
84
+ # Root configuration
85
+ class MimiConfig(StrictModel):
86
+ """Root configuration model for Mimi YAML config files."""
87
+
88
+ dtype: str
89
+
90
+ # Sample rate and channels
91
+ sample_rate: int
92
+ channels: int
93
+ frame_rate: float
94
+
95
+ # SEANet configurations
96
+ seanet: SEANetConfig
97
+
98
+ # Transformer
99
+ transformer: MimiTransformerConfig
100
+
101
+ # Quantizer
102
+ quantizer: QuantizerConfig
103
+ weights_path: str | None = None
104
+
105
+
106
+ class Config(StrictModel):
107
+ flow_lm: FlowLMConfig
108
+ mimi: MimiConfig
109
+ weights_path: str | None = None
110
+ weights_path_without_voice_cloning: str | None = None
111
+
112
+
113
+ def load_config(yaml_path: str | Path) -> Config:
114
+ yaml_path = Path(yaml_path)
115
+
116
+ if not yaml_path.exists():
117
+ raise FileNotFoundError(f"Config file not found: {yaml_path}")
118
+
119
+ with open(yaml_path, "r") as f:
120
+ config_dict = yaml.safe_load(f)
121
+
122
+ return Config(**config_dict)
@@ -0,0 +1,26 @@
1
+ import torch
2
+ from torch.utils._python_dispatch import TorchDispatchMode
3
+
4
+
5
+ def to_str(obj):
6
+ if isinstance(obj, (torch.Tensor, torch.nn.Parameter)):
7
+ return f"T(s={list(obj.shape)})"
8
+ elif isinstance(obj, (list, tuple)):
9
+ return "[" + ", ".join(to_str(o) for o in obj) + "]"
10
+ elif isinstance(obj, dict):
11
+ return "{" + ", ".join(f"{to_str(k)}: {to_str(v)}" for k, v in obj.items()) + "}"
12
+ else:
13
+ return str(obj)
14
+
15
+
16
+ class LoggingMode(TorchDispatchMode):
17
+ """Useful to check implementation differences."""
18
+
19
+ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
20
+ output = func(*args, **kwargs or {})
21
+ print(
22
+ f"Aten function called: {func}, args: "
23
+ f"{to_str(args)}, kwargs: {to_str(kwargs)} -> "
24
+ f"output: {to_str(output)}"
25
+ )
26
+ return output
@@ -0,0 +1,41 @@
1
+ import logging
2
+ from contextlib import contextmanager
3
+
4
+
5
+ class PocketTTSFilter(logging.Filter):
6
+ def filter(self, record):
7
+ return record.name.startswith("pocket_tts")
8
+
9
+
10
+ @contextmanager
11
+ def enable_logging(library_name, level):
12
+ # Get the specific logger and its parent
13
+ logger = logging.getLogger(library_name)
14
+ parent_logger = logging.getLogger("pocket_tts")
15
+
16
+ # Store original configuration
17
+ old_level = logger.level
18
+ old_parent_level = parent_logger.level
19
+ old_handlers = parent_logger.handlers.copy()
20
+
21
+ # Configure logging format for pocket_tts logger
22
+ parent_logger.setLevel(level)
23
+
24
+ # Clear existing handlers and add our custom formatter with filter
25
+ parent_logger.handlers.clear()
26
+ handler = logging.StreamHandler()
27
+ formatter = logging.Formatter("%(levelname)s: %(message)s")
28
+ handler.setFormatter(formatter)
29
+ handler.addFilter(PocketTTSFilter())
30
+ parent_logger.addHandler(handler)
31
+ parent_logger.propagate = False
32
+
33
+ try:
34
+ yield logger
35
+ finally:
36
+ # Restore original configuration
37
+ logger.setLevel(old_level)
38
+ parent_logger.setLevel(old_parent_level)
39
+ parent_logger.handlers.clear()
40
+ for h in old_handlers:
41
+ parent_logger.addHandler(h)
@@ -0,0 +1,103 @@
1
+ import hashlib
2
+ import logging
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import requests
7
+ import safetensors.torch
8
+ import torch
9
+ from huggingface_hub import hf_hub_download
10
+ from torch import nn
11
+
12
+ PROJECT_ROOT = Path(__file__).parent.parent.parent
13
+
14
+ _voices_names = ["alba", "marius", "javert", "jean", "fantine", "cosette", "eponine", "azelma"]
15
+ PREDEFINED_VOICES = {
16
+ # don't forget to change this
17
+ x: f"hf://kyutai/pocket-tts-without-voice-cloning/embeddings/{x}.safetensors@d4fdd22ae8c8e1cb3634e150ebeff1dab2d16df3"
18
+ for x in _voices_names
19
+ }
20
+
21
+
22
+ def make_cache_directory() -> Path:
23
+ cache_dir = Path.home() / ".cache" / "pocket_tts"
24
+ cache_dir.mkdir(parents=True, exist_ok=True)
25
+ return cache_dir
26
+
27
+
28
+ def print_nb_parameters(model: nn.Module, model_name: str):
29
+ logger = logging.getLogger(__name__)
30
+ state_dict = model.state_dict()
31
+ total = 0
32
+ for key, value in state_dict.items():
33
+ logger.info("%s: %,d", key, value.numel())
34
+ total += value.numel()
35
+ logger.info("Total number of parameters in %s: %,d", model_name, total)
36
+
37
+
38
+ def size_of_dict(state_dict: dict) -> int:
39
+ total_size = 0
40
+ for value in state_dict.values():
41
+ if isinstance(value, torch.Tensor):
42
+ total_size += value.numel() * value.element_size()
43
+ elif isinstance(value, dict):
44
+ total_size += size_of_dict(value)
45
+ return total_size
46
+
47
+
48
+ class display_execution_time:
49
+ def __init__(self, task_name: str, print_output: bool = True):
50
+ self.task_name = task_name
51
+ self.print_output = print_output
52
+ self.start_time = None
53
+ self.elapsed_time_ms = None
54
+ self.logger = logging.getLogger(__name__)
55
+
56
+ def __enter__(self):
57
+ self.start_time = time.monotonic()
58
+ return self
59
+
60
+ def __exit__(self, exc_type, exc_val, exc_tb):
61
+ end_time = time.monotonic()
62
+ self.elapsed_time_ms = int((end_time - self.start_time) * 1000)
63
+ if self.print_output:
64
+ self.logger.info("%s took %d ms", self.task_name, self.elapsed_time_ms)
65
+ return False # Don't suppress exceptions
66
+
67
+
68
+ def download_if_necessary(file_path: str) -> Path:
69
+ if file_path.startswith("http://") or file_path.startswith("https://"):
70
+ cache_dir = make_cache_directory()
71
+ cached_file = cache_dir / (
72
+ hashlib.sha256(file_path.encode()).hexdigest() + "." + file_path.split(".")[-1]
73
+ )
74
+ if not cached_file.exists():
75
+ response = requests.get(file_path)
76
+ response.raise_for_status()
77
+ with open(cached_file, "wb") as f:
78
+ f.write(response.content)
79
+ return cached_file
80
+ elif file_path.startswith("hf://"):
81
+ file_path = file_path.removeprefix("hf://")
82
+ splitted = file_path.split("/")
83
+ repo_id = "/".join(splitted[:2])
84
+ filename = "/".join(splitted[2:])
85
+ if "@" in filename:
86
+ filename, revision = filename.split("@")
87
+ else:
88
+ revision = None
89
+ cached_file = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision)
90
+ return Path(cached_file)
91
+ else:
92
+ return Path(file_path)
93
+
94
+
95
+ def load_predefined_voice(voice_name: str) -> torch.Tensor:
96
+ if voice_name not in PREDEFINED_VOICES:
97
+ raise ValueError(
98
+ f"Predefined voice '{voice_name}' not found"
99
+ f", available voices are {list(PREDEFINED_VOICES)}."
100
+ )
101
+ voice_file = download_if_necessary(PREDEFINED_VOICES[voice_name])
102
+ # There is only one tensor in the file.
103
+ return safetensors.torch.load_file(voice_file)["audio_prompt"]