pocket-tts 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pocket_tts/__init__.py +16 -0
- pocket_tts/__main__.py +6 -0
- pocket_tts/conditioners/__init__.py +0 -0
- pocket_tts/conditioners/base.py +38 -0
- pocket_tts/conditioners/text.py +61 -0
- pocket_tts/config/b6369a24.yaml +57 -0
- pocket_tts/data/__init__.py +2 -0
- pocket_tts/data/audio.py +144 -0
- pocket_tts/data/audio_utils.py +28 -0
- pocket_tts/default_parameters.py +7 -0
- pocket_tts/main.py +262 -0
- pocket_tts/models/__init__.py +3 -0
- pocket_tts/models/flow_lm.py +208 -0
- pocket_tts/models/mimi.py +111 -0
- pocket_tts/models/tts_model.py +782 -0
- pocket_tts/modules/__init__.py +1 -0
- pocket_tts/modules/conv.py +161 -0
- pocket_tts/modules/dummy_quantizer.py +18 -0
- pocket_tts/modules/layer_scale.py +11 -0
- pocket_tts/modules/mimi_transformer.py +285 -0
- pocket_tts/modules/mlp.py +215 -0
- pocket_tts/modules/resample.py +46 -0
- pocket_tts/modules/rope.py +74 -0
- pocket_tts/modules/seanet.py +180 -0
- pocket_tts/modules/stateful_module.py +45 -0
- pocket_tts/modules/transformer.py +124 -0
- pocket_tts/static/index.html +374 -0
- pocket_tts/utils/__init__.py +1 -0
- pocket_tts/utils/config.py +122 -0
- pocket_tts/utils/debugging.py +26 -0
- pocket_tts/utils/logging_utils.py +41 -0
- pocket_tts/utils/utils.py +103 -0
- pocket_tts/utils/weights_loading.py +35 -0
- pocket_tts-1.0.2.dist-info/METADATA +174 -0
- pocket_tts-1.0.2.dist-info/RECORD +38 -0
- pocket_tts-1.0.2.dist-info/WHEEL +4 -0
- pocket_tts-1.0.2.dist-info/entry_points.txt +2 -0
- pocket_tts-1.0.2.dist-info/licenses/LICENSE +23 -0
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
<!DOCTYPE html>
|
|
2
|
+
<html lang="en">
|
|
3
|
+
<head>
|
|
4
|
+
<meta charset="UTF-8">
|
|
5
|
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
6
|
+
<title>Pocket TTS - Streaming</title>
|
|
7
|
+
<script src="https://cdn.tailwindcss.com"></script>
|
|
8
|
+
<style>
|
|
9
|
+
body {
|
|
10
|
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
|
11
|
+
}
|
|
12
|
+
.spinner {
|
|
13
|
+
animation: spin 1s linear infinite;
|
|
14
|
+
}
|
|
15
|
+
@keyframes spin {
|
|
16
|
+
from { transform: rotate(0deg); }
|
|
17
|
+
to { transform: rotate(360deg); }
|
|
18
|
+
}
|
|
19
|
+
</style>
|
|
20
|
+
</head>
|
|
21
|
+
<body class="bg-gray-50 min-h-screen">
|
|
22
|
+
<div class="max-w-xl mx-auto p-4 space-y-4">
|
|
23
|
+
<h1 class="text-xl font-bold text-center">Pocket TTS</h1>
|
|
24
|
+
|
|
25
|
+
<div class="bg-white p-4 rounded-lg shadow">
|
|
26
|
+
<textarea
|
|
27
|
+
id="text-input"
|
|
28
|
+
class="w-full border rounded p-2"
|
|
29
|
+
placeholder="Enter your text here..."
|
|
30
|
+
rows="4"
|
|
31
|
+
>Hello world. I am Kyutai's Pocket TTS. I'm fast enough to run on small CPUs. I hope you'll like me.</textarea>
|
|
32
|
+
</div>
|
|
33
|
+
|
|
34
|
+
<div class="bg-white p-4 rounded-lg shadow">
|
|
35
|
+
<label for="voice-url-input" class="block text-sm font-medium text-gray-700 mb-2">
|
|
36
|
+
Optional voice URL (leave empty to use default voice):
|
|
37
|
+
</label>
|
|
38
|
+
<input
|
|
39
|
+
type="text"
|
|
40
|
+
id="voice-url-input"
|
|
41
|
+
class="w-full border rounded p-2"
|
|
42
|
+
placeholder="hf://kyutai/tts-voices/alba-mackenna/casual.wav"
|
|
43
|
+
value="alba"
|
|
44
|
+
/>
|
|
45
|
+
<p class="text-xs text-gray-500 mt-1">
|
|
46
|
+
Supports: http://, https://, or hf:// URLs.<br>
|
|
47
|
+
You can also use predefined voices:<br>
|
|
48
|
+
"alba", "marius", "javert", "jean", "fantine", "cosette", "eponine", "azelma".
|
|
49
|
+
</p>
|
|
50
|
+
</div>
|
|
51
|
+
|
|
52
|
+
<div class="bg-white p-4 rounded-lg shadow">
|
|
53
|
+
<label for="voice-wav-input" class="block text-sm font-medium text-gray-700 mb-2">
|
|
54
|
+
Or upload an audio file for voice cloning:
|
|
55
|
+
</label>
|
|
56
|
+
<input
|
|
57
|
+
type="file"
|
|
58
|
+
id="voice-wav-input"
|
|
59
|
+
class="w-full border rounded p-2"
|
|
60
|
+
accept=".wav,.mp3,.flac,.ogg,.m4a,audio/*"
|
|
61
|
+
/>
|
|
62
|
+
<p class="text-xs text-gray-500 mt-1">
|
|
63
|
+
Upload an audio file (WAV, MP3, FLAC, etc.) to use as voice reference. Takes precedence over voice URL.
|
|
64
|
+
</p>
|
|
65
|
+
</div>
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
<button
|
|
69
|
+
id="generate-btn"
|
|
70
|
+
class="w-full px-4 py-2 rounded text-white bg-blue-600 hover:bg-blue-700 disabled:bg-gray-400 disabled:cursor-not-allowed"
|
|
71
|
+
>
|
|
72
|
+
<span id="generate-text">Generate audio</span>
|
|
73
|
+
</button>
|
|
74
|
+
|
|
75
|
+
<div id="status" class="hidden">
|
|
76
|
+
<div class="flex items-center space-x-2">
|
|
77
|
+
<div id="status-spinner" class="spinner w-4 h-4 border-2 border-blue-600 border-t-transparent rounded-full"></div>
|
|
78
|
+
<span id="status-text"></span>
|
|
79
|
+
</div>
|
|
80
|
+
</div>
|
|
81
|
+
|
|
82
|
+
<div id="audio-section" class="hidden">
|
|
83
|
+
<div class="bg-white p-4 rounded-lg shadow">
|
|
84
|
+
<h3 class="text-lg font-semibold mb-2">Generated Audio</h3>
|
|
85
|
+
<audio id="audio" controls class="w-full"></audio>
|
|
86
|
+
<div class="mt-2">
|
|
87
|
+
<a id="download-link" class="text-blue-600 text-sm hover:underline">Download</a>
|
|
88
|
+
</div>
|
|
89
|
+
</div>
|
|
90
|
+
</div>
|
|
91
|
+
</div>
|
|
92
|
+
|
|
93
|
+
<script>
|
|
94
|
+
class StreamingWavPlayer {
|
|
95
|
+
constructor() {
|
|
96
|
+
this.audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
|
97
|
+
this.sampleRate = 0;
|
|
98
|
+
this.numChannels = 0;
|
|
99
|
+
this.headerParsed = false;
|
|
100
|
+
this.headerBuffer = new Uint8Array(44);
|
|
101
|
+
this.headerBytesReceived = 0;
|
|
102
|
+
this.nextStartTime = 0;
|
|
103
|
+
this.isPlaying = false;
|
|
104
|
+
this.minBufferSize = 16384;
|
|
105
|
+
this.pcmData = new Uint8Array(0);
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
parseWavHeader(header) {
|
|
109
|
+
const view = new DataView(header.buffer);
|
|
110
|
+
|
|
111
|
+
const riff = String.fromCharCode.apply(null, Array.from(header.slice(0, 4)));
|
|
112
|
+
const wave = String.fromCharCode.apply(null, Array.from(header.slice(8, 12)));
|
|
113
|
+
|
|
114
|
+
if (riff !== 'RIFF' || wave !== 'WAVE') {
|
|
115
|
+
throw new Error('Invalid WAV file');
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
this.numChannels = view.getUint16(22, true);
|
|
119
|
+
this.sampleRate = view.getUint32(24, true);
|
|
120
|
+
const bitsPerSample = view.getUint16(34, true);
|
|
121
|
+
|
|
122
|
+
console.log(`WAV Format: ${this.sampleRate}Hz, ${this.numChannels} channels, ${bitsPerSample} bits`);
|
|
123
|
+
|
|
124
|
+
this.headerParsed = true;
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
appendPcmData(newData) {
|
|
128
|
+
const newBuffer = new Uint8Array(this.pcmData.length + newData.length);
|
|
129
|
+
newBuffer.set(this.pcmData);
|
|
130
|
+
newBuffer.set(newData, this.pcmData.length);
|
|
131
|
+
this.pcmData = newBuffer;
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
async tryPlayBuffer() {
|
|
135
|
+
if (!this.headerParsed || this.pcmData.length < this.minBufferSize) {
|
|
136
|
+
return;
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
const bytesPerSample = this.numChannels * 2;
|
|
140
|
+
const samplesToPlay = Math.floor(this.pcmData.length / bytesPerSample);
|
|
141
|
+
const bytesToPlay = samplesToPlay * bytesPerSample;
|
|
142
|
+
|
|
143
|
+
if (bytesToPlay === 0) return;
|
|
144
|
+
|
|
145
|
+
const dataToPlay = this.pcmData.slice(0, bytesToPlay);
|
|
146
|
+
this.pcmData = this.pcmData.slice(bytesToPlay);
|
|
147
|
+
|
|
148
|
+
const audioBuffer = this.audioContext.createBuffer(
|
|
149
|
+
this.numChannels,
|
|
150
|
+
samplesToPlay,
|
|
151
|
+
this.sampleRate
|
|
152
|
+
);
|
|
153
|
+
|
|
154
|
+
const int16Data = new Int16Array(dataToPlay.buffer, dataToPlay.byteOffset, samplesToPlay * this.numChannels);
|
|
155
|
+
|
|
156
|
+
for (let channel = 0; channel < this.numChannels; channel++) {
|
|
157
|
+
const channelData = audioBuffer.getChannelData(channel);
|
|
158
|
+
for (let i = 0; i < samplesToPlay; i++) {
|
|
159
|
+
channelData[i] = int16Data[i * this.numChannels + channel] / 32768;
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
const source = this.audioContext.createBufferSource();
|
|
164
|
+
source.buffer = audioBuffer;
|
|
165
|
+
source.connect(this.audioContext.destination);
|
|
166
|
+
|
|
167
|
+
const currentTime = this.audioContext.currentTime;
|
|
168
|
+
const startTime = Math.max(currentTime, this.nextStartTime);
|
|
169
|
+
|
|
170
|
+
source.start(startTime);
|
|
171
|
+
|
|
172
|
+
// Track first audio playback
|
|
173
|
+
if (!this.firstAudioPlayed && window.firstAudioCallback) {
|
|
174
|
+
this.firstAudioPlayed = true;
|
|
175
|
+
window.firstAudioCallback();
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
this.nextStartTime = startTime + audioBuffer.duration;
|
|
179
|
+
this.isPlaying = true;
|
|
180
|
+
|
|
181
|
+
if (this.pcmData.length >= this.minBufferSize) {
|
|
182
|
+
setTimeout(() => this.tryPlayBuffer(), 10);
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
addChunk(chunk) {
|
|
187
|
+
if (!this.headerParsed) {
|
|
188
|
+
const headerBytesNeeded = 44 - this.headerBytesReceived;
|
|
189
|
+
const bytesToCopy = Math.min(headerBytesNeeded, chunk.length);
|
|
190
|
+
|
|
191
|
+
this.headerBuffer.set(
|
|
192
|
+
chunk.slice(0, bytesToCopy),
|
|
193
|
+
this.headerBytesReceived
|
|
194
|
+
);
|
|
195
|
+
|
|
196
|
+
this.headerBytesReceived += bytesToCopy;
|
|
197
|
+
|
|
198
|
+
if (this.headerBytesReceived >= 44) {
|
|
199
|
+
this.parseWavHeader(this.headerBuffer);
|
|
200
|
+
|
|
201
|
+
if (chunk.length > bytesToCopy) {
|
|
202
|
+
this.appendPcmData(chunk.slice(bytesToCopy));
|
|
203
|
+
}
|
|
204
|
+
}
|
|
205
|
+
} else {
|
|
206
|
+
this.appendPcmData(chunk);
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
this.tryPlayBuffer();
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
stop() {
|
|
213
|
+
this.audioContext.close();
|
|
214
|
+
this.isPlaying = false;
|
|
215
|
+
}
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
// Application state
|
|
219
|
+
let streamingPlayer = null;
|
|
220
|
+
let currentAudioBlob = null;
|
|
221
|
+
|
|
222
|
+
// DOM elements
|
|
223
|
+
const textInput = document.getElementById('text-input');
|
|
224
|
+
const generateBtn = document.getElementById('generate-btn');
|
|
225
|
+
const generateText = document.getElementById('generate-text');
|
|
226
|
+
const status = document.getElementById('status');
|
|
227
|
+
const statusText = document.getElementById('status-text');
|
|
228
|
+
const statusSpinner = document.getElementById('status-spinner');
|
|
229
|
+
const audioSection = document.getElementById('audio-section');
|
|
230
|
+
const audio = document.getElementById('audio');
|
|
231
|
+
const downloadLink = document.getElementById('download-link');
|
|
232
|
+
|
|
233
|
+
// Event listeners
|
|
234
|
+
generateBtn.addEventListener('click', generateAudio);
|
|
235
|
+
|
|
236
|
+
function showStatus(message, isLoading = false) {
|
|
237
|
+
statusText.textContent = message;
|
|
238
|
+
statusSpinner.style.display = isLoading ? 'block' : 'none';
|
|
239
|
+
status.classList.remove('hidden');
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
function hideStatus() {
|
|
243
|
+
status.classList.add('hidden');
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
async function generateAudio() {
|
|
247
|
+
const text = textInput.value.trim();
|
|
248
|
+
|
|
249
|
+
if (!text) {
|
|
250
|
+
showStatus('Please enter some text to generate speech.', false);
|
|
251
|
+
setTimeout(hideStatus, 3000);
|
|
252
|
+
return;
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
// Stop any currently playing audio
|
|
256
|
+
if (streamingPlayer) {
|
|
257
|
+
streamingPlayer.stop();
|
|
258
|
+
streamingPlayer = null;
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
// Track timing
|
|
262
|
+
const startTime = performance.now();
|
|
263
|
+
let firstAudioTime = null;
|
|
264
|
+
|
|
265
|
+
// Set callback for first audio
|
|
266
|
+
window.firstAudioCallback = () => {
|
|
267
|
+
if (!firstAudioTime) {
|
|
268
|
+
firstAudioTime = performance.now();
|
|
269
|
+
const timeToFirstAudio = ((firstAudioTime - startTime) / 1000).toFixed(2);
|
|
270
|
+
showStatus(`First audio in ${timeToFirstAudio}s...`, true);
|
|
271
|
+
}
|
|
272
|
+
};
|
|
273
|
+
|
|
274
|
+
// Update UI
|
|
275
|
+
generateBtn.disabled = true;
|
|
276
|
+
generateText.textContent = 'Generating...';
|
|
277
|
+
showStatus('Generating speech...', true);
|
|
278
|
+
audioSection.classList.add('hidden');
|
|
279
|
+
|
|
280
|
+
try {
|
|
281
|
+
const formData = new FormData();
|
|
282
|
+
formData.append('text', text);
|
|
283
|
+
|
|
284
|
+
// Add voice URL if provided (only if no WAV file is uploaded)
|
|
285
|
+
const voiceUrl = document.getElementById('voice-url-input').value.trim();
|
|
286
|
+
const voiceWavFile = document.getElementById('voice-wav-input').files[0];
|
|
287
|
+
|
|
288
|
+
if (voiceWavFile) {
|
|
289
|
+
// If WAV file is uploaded, only use the WAV file (ignore voice URL)
|
|
290
|
+
formData.append('voice_wav', voiceWavFile);
|
|
291
|
+
} else if (voiceUrl) {
|
|
292
|
+
// Only use voice URL if no WAV file is uploaded
|
|
293
|
+
formData.append('voice_url', voiceUrl);
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
const response = await fetch('/tts', {
|
|
297
|
+
method: 'POST',
|
|
298
|
+
body: formData
|
|
299
|
+
});
|
|
300
|
+
|
|
301
|
+
if (!response.ok) {
|
|
302
|
+
throw new Error(`Server error: ${response.status}`);
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
// Clone the response for both streaming and blob collection
|
|
306
|
+
const responseForPlayback = response.clone();
|
|
307
|
+
const responseForHistory = response.clone();
|
|
308
|
+
|
|
309
|
+
// Start streaming playback
|
|
310
|
+
const reader = responseForPlayback.body.getReader();
|
|
311
|
+
streamingPlayer = new StreamingWavPlayer();
|
|
312
|
+
|
|
313
|
+
const processStream = async () => {
|
|
314
|
+
try {
|
|
315
|
+
while (true) {
|
|
316
|
+
const { done, value } = await reader.read();
|
|
317
|
+
if (done) break;
|
|
318
|
+
|
|
319
|
+
if (value) {
|
|
320
|
+
streamingPlayer.addChunk(value);
|
|
321
|
+
}
|
|
322
|
+
}
|
|
323
|
+
} catch (e) {
|
|
324
|
+
console.error('Error processing stream:', e);
|
|
325
|
+
}
|
|
326
|
+
};
|
|
327
|
+
|
|
328
|
+
// Start processing the stream
|
|
329
|
+
processStream();
|
|
330
|
+
|
|
331
|
+
// Also collect the full blob for download/audio element
|
|
332
|
+
const blob = await responseForHistory.blob();
|
|
333
|
+
const totalTime = ((performance.now() - startTime) / 1000).toFixed(2);
|
|
334
|
+
currentAudioBlob = blob;
|
|
335
|
+
|
|
336
|
+
// Update UI with audio
|
|
337
|
+
const audioUrl = URL.createObjectURL(blob);
|
|
338
|
+
audio.src = audioUrl;
|
|
339
|
+
|
|
340
|
+
// Wait for audio metadata to load to get duration
|
|
341
|
+
audio.addEventListener('loadedmetadata', () => {
|
|
342
|
+
const audioDuration = audio.duration;
|
|
343
|
+
const speedRatio = (audioDuration / parseFloat(totalTime)).toFixed(1);
|
|
344
|
+
|
|
345
|
+
const timeToFirst = firstAudioTime ? ((firstAudioTime - startTime) / 1000).toFixed(2) : 'N/A';
|
|
346
|
+
|
|
347
|
+
showStatus(
|
|
348
|
+
`✨ First audio: ${timeToFirst}s | Total: ${totalTime}s | ${speedRatio}x faster than real-time`,
|
|
349
|
+
false
|
|
350
|
+
);
|
|
351
|
+
});
|
|
352
|
+
downloadLink.href = audioUrl;
|
|
353
|
+
downloadLink.download = `tts-audio.wav`;
|
|
354
|
+
|
|
355
|
+
audioSection.classList.remove('hidden');
|
|
356
|
+
|
|
357
|
+
} catch (error) {
|
|
358
|
+
console.error('Error generating audio:', error);
|
|
359
|
+
showStatus(`Error: ${error.message}`, false);
|
|
360
|
+
setTimeout(hideStatus, 3000);
|
|
361
|
+
} finally {
|
|
362
|
+
// Re-enable button
|
|
363
|
+
generateBtn.disabled = false;
|
|
364
|
+
generateText.textContent = 'Generate Audio';
|
|
365
|
+
}
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
// Focus on text input when page loads
|
|
369
|
+
window.addEventListener('load', () => {
|
|
370
|
+
textInput.focus();
|
|
371
|
+
});
|
|
372
|
+
</script>
|
|
373
|
+
</body>
|
|
374
|
+
</html>
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Utilities."""
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""Configuration models for loading YAML config files."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import yaml
|
|
6
|
+
from pydantic import BaseModel, ConfigDict
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class StrictModel(BaseModel):
|
|
10
|
+
model_config = ConfigDict(extra="forbid")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# Flow configuration
|
|
14
|
+
class FlowConfig(StrictModel):
|
|
15
|
+
dim: int
|
|
16
|
+
depth: int
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Transformer configuration for FlowLM
|
|
20
|
+
class FlowLMTransformerConfig(StrictModel):
|
|
21
|
+
hidden_scale: int
|
|
22
|
+
max_period: int
|
|
23
|
+
d_model: int
|
|
24
|
+
num_heads: int
|
|
25
|
+
num_layers: int
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class LookupTable(StrictModel):
|
|
29
|
+
dim: int
|
|
30
|
+
n_bins: int
|
|
31
|
+
tokenizer: str
|
|
32
|
+
tokenizer_path: str
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# Root configuration
|
|
36
|
+
class FlowLMConfig(StrictModel):
|
|
37
|
+
"""Root configuration model for YAML config files."""
|
|
38
|
+
|
|
39
|
+
dtype: str
|
|
40
|
+
|
|
41
|
+
# Nested configurations
|
|
42
|
+
flow: FlowConfig
|
|
43
|
+
transformer: FlowLMTransformerConfig
|
|
44
|
+
|
|
45
|
+
# conditioning
|
|
46
|
+
lookup_table: LookupTable
|
|
47
|
+
weights_path: str | None = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# SEANet configuration
|
|
51
|
+
class SEANetConfig(StrictModel):
|
|
52
|
+
dimension: int
|
|
53
|
+
channels: int
|
|
54
|
+
n_filters: int
|
|
55
|
+
n_residual_layers: int
|
|
56
|
+
ratios: list[int]
|
|
57
|
+
kernel_size: int
|
|
58
|
+
residual_kernel_size: int
|
|
59
|
+
last_kernel_size: int
|
|
60
|
+
dilation_base: int
|
|
61
|
+
pad_mode: str
|
|
62
|
+
compress: int
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# Transformer configuration for Mimi
|
|
66
|
+
class MimiTransformerConfig(StrictModel):
|
|
67
|
+
d_model: int
|
|
68
|
+
input_dimension: int
|
|
69
|
+
output_dimensions: tuple[int, ...]
|
|
70
|
+
num_heads: int
|
|
71
|
+
num_layers: int
|
|
72
|
+
layer_scale: float
|
|
73
|
+
context: int
|
|
74
|
+
max_period: float = 10000.0
|
|
75
|
+
dim_feedforward: int
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# Quantizer configuration
|
|
79
|
+
class QuantizerConfig(StrictModel):
|
|
80
|
+
dimension: int
|
|
81
|
+
output_dimension: int
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# Root configuration
|
|
85
|
+
class MimiConfig(StrictModel):
|
|
86
|
+
"""Root configuration model for Mimi YAML config files."""
|
|
87
|
+
|
|
88
|
+
dtype: str
|
|
89
|
+
|
|
90
|
+
# Sample rate and channels
|
|
91
|
+
sample_rate: int
|
|
92
|
+
channels: int
|
|
93
|
+
frame_rate: float
|
|
94
|
+
|
|
95
|
+
# SEANet configurations
|
|
96
|
+
seanet: SEANetConfig
|
|
97
|
+
|
|
98
|
+
# Transformer
|
|
99
|
+
transformer: MimiTransformerConfig
|
|
100
|
+
|
|
101
|
+
# Quantizer
|
|
102
|
+
quantizer: QuantizerConfig
|
|
103
|
+
weights_path: str | None = None
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class Config(StrictModel):
|
|
107
|
+
flow_lm: FlowLMConfig
|
|
108
|
+
mimi: MimiConfig
|
|
109
|
+
weights_path: str | None = None
|
|
110
|
+
weights_path_without_voice_cloning: str | None = None
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def load_config(yaml_path: str | Path) -> Config:
|
|
114
|
+
yaml_path = Path(yaml_path)
|
|
115
|
+
|
|
116
|
+
if not yaml_path.exists():
|
|
117
|
+
raise FileNotFoundError(f"Config file not found: {yaml_path}")
|
|
118
|
+
|
|
119
|
+
with open(yaml_path, "r") as f:
|
|
120
|
+
config_dict = yaml.safe_load(f)
|
|
121
|
+
|
|
122
|
+
return Config(**config_dict)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.utils._python_dispatch import TorchDispatchMode
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def to_str(obj):
|
|
6
|
+
if isinstance(obj, (torch.Tensor, torch.nn.Parameter)):
|
|
7
|
+
return f"T(s={list(obj.shape)})"
|
|
8
|
+
elif isinstance(obj, (list, tuple)):
|
|
9
|
+
return "[" + ", ".join(to_str(o) for o in obj) + "]"
|
|
10
|
+
elif isinstance(obj, dict):
|
|
11
|
+
return "{" + ", ".join(f"{to_str(k)}: {to_str(v)}" for k, v in obj.items()) + "}"
|
|
12
|
+
else:
|
|
13
|
+
return str(obj)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LoggingMode(TorchDispatchMode):
|
|
17
|
+
"""Useful to check implementation differences."""
|
|
18
|
+
|
|
19
|
+
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
20
|
+
output = func(*args, **kwargs or {})
|
|
21
|
+
print(
|
|
22
|
+
f"Aten function called: {func}, args: "
|
|
23
|
+
f"{to_str(args)}, kwargs: {to_str(kwargs)} -> "
|
|
24
|
+
f"output: {to_str(output)}"
|
|
25
|
+
)
|
|
26
|
+
return output
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class PocketTTSFilter(logging.Filter):
|
|
6
|
+
def filter(self, record):
|
|
7
|
+
return record.name.startswith("pocket_tts")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@contextmanager
|
|
11
|
+
def enable_logging(library_name, level):
|
|
12
|
+
# Get the specific logger and its parent
|
|
13
|
+
logger = logging.getLogger(library_name)
|
|
14
|
+
parent_logger = logging.getLogger("pocket_tts")
|
|
15
|
+
|
|
16
|
+
# Store original configuration
|
|
17
|
+
old_level = logger.level
|
|
18
|
+
old_parent_level = parent_logger.level
|
|
19
|
+
old_handlers = parent_logger.handlers.copy()
|
|
20
|
+
|
|
21
|
+
# Configure logging format for pocket_tts logger
|
|
22
|
+
parent_logger.setLevel(level)
|
|
23
|
+
|
|
24
|
+
# Clear existing handlers and add our custom formatter with filter
|
|
25
|
+
parent_logger.handlers.clear()
|
|
26
|
+
handler = logging.StreamHandler()
|
|
27
|
+
formatter = logging.Formatter("%(levelname)s: %(message)s")
|
|
28
|
+
handler.setFormatter(formatter)
|
|
29
|
+
handler.addFilter(PocketTTSFilter())
|
|
30
|
+
parent_logger.addHandler(handler)
|
|
31
|
+
parent_logger.propagate = False
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
yield logger
|
|
35
|
+
finally:
|
|
36
|
+
# Restore original configuration
|
|
37
|
+
logger.setLevel(old_level)
|
|
38
|
+
parent_logger.setLevel(old_parent_level)
|
|
39
|
+
parent_logger.handlers.clear()
|
|
40
|
+
for h in old_handlers:
|
|
41
|
+
parent_logger.addHandler(h)
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import requests
|
|
7
|
+
import safetensors.torch
|
|
8
|
+
import torch
|
|
9
|
+
from huggingface_hub import hf_hub_download
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
|
13
|
+
|
|
14
|
+
_voices_names = ["alba", "marius", "javert", "jean", "fantine", "cosette", "eponine", "azelma"]
|
|
15
|
+
PREDEFINED_VOICES = {
|
|
16
|
+
# don't forget to change this
|
|
17
|
+
x: f"hf://kyutai/pocket-tts-without-voice-cloning/embeddings/{x}.safetensors@d4fdd22ae8c8e1cb3634e150ebeff1dab2d16df3"
|
|
18
|
+
for x in _voices_names
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def make_cache_directory() -> Path:
|
|
23
|
+
cache_dir = Path.home() / ".cache" / "pocket_tts"
|
|
24
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
25
|
+
return cache_dir
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def print_nb_parameters(model: nn.Module, model_name: str):
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
state_dict = model.state_dict()
|
|
31
|
+
total = 0
|
|
32
|
+
for key, value in state_dict.items():
|
|
33
|
+
logger.info("%s: %,d", key, value.numel())
|
|
34
|
+
total += value.numel()
|
|
35
|
+
logger.info("Total number of parameters in %s: %,d", model_name, total)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def size_of_dict(state_dict: dict) -> int:
|
|
39
|
+
total_size = 0
|
|
40
|
+
for value in state_dict.values():
|
|
41
|
+
if isinstance(value, torch.Tensor):
|
|
42
|
+
total_size += value.numel() * value.element_size()
|
|
43
|
+
elif isinstance(value, dict):
|
|
44
|
+
total_size += size_of_dict(value)
|
|
45
|
+
return total_size
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class display_execution_time:
|
|
49
|
+
def __init__(self, task_name: str, print_output: bool = True):
|
|
50
|
+
self.task_name = task_name
|
|
51
|
+
self.print_output = print_output
|
|
52
|
+
self.start_time = None
|
|
53
|
+
self.elapsed_time_ms = None
|
|
54
|
+
self.logger = logging.getLogger(__name__)
|
|
55
|
+
|
|
56
|
+
def __enter__(self):
|
|
57
|
+
self.start_time = time.monotonic()
|
|
58
|
+
return self
|
|
59
|
+
|
|
60
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
61
|
+
end_time = time.monotonic()
|
|
62
|
+
self.elapsed_time_ms = int((end_time - self.start_time) * 1000)
|
|
63
|
+
if self.print_output:
|
|
64
|
+
self.logger.info("%s took %d ms", self.task_name, self.elapsed_time_ms)
|
|
65
|
+
return False # Don't suppress exceptions
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def download_if_necessary(file_path: str) -> Path:
|
|
69
|
+
if file_path.startswith("http://") or file_path.startswith("https://"):
|
|
70
|
+
cache_dir = make_cache_directory()
|
|
71
|
+
cached_file = cache_dir / (
|
|
72
|
+
hashlib.sha256(file_path.encode()).hexdigest() + "." + file_path.split(".")[-1]
|
|
73
|
+
)
|
|
74
|
+
if not cached_file.exists():
|
|
75
|
+
response = requests.get(file_path)
|
|
76
|
+
response.raise_for_status()
|
|
77
|
+
with open(cached_file, "wb") as f:
|
|
78
|
+
f.write(response.content)
|
|
79
|
+
return cached_file
|
|
80
|
+
elif file_path.startswith("hf://"):
|
|
81
|
+
file_path = file_path.removeprefix("hf://")
|
|
82
|
+
splitted = file_path.split("/")
|
|
83
|
+
repo_id = "/".join(splitted[:2])
|
|
84
|
+
filename = "/".join(splitted[2:])
|
|
85
|
+
if "@" in filename:
|
|
86
|
+
filename, revision = filename.split("@")
|
|
87
|
+
else:
|
|
88
|
+
revision = None
|
|
89
|
+
cached_file = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision)
|
|
90
|
+
return Path(cached_file)
|
|
91
|
+
else:
|
|
92
|
+
return Path(file_path)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def load_predefined_voice(voice_name: str) -> torch.Tensor:
|
|
96
|
+
if voice_name not in PREDEFINED_VOICES:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f"Predefined voice '{voice_name}' not found"
|
|
99
|
+
f", available voices are {list(PREDEFINED_VOICES)}."
|
|
100
|
+
)
|
|
101
|
+
voice_file = download_if_necessary(PREDEFINED_VOICES[voice_name])
|
|
102
|
+
# There is only one tensor in the file.
|
|
103
|
+
return safetensors.torch.load_file(voice_file)["audio_prompt"]
|