pocket-tts 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pocket_tts/__init__.py +16 -0
- pocket_tts/__main__.py +6 -0
- pocket_tts/conditioners/__init__.py +0 -0
- pocket_tts/conditioners/base.py +38 -0
- pocket_tts/conditioners/text.py +61 -0
- pocket_tts/config/b6369a24.yaml +57 -0
- pocket_tts/data/__init__.py +2 -0
- pocket_tts/data/audio.py +144 -0
- pocket_tts/data/audio_utils.py +28 -0
- pocket_tts/default_parameters.py +7 -0
- pocket_tts/main.py +262 -0
- pocket_tts/models/__init__.py +3 -0
- pocket_tts/models/flow_lm.py +208 -0
- pocket_tts/models/mimi.py +111 -0
- pocket_tts/models/tts_model.py +782 -0
- pocket_tts/modules/__init__.py +1 -0
- pocket_tts/modules/conv.py +161 -0
- pocket_tts/modules/dummy_quantizer.py +18 -0
- pocket_tts/modules/layer_scale.py +11 -0
- pocket_tts/modules/mimi_transformer.py +285 -0
- pocket_tts/modules/mlp.py +215 -0
- pocket_tts/modules/resample.py +46 -0
- pocket_tts/modules/rope.py +74 -0
- pocket_tts/modules/seanet.py +180 -0
- pocket_tts/modules/stateful_module.py +45 -0
- pocket_tts/modules/transformer.py +124 -0
- pocket_tts/static/index.html +374 -0
- pocket_tts/utils/__init__.py +1 -0
- pocket_tts/utils/config.py +122 -0
- pocket_tts/utils/debugging.py +26 -0
- pocket_tts/utils/logging_utils.py +41 -0
- pocket_tts/utils/utils.py +103 -0
- pocket_tts/utils/weights_loading.py +35 -0
- pocket_tts-1.0.2.dist-info/METADATA +174 -0
- pocket_tts-1.0.2.dist-info/RECORD +38 -0
- pocket_tts-1.0.2.dist-info/WHEEL +4 -0
- pocket_tts-1.0.2.dist-info/entry_points.txt +2 -0
- pocket_tts-1.0.2.dist-info/licenses/LICENSE +23 -0
|
@@ -0,0 +1,782 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import queue
|
|
5
|
+
import statistics
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
from functools import lru_cache
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
import safetensors
|
|
12
|
+
import torch
|
|
13
|
+
from torch import nn
|
|
14
|
+
from torch.nn import functional as F
|
|
15
|
+
from typing_extensions import Self
|
|
16
|
+
|
|
17
|
+
from pocket_tts.conditioners.base import TokenizedText
|
|
18
|
+
from pocket_tts.data.audio import audio_read
|
|
19
|
+
from pocket_tts.data.audio_utils import convert_audio
|
|
20
|
+
from pocket_tts.default_parameters import (
|
|
21
|
+
DEFAULT_EOS_THRESHOLD,
|
|
22
|
+
DEFAULT_LSD_DECODE_STEPS,
|
|
23
|
+
DEFAULT_NOISE_CLAMP,
|
|
24
|
+
DEFAULT_TEMPERATURE,
|
|
25
|
+
DEFAULT_VARIANT,
|
|
26
|
+
)
|
|
27
|
+
from pocket_tts.models.flow_lm import FlowLMModel
|
|
28
|
+
from pocket_tts.models.mimi import MimiModel
|
|
29
|
+
from pocket_tts.modules import mimi_transformer
|
|
30
|
+
from pocket_tts.modules.dummy_quantizer import DummyQuantizer
|
|
31
|
+
from pocket_tts.modules.seanet import SEANetDecoder, SEANetEncoder
|
|
32
|
+
from pocket_tts.modules.stateful_module import increment_steps, init_states
|
|
33
|
+
from pocket_tts.utils.config import Config, load_config
|
|
34
|
+
from pocket_tts.utils.utils import (
|
|
35
|
+
PREDEFINED_VOICES,
|
|
36
|
+
display_execution_time,
|
|
37
|
+
download_if_necessary,
|
|
38
|
+
load_predefined_voice,
|
|
39
|
+
size_of_dict,
|
|
40
|
+
)
|
|
41
|
+
from pocket_tts.utils.weights_loading import get_flow_lm_state_dict, get_mimi_state_dict
|
|
42
|
+
|
|
43
|
+
torch.set_num_threads(1)
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class TTSModel(nn.Module):
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
flow_lm: FlowLMModel,
|
|
51
|
+
temp: float,
|
|
52
|
+
lsd_decode_steps: int,
|
|
53
|
+
noise_clamp: float | None,
|
|
54
|
+
eos_threshold,
|
|
55
|
+
config: Config,
|
|
56
|
+
):
|
|
57
|
+
super().__init__()
|
|
58
|
+
self.flow_lm = flow_lm
|
|
59
|
+
self.temp = temp
|
|
60
|
+
self.lsd_decode_steps = lsd_decode_steps
|
|
61
|
+
self.noise_clamp = noise_clamp
|
|
62
|
+
self.eos_threshold = eos_threshold
|
|
63
|
+
self.config = config
|
|
64
|
+
self.has_voice_cloning = True
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def device(self) -> str:
|
|
68
|
+
return next(self.parameters()).device.type
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def sample_rate(self) -> int:
|
|
72
|
+
return self.config.mimi.sample_rate
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def _from_pydantic_config(
|
|
76
|
+
cls, config: Config, temp, lsd_decode_steps, noise_clamp: float | None, eos_threshold
|
|
77
|
+
) -> Self:
|
|
78
|
+
flow_lm = FlowLMModel.from_pydantic_config(
|
|
79
|
+
config.flow_lm, latent_dim=config.mimi.quantizer.dimension
|
|
80
|
+
)
|
|
81
|
+
tts_model = cls(flow_lm, temp, lsd_decode_steps, noise_clamp, eos_threshold, config)
|
|
82
|
+
return tts_model
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def _from_pydantic_config_with_weights(
|
|
86
|
+
cls, config: Config, temp, lsd_decode_steps, noise_clamp: float | None, eos_threshold
|
|
87
|
+
) -> Self:
|
|
88
|
+
tts_model = cls._from_pydantic_config(
|
|
89
|
+
config, temp, lsd_decode_steps, noise_clamp, eos_threshold
|
|
90
|
+
)
|
|
91
|
+
tts_model.flow_lm.speaker_proj_weight = torch.nn.Parameter(
|
|
92
|
+
torch.zeros((1024, 512), dtype=torch.float32)
|
|
93
|
+
)
|
|
94
|
+
if config.flow_lm.weights_path is not None:
|
|
95
|
+
if config.mimi.weights_path is None:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
"If you specify flow_lm.weights_path you should specify mimi.weights_path"
|
|
98
|
+
)
|
|
99
|
+
logger.info(f"Loading FlowLM weights from {config.flow_lm.weights_path}")
|
|
100
|
+
state_dict_flowlm = get_flow_lm_state_dict(
|
|
101
|
+
download_if_necessary(config.flow_lm.weights_path)
|
|
102
|
+
)
|
|
103
|
+
tts_model.flow_lm.load_state_dict(state_dict_flowlm, strict=True)
|
|
104
|
+
|
|
105
|
+
# safetensors.torch.save_file(tts_model.state_dict(), "7442637a.safetensors")
|
|
106
|
+
# Create mimi config directly from the provided config using model_dump
|
|
107
|
+
mimi_config = config.mimi.model_dump()
|
|
108
|
+
|
|
109
|
+
# Build mimi model from config
|
|
110
|
+
encoder = SEANetEncoder(**mimi_config["seanet"])
|
|
111
|
+
decoder = SEANetDecoder(**mimi_config["seanet"])
|
|
112
|
+
|
|
113
|
+
encoder_transformer = mimi_transformer.ProjectedTransformer(**mimi_config["transformer"])
|
|
114
|
+
decoder_transformer = mimi_transformer.ProjectedTransformer(**mimi_config["transformer"])
|
|
115
|
+
quantizer = DummyQuantizer(**mimi_config["quantizer"])
|
|
116
|
+
|
|
117
|
+
tts_model.mimi = MimiModel(
|
|
118
|
+
encoder,
|
|
119
|
+
decoder,
|
|
120
|
+
quantizer,
|
|
121
|
+
channels=mimi_config["channels"],
|
|
122
|
+
sample_rate=mimi_config["sample_rate"],
|
|
123
|
+
frame_rate=mimi_config["frame_rate"],
|
|
124
|
+
encoder_frame_rate=mimi_config["sample_rate"] / encoder.hop_length,
|
|
125
|
+
encoder_transformer=encoder_transformer,
|
|
126
|
+
decoder_transformer=decoder_transformer,
|
|
127
|
+
).to(device="cpu")
|
|
128
|
+
|
|
129
|
+
# Load mimi weights from the config safetensors file with complete mapping for strict loading
|
|
130
|
+
|
|
131
|
+
if config.mimi.weights_path is not None:
|
|
132
|
+
if config.flow_lm.weights_path is None:
|
|
133
|
+
raise ValueError(
|
|
134
|
+
"If you specify mimi.weights_path you should specify flow_lm.weights_path"
|
|
135
|
+
)
|
|
136
|
+
logger.info(f"Loading Mimi weights from {config.mimi.weights_path}")
|
|
137
|
+
mimi_state = get_mimi_state_dict(download_if_necessary(config.mimi.weights_path))
|
|
138
|
+
tts_model.mimi.load_state_dict(mimi_state, strict=True)
|
|
139
|
+
|
|
140
|
+
tts_model.mimi.eval()
|
|
141
|
+
# tts_model.to(dtype=torch.float32)
|
|
142
|
+
|
|
143
|
+
# uncomment to save the weights
|
|
144
|
+
# tts_model = tts_model.to(dtype=torch.bfloat16)
|
|
145
|
+
# safetensors.torch.save_file(tts_model.state_dict(), "tts_b6369a24.safetensors")
|
|
146
|
+
if config.weights_path is not None:
|
|
147
|
+
logger.info(f"Loading TTSModel weights from {config.weights_path}")
|
|
148
|
+
try:
|
|
149
|
+
weights_file = download_if_necessary(config.weights_path)
|
|
150
|
+
except Exception:
|
|
151
|
+
tts_model.has_voice_cloning = False
|
|
152
|
+
weights_file = download_if_necessary(config.weights_path_without_voice_cloning)
|
|
153
|
+
|
|
154
|
+
state_dict = safetensors.torch.load_file(weights_file)
|
|
155
|
+
tts_model.load_state_dict(state_dict, strict=True)
|
|
156
|
+
|
|
157
|
+
if config.flow_lm.weights_path is None and config.weights_path is None:
|
|
158
|
+
logger.warning(
|
|
159
|
+
"No weights_path specified for FlowLM or TTSModel, model is uninitialized!"
|
|
160
|
+
)
|
|
161
|
+
size_in_mb = size_of_dict(tts_model.state_dict()) // 1e6
|
|
162
|
+
logging.info(f"TTS Model loaded successfully. Its size is {size_in_mb} MB")
|
|
163
|
+
|
|
164
|
+
return tts_model
|
|
165
|
+
|
|
166
|
+
def load_model(
|
|
167
|
+
variant: str = DEFAULT_VARIANT,
|
|
168
|
+
temp: float | int = DEFAULT_TEMPERATURE,
|
|
169
|
+
lsd_decode_steps: int = DEFAULT_LSD_DECODE_STEPS,
|
|
170
|
+
noise_clamp: float | int | None = DEFAULT_NOISE_CLAMP,
|
|
171
|
+
eos_threshold: float = DEFAULT_EOS_THRESHOLD,
|
|
172
|
+
) -> Self:
|
|
173
|
+
"""Load a pre-trained TTS model with specified configuration.
|
|
174
|
+
|
|
175
|
+
This class method loads a complete TTS model including the flow language model
|
|
176
|
+
and Mimi compression model from pre-trained weights. The model is initialized
|
|
177
|
+
with the specified generation parameters and ready for inference.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
variant: Model variant identifier corresponding to a config file name
|
|
181
|
+
(e.g., '610b0b2c'). Must match a YAML file in the config directory.
|
|
182
|
+
temp: Sampling temperature for generation. Higher values produce more
|
|
183
|
+
diverse but potentially lower quality output.
|
|
184
|
+
lsd_decode_steps: Number of steps for Lagrangian Self Distillation
|
|
185
|
+
decoding. More steps can improve quality but increase computation.
|
|
186
|
+
noise_clamp: Maximum value for noise sampling. If None, no clamping
|
|
187
|
+
is applied. Helps prevent extreme values in generation.
|
|
188
|
+
eos_threshold: Threshold for end-of-sequence detection. Higher values
|
|
189
|
+
make the model more likely to continue generating.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
TTSModel: Fully initialized model with loaded weights on cpu, ready for
|
|
193
|
+
text-to-speech generation.
|
|
194
|
+
|
|
195
|
+
Raises:
|
|
196
|
+
FileNotFoundError: If the specified config file or model weights
|
|
197
|
+
are not found.
|
|
198
|
+
ValueError: If the configuration is invalid or incompatible.
|
|
199
|
+
"""
|
|
200
|
+
config = load_config(Path(__file__).parents[1] / f"config/{variant}.yaml")
|
|
201
|
+
tts_model = TTSModel._from_pydantic_config_with_weights(
|
|
202
|
+
config, temp, lsd_decode_steps, noise_clamp, eos_threshold
|
|
203
|
+
)
|
|
204
|
+
return tts_model
|
|
205
|
+
|
|
206
|
+
def _run_flow_lm_and_increment_step(
|
|
207
|
+
self,
|
|
208
|
+
model_state: dict,
|
|
209
|
+
text_tokens: torch.Tensor | None = None,
|
|
210
|
+
backbone_input_latents: torch.Tensor | None = None,
|
|
211
|
+
audio_conditioning: torch.Tensor | None = None,
|
|
212
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
213
|
+
"""First one is the backbone output, second one is the audio decoding output."""
|
|
214
|
+
if text_tokens is None:
|
|
215
|
+
text_tokens = torch.zeros((1, 0), dtype=torch.int64, device=self.flow_lm.device)
|
|
216
|
+
if backbone_input_latents is None:
|
|
217
|
+
backbone_input_latents = torch.empty(
|
|
218
|
+
(1, 0, self.flow_lm.ldim), dtype=self.flow_lm.dtype, device=self.flow_lm.device
|
|
219
|
+
)
|
|
220
|
+
if audio_conditioning is None:
|
|
221
|
+
audio_conditioning = torch.empty(
|
|
222
|
+
(1, 0, self.flow_lm.dim), dtype=self.flow_lm.dtype, device=self.flow_lm.device
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
output = self._run_flow_lm(
|
|
226
|
+
text_tokens=text_tokens,
|
|
227
|
+
backbone_input_latents=backbone_input_latents,
|
|
228
|
+
model_state=model_state,
|
|
229
|
+
audio_conditioning=audio_conditioning,
|
|
230
|
+
)
|
|
231
|
+
increment_by = (
|
|
232
|
+
text_tokens.shape[1] + backbone_input_latents.shape[1] + audio_conditioning.shape[1]
|
|
233
|
+
)
|
|
234
|
+
increment_steps(self.flow_lm, model_state, increment=increment_by)
|
|
235
|
+
return output
|
|
236
|
+
|
|
237
|
+
def _run_flow_lm(
|
|
238
|
+
self,
|
|
239
|
+
model_state: dict,
|
|
240
|
+
text_tokens: torch.Tensor,
|
|
241
|
+
backbone_input_latents: torch.Tensor,
|
|
242
|
+
audio_conditioning: torch.Tensor,
|
|
243
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
244
|
+
text_embeddings = self.flow_lm.conditioner(TokenizedText(text_tokens))
|
|
245
|
+
text_embeddings = torch.cat([text_embeddings, audio_conditioning], dim=1)
|
|
246
|
+
|
|
247
|
+
output_embeddings, is_eos = self.flow_lm._sample_next_latent(
|
|
248
|
+
backbone_input_latents,
|
|
249
|
+
text_embeddings,
|
|
250
|
+
model_state=model_state,
|
|
251
|
+
lsd_decode_steps=self.lsd_decode_steps,
|
|
252
|
+
temp=self.temp,
|
|
253
|
+
noise_clamp=self.noise_clamp,
|
|
254
|
+
eos_threshold=self.eos_threshold,
|
|
255
|
+
)
|
|
256
|
+
return output_embeddings[:, None, :], is_eos
|
|
257
|
+
|
|
258
|
+
def _encode_audio(self, audio: torch.Tensor) -> torch.Tensor:
|
|
259
|
+
encoded = self.mimi.encode_to_latent(audio)
|
|
260
|
+
latents = encoded.transpose(-1, -2).to(torch.float32)
|
|
261
|
+
conditioning = F.linear(latents, self.flow_lm.speaker_proj_weight)
|
|
262
|
+
return conditioning
|
|
263
|
+
|
|
264
|
+
def _slice_kv_cache(self, model_state: dict, num_frames: int) -> None:
|
|
265
|
+
"""Slice KV cache to only keep the first num_frames elements.
|
|
266
|
+
|
|
267
|
+
This optimizes memory usage when caching voice states by discarding
|
|
268
|
+
unused cache capacity beyond the actual audio prompt length.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
model_state: The model state dict containing KV caches for all modules
|
|
272
|
+
num_frames: Number of frames to keep in the KV cache
|
|
273
|
+
"""
|
|
274
|
+
original_size = 0
|
|
275
|
+
sliced_size = 0
|
|
276
|
+
for module_name, module_state in model_state.items():
|
|
277
|
+
if "cache" in module_state:
|
|
278
|
+
# KV cache has shape [2, batch_size, sequence_length, num_heads, dim_per_head]
|
|
279
|
+
cache = module_state["cache"]
|
|
280
|
+
original_size += cache.numel() * cache.element_size()
|
|
281
|
+
# Slice to keep only the first num_frames positions
|
|
282
|
+
module_state["cache"] = cache[:, :, :num_frames, :, :].clone()
|
|
283
|
+
sliced_size += module_state["cache"].numel() * module_state["cache"].element_size()
|
|
284
|
+
|
|
285
|
+
memory_saved_mb = (original_size - sliced_size) / (1024 * 1024)
|
|
286
|
+
logger.info(
|
|
287
|
+
f"Sliced KV cache from {original_size / (1024 * 1024):.1f} MB to {sliced_size / (1024 * 1024):.1f} MB "
|
|
288
|
+
f"(saved {memory_saved_mb:.1f} MB)"
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
def _expand_kv_cache(self, model_state: dict, sequence_length: int) -> None:
|
|
292
|
+
"""Expand KV cache back to full sequence_length for generation.
|
|
293
|
+
|
|
294
|
+
When a model state is retrieved from cache with sliced KV caches,
|
|
295
|
+
this method expands them back to the full size needed for generation.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
model_state: The model state dict containing potentially sliced KV caches
|
|
299
|
+
sequence_length: Target sequence length to expand caches to
|
|
300
|
+
"""
|
|
301
|
+
for module_name, module_state in model_state.items():
|
|
302
|
+
if "cache" in module_state:
|
|
303
|
+
cache = module_state["cache"]
|
|
304
|
+
# KV cache has shape [2, batch_size, current_length, num_heads, dim_per_head]
|
|
305
|
+
current_length = cache.shape[2]
|
|
306
|
+
if current_length < sequence_length:
|
|
307
|
+
# Create expanded cache filled with NaN for unused positions
|
|
308
|
+
expanded_cache = torch.full(
|
|
309
|
+
(
|
|
310
|
+
cache.shape[0],
|
|
311
|
+
cache.shape[1],
|
|
312
|
+
sequence_length,
|
|
313
|
+
cache.shape[3],
|
|
314
|
+
cache.shape[4],
|
|
315
|
+
),
|
|
316
|
+
float("NaN"),
|
|
317
|
+
device=cache.device,
|
|
318
|
+
dtype=cache.dtype,
|
|
319
|
+
)
|
|
320
|
+
# Copy existing data to the beginning
|
|
321
|
+
expanded_cache[:, :, :current_length, :, :] = cache
|
|
322
|
+
module_state["cache"] = expanded_cache
|
|
323
|
+
|
|
324
|
+
@torch.no_grad
|
|
325
|
+
def _decode_audio_worker(self, latents_queue: queue.Queue, result_queue: queue.Queue):
|
|
326
|
+
"""Worker thread function for decoding audio latents from queue with immediate streaming."""
|
|
327
|
+
try:
|
|
328
|
+
audio_chunks = []
|
|
329
|
+
mimi_state = init_states(self.mimi, batch_size=1, sequence_length=1000)
|
|
330
|
+
while True:
|
|
331
|
+
latent = latents_queue.get()
|
|
332
|
+
if latent is None:
|
|
333
|
+
break
|
|
334
|
+
mimi_decoding_input = latent * self.flow_lm.emb_std + self.flow_lm.emb_mean
|
|
335
|
+
transposed = mimi_decoding_input.transpose(-1, -2)
|
|
336
|
+
quantized = self.mimi.quantizer(transposed)
|
|
337
|
+
|
|
338
|
+
t = time.monotonic()
|
|
339
|
+
audio_frame = self.mimi.decode_from_latent(quantized, mimi_state)
|
|
340
|
+
increment_steps(self.mimi, mimi_state, increment=16)
|
|
341
|
+
audio_frame_duration = audio_frame.shape[2] / self.config.mimi.sample_rate
|
|
342
|
+
# We could log the timings here.
|
|
343
|
+
logger.debug(
|
|
344
|
+
" " * 30 + "Decoded %d ms of audio with mimi in %d ms",
|
|
345
|
+
int(audio_frame_duration * 1000),
|
|
346
|
+
int((time.monotonic() - t) * 1000),
|
|
347
|
+
)
|
|
348
|
+
audio_chunks.append(audio_frame)
|
|
349
|
+
|
|
350
|
+
result_queue.put(("chunk", audio_frame))
|
|
351
|
+
|
|
352
|
+
latents_queue.task_done()
|
|
353
|
+
|
|
354
|
+
# Signal completion
|
|
355
|
+
result_queue.put(("done", None))
|
|
356
|
+
|
|
357
|
+
except Exception as e:
|
|
358
|
+
# Put error in result queue
|
|
359
|
+
result_queue.put(("error", e))
|
|
360
|
+
|
|
361
|
+
@torch.no_grad
|
|
362
|
+
def generate_audio(
|
|
363
|
+
self,
|
|
364
|
+
model_state: dict,
|
|
365
|
+
text_to_generate: str,
|
|
366
|
+
frames_after_eos: int | None = None,
|
|
367
|
+
copy_state: bool = True,
|
|
368
|
+
) -> torch.Tensor:
|
|
369
|
+
"""Generate complete audio tensor from text input.
|
|
370
|
+
|
|
371
|
+
This method generates the full audio output for the given text prompt
|
|
372
|
+
and returns it as a single tensor. It internally uses the streaming
|
|
373
|
+
generation method but collects all chunks before returning.
|
|
374
|
+
|
|
375
|
+
This method is NOT thread-safe; separate model instances should be used
|
|
376
|
+
for concurrent generation.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
model_state: Model state dictionary containing hidden states and
|
|
380
|
+
positional information. Can be obtained from get_state_for_audio_prompt()
|
|
381
|
+
or init_states(). The state may be modified during generation.
|
|
382
|
+
text_to_generate: Input text to convert to speech. The text will be
|
|
383
|
+
automatically formatted (capitalization, punctuation) for optimal
|
|
384
|
+
generation quality.
|
|
385
|
+
frames_after_eos: Number of additional frames to generate after
|
|
386
|
+
detecting end-of-sequence. If None, automatically determined
|
|
387
|
+
based on text length (1-3 frames).
|
|
388
|
+
copy_state: Whether to create a deep copy of the model state before
|
|
389
|
+
generation. If True, preserves the original state for reuse.
|
|
390
|
+
If False, modifies the input state in-place. Defaults to True.
|
|
391
|
+
|
|
392
|
+
Returns:
|
|
393
|
+
torch.Tensor: Generated audio tensor with shape [channels, samples]
|
|
394
|
+
at the model's sample rate (typically 24kHz). The audio is
|
|
395
|
+
normalized and ready for playback or saving.
|
|
396
|
+
You can get the sample rate from the `sample_rate` attribute.
|
|
397
|
+
|
|
398
|
+
Raises:
|
|
399
|
+
ValueError: If text_to_generate is empty or invalid.
|
|
400
|
+
RuntimeError: If generation fails due to model errors.
|
|
401
|
+
"""
|
|
402
|
+
audio_chunks = []
|
|
403
|
+
for chunk in self.generate_audio_stream(
|
|
404
|
+
model_state=model_state,
|
|
405
|
+
text_to_generate=text_to_generate,
|
|
406
|
+
frames_after_eos=frames_after_eos,
|
|
407
|
+
copy_state=copy_state,
|
|
408
|
+
):
|
|
409
|
+
audio_chunks.append(chunk)
|
|
410
|
+
return torch.cat(audio_chunks, dim=0)
|
|
411
|
+
|
|
412
|
+
@torch.no_grad
|
|
413
|
+
def generate_audio_stream(
|
|
414
|
+
self,
|
|
415
|
+
model_state: dict,
|
|
416
|
+
text_to_generate: str,
|
|
417
|
+
frames_after_eos: int | None = None,
|
|
418
|
+
copy_state: bool = True,
|
|
419
|
+
):
|
|
420
|
+
"""Generate audio streaming chunks from text input.
|
|
421
|
+
|
|
422
|
+
This method generates audio from text and yields chunks as they become
|
|
423
|
+
available, enabling real-time playback or processing. It uses multithreading
|
|
424
|
+
to parallelize generation and decoding for optimal performance.
|
|
425
|
+
This method is NOT thread-safe; separate model instances should be used
|
|
426
|
+
for concurrent generation.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
model_state: Model state dictionary containing hidden states and
|
|
430
|
+
positional information. Can be obtained from get_state_for_audio_prompt()
|
|
431
|
+
or init_states(). The state may be modified during generation.
|
|
432
|
+
text_to_generate: Input text to convert to speech. The text will be
|
|
433
|
+
automatically formatted (capitalization, punctuation) for optimal
|
|
434
|
+
generation quality.
|
|
435
|
+
frames_after_eos: Number of additional frames to generate after
|
|
436
|
+
detecting end-of-sequence. If None, automatically determined
|
|
437
|
+
based on text length (1-3 frames). Defaults to None.
|
|
438
|
+
copy_state: Whether to create a deep copy of the model state before
|
|
439
|
+
generation. If True, preserves the original state for reuse.
|
|
440
|
+
If False, modifies the input state in-place. Defaults to True.
|
|
441
|
+
|
|
442
|
+
Yields:
|
|
443
|
+
torch.Tensor: Audio chunks with shape [samples] at the model's
|
|
444
|
+
sample rate (typically 24kHz). Chunks are yielded as soon as
|
|
445
|
+
they are decoded, enabling real-time streaming.
|
|
446
|
+
|
|
447
|
+
Raises:
|
|
448
|
+
ValueError: If text_to_generate is empty or invalid.
|
|
449
|
+
RuntimeError: If generation fails due to model errors or threading issues.
|
|
450
|
+
|
|
451
|
+
Note:
|
|
452
|
+
This method uses multithreading to parallelize latent generation
|
|
453
|
+
and audio decoding. Generation performance is logged including
|
|
454
|
+
real-time factor (RTF) metrics.
|
|
455
|
+
"""
|
|
456
|
+
|
|
457
|
+
# This is a very simplistic way of handling long texts. We could do much better
|
|
458
|
+
# by using teacher forcing, but it would be a bit slower.
|
|
459
|
+
# TODO: add the teacher forcing method for long texts where we use the audio of one chunk
|
|
460
|
+
# as conditioning for the next chunk.
|
|
461
|
+
chunks = split_into_best_sentences(self.flow_lm.conditioner.tokenizer, text_to_generate)
|
|
462
|
+
|
|
463
|
+
for chunk in chunks:
|
|
464
|
+
text_to_generate, frames_after_eos_guess = prepare_text_prompt(chunk)
|
|
465
|
+
frames_after_eos_guess += 2
|
|
466
|
+
yield from self._generate_audio_stream_short_text(
|
|
467
|
+
model_state=model_state,
|
|
468
|
+
text_to_generate=chunk,
|
|
469
|
+
frames_after_eos=frames_after_eos_guess,
|
|
470
|
+
copy_state=copy_state,
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
@torch.no_grad
|
|
474
|
+
def _generate_audio_stream_short_text(
|
|
475
|
+
self, model_state: dict, text_to_generate: str, frames_after_eos: int, copy_state: bool
|
|
476
|
+
):
|
|
477
|
+
if copy_state:
|
|
478
|
+
model_state = copy.deepcopy(model_state)
|
|
479
|
+
|
|
480
|
+
# Expand sliced KV caches back to full size for generation
|
|
481
|
+
self._expand_kv_cache(model_state, sequence_length=1000)
|
|
482
|
+
|
|
483
|
+
# Set up multithreaded generation and decoding
|
|
484
|
+
latents_queue = queue.Queue()
|
|
485
|
+
result_queue = queue.Queue()
|
|
486
|
+
|
|
487
|
+
# Start decoder worker thread
|
|
488
|
+
decoder_thread = threading.Thread(
|
|
489
|
+
target=self._decode_audio_worker, args=(latents_queue, result_queue), daemon=True
|
|
490
|
+
)
|
|
491
|
+
logger.info("starting timer now!")
|
|
492
|
+
t_generating = time.monotonic()
|
|
493
|
+
decoder_thread.start()
|
|
494
|
+
|
|
495
|
+
# Generate latents and add them to queue (decoder processes them in parallel)
|
|
496
|
+
self._generate(
|
|
497
|
+
model_state=model_state,
|
|
498
|
+
text_to_generate=text_to_generate,
|
|
499
|
+
frames_after_eos=frames_after_eos,
|
|
500
|
+
latents_queue=latents_queue,
|
|
501
|
+
result_queue=result_queue,
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
# Stream audio chunks as they become available
|
|
505
|
+
total_generated_samples = 0
|
|
506
|
+
while True:
|
|
507
|
+
result = result_queue.get()
|
|
508
|
+
if result[0] == "chunk":
|
|
509
|
+
# Audio chunk available immediately for streaming/playback
|
|
510
|
+
audio_chunk = result[1]
|
|
511
|
+
total_generated_samples += audio_chunk.shape[-1]
|
|
512
|
+
yield audio_chunk[0, 0] # Remove batch, channel
|
|
513
|
+
elif result[0] == "done":
|
|
514
|
+
# Generation complete
|
|
515
|
+
break
|
|
516
|
+
elif result[0] == "error":
|
|
517
|
+
# Wait for decoder thread to finish cleanly before propagating error
|
|
518
|
+
with display_execution_time("Waiting for mimi decoder to finish"):
|
|
519
|
+
decoder_thread.join()
|
|
520
|
+
# Propagate error
|
|
521
|
+
raise result[1]
|
|
522
|
+
|
|
523
|
+
# Wait for decoder thread to finish cleanly
|
|
524
|
+
with display_execution_time("Waiting for mimi decoder to finish"):
|
|
525
|
+
decoder_thread.join()
|
|
526
|
+
|
|
527
|
+
# Print timing information
|
|
528
|
+
duration_generated_audio = int(
|
|
529
|
+
total_generated_samples * 1000 / self.config.mimi.sample_rate
|
|
530
|
+
)
|
|
531
|
+
generation_time = int((time.monotonic() - t_generating) * 1000)
|
|
532
|
+
real_time_factor = duration_generated_audio / generation_time
|
|
533
|
+
|
|
534
|
+
logger.info(
|
|
535
|
+
"Generated: %d ms of audio in %d ms so %.2fx faster than real-time",
|
|
536
|
+
duration_generated_audio,
|
|
537
|
+
generation_time,
|
|
538
|
+
real_time_factor,
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
@torch.no_grad
|
|
542
|
+
def _generate(
|
|
543
|
+
self,
|
|
544
|
+
model_state: dict,
|
|
545
|
+
text_to_generate: str,
|
|
546
|
+
frames_after_eos: int,
|
|
547
|
+
latents_queue: queue.Queue,
|
|
548
|
+
result_queue: queue.Queue,
|
|
549
|
+
):
|
|
550
|
+
gen_len_sec = len(text_to_generate.split()) * 1 + 2.0
|
|
551
|
+
max_gen_len = int(gen_len_sec * 12.5)
|
|
552
|
+
prepared = self.flow_lm.conditioner.prepare(text_to_generate)
|
|
553
|
+
|
|
554
|
+
with display_execution_time("Prompting text"):
|
|
555
|
+
self._run_flow_lm_and_increment_step(
|
|
556
|
+
model_state=model_state, text_tokens=prepared.tokens
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
def run_generation():
|
|
560
|
+
try:
|
|
561
|
+
self._autoregressive_generation(
|
|
562
|
+
model_state, max_gen_len, frames_after_eos, latents_queue
|
|
563
|
+
)
|
|
564
|
+
except Exception as e:
|
|
565
|
+
logger.error(f"Error in autoregressive generation: {e}")
|
|
566
|
+
# Signal decoder to stop by putting None (completion sentinel)
|
|
567
|
+
if latents_queue is not None:
|
|
568
|
+
latents_queue.put(None)
|
|
569
|
+
# Report error to main thread
|
|
570
|
+
if result_queue is not None:
|
|
571
|
+
result_queue.put(("error", e))
|
|
572
|
+
|
|
573
|
+
generation_thread = threading.Thread(target=run_generation, daemon=True)
|
|
574
|
+
generation_thread.start()
|
|
575
|
+
|
|
576
|
+
@torch.no_grad
|
|
577
|
+
def _autoregressive_generation(
|
|
578
|
+
self, model_state: dict, max_gen_len: int, frames_after_eos: int, latents_queue: queue.Queue
|
|
579
|
+
):
|
|
580
|
+
backbone_input = torch.full(
|
|
581
|
+
(1, 1, self.flow_lm.ldim),
|
|
582
|
+
fill_value=float("NaN"),
|
|
583
|
+
device=next(iter(self.flow_lm.parameters())).device,
|
|
584
|
+
dtype=self.flow_lm.dtype,
|
|
585
|
+
)
|
|
586
|
+
steps_times = []
|
|
587
|
+
eos_step = None
|
|
588
|
+
for generation_step in range(max_gen_len):
|
|
589
|
+
with display_execution_time("Generating latent", print_output=False) as timer:
|
|
590
|
+
next_latent, is_eos = self._run_flow_lm_and_increment_step(
|
|
591
|
+
model_state=model_state, backbone_input_latents=backbone_input
|
|
592
|
+
)
|
|
593
|
+
if is_eos.item() and eos_step is None:
|
|
594
|
+
eos_step = generation_step
|
|
595
|
+
if eos_step is not None and generation_step >= eos_step + frames_after_eos:
|
|
596
|
+
break
|
|
597
|
+
|
|
598
|
+
# Add generated latent to queue for immediate decoding
|
|
599
|
+
latents_queue.put(next_latent)
|
|
600
|
+
backbone_input = next_latent
|
|
601
|
+
steps_times.append(timer.elapsed_time_ms)
|
|
602
|
+
else:
|
|
603
|
+
if os.environ.get("KPOCKET_TTS_ERROR_WITHOUT_EOS", "0") == "1":
|
|
604
|
+
raise RuntimeError("Generation reached maximum length without EOS!")
|
|
605
|
+
logger.warning(
|
|
606
|
+
"Maximum generation length reached without EOS, this very often indicates an error."
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
# Add sentinel value to signal end of generation
|
|
610
|
+
latents_queue.put(None)
|
|
611
|
+
logger.info("Average generation step time: %d ms", int(statistics.mean(steps_times)))
|
|
612
|
+
|
|
613
|
+
@lru_cache(maxsize=2)
|
|
614
|
+
def _cached_get_state_for_audio_prompt(
|
|
615
|
+
self, audio_conditioning: Path | str | torch.Tensor, truncate: bool = False
|
|
616
|
+
) -> dict:
|
|
617
|
+
return self.get_state_for_audio_prompt(audio_conditioning, truncate)
|
|
618
|
+
|
|
619
|
+
@torch.no_grad
|
|
620
|
+
def get_state_for_audio_prompt(
|
|
621
|
+
self, audio_conditioning: Path | str | torch.Tensor, truncate: bool = False
|
|
622
|
+
) -> dict:
|
|
623
|
+
"""Create model state conditioned on audio prompt for continuation.
|
|
624
|
+
|
|
625
|
+
This method processes an audio prompt and creates a model state that
|
|
626
|
+
captures the acoustic characteristics (speaker voice, style, prosody)
|
|
627
|
+
for use in subsequent text-to-speech generation. The resulting state
|
|
628
|
+
enables voice cloning and audio continuation with speaker consistency.
|
|
629
|
+
|
|
630
|
+
Args:
|
|
631
|
+
audio_conditioning: Audio prompt to condition on. Can be:
|
|
632
|
+
- Path: Local file path to audio file
|
|
633
|
+
- str: URL to download audio file from
|
|
634
|
+
- torch.Tensor: Pre-loaded audio tensor with shape [channels, samples]
|
|
635
|
+
truncate: Whether to truncate long audio prompts to 30 seconds.
|
|
636
|
+
Helps prevent memory issues with very long inputs. Defaults to False.
|
|
637
|
+
|
|
638
|
+
Returns:
|
|
639
|
+
dict: Model state dictionary containing hidden states and positional
|
|
640
|
+
information conditioned on the audio prompt. This state can be
|
|
641
|
+
passed to `generate_audio()` or `generate_audio_stream()` for
|
|
642
|
+
voice-consistent generation.
|
|
643
|
+
|
|
644
|
+
Raises:
|
|
645
|
+
FileNotFoundError: If audio file path doesn't exist.
|
|
646
|
+
ValueError: If audio tensor is invalid or empty.
|
|
647
|
+
RuntimeError: If audio processing or encoding fails.
|
|
648
|
+
|
|
649
|
+
Note:
|
|
650
|
+
- Audio is automatically resampled to the model's sample rate (24kHz)
|
|
651
|
+
- The audio is encoded using the Mimi compression model and projected
|
|
652
|
+
to the flow model's latent space
|
|
653
|
+
- Processing time is logged for performance monitoring
|
|
654
|
+
- The state preserves speaker characteristics for voice cloning
|
|
655
|
+
"""
|
|
656
|
+
if isinstance(audio_conditioning, str) and audio_conditioning in PREDEFINED_VOICES:
|
|
657
|
+
# We get the audio conditioning directly from the safetensors file.
|
|
658
|
+
prompt = load_predefined_voice(audio_conditioning)
|
|
659
|
+
else:
|
|
660
|
+
if not self.has_voice_cloning and isinstance(audio_conditioning, (str, Path)):
|
|
661
|
+
raise ValueError(
|
|
662
|
+
f"We could not download the weights for the model with voice cloning, "
|
|
663
|
+
f"but you're trying to use voice cloning. "
|
|
664
|
+
f"Without voice cloning, you can use our catalog of voices {list(PREDEFINED_VOICES)}. "
|
|
665
|
+
f"If you want access to the model with voice cloning, go to "
|
|
666
|
+
f"https://huggingface.co/kyutai/pocket-tts and accept the terms, "
|
|
667
|
+
f"then make sure you're logged in locally with `uvx hf auth login`."
|
|
668
|
+
)
|
|
669
|
+
if isinstance(audio_conditioning, str):
|
|
670
|
+
audio_conditioning = download_if_necessary(audio_conditioning)
|
|
671
|
+
|
|
672
|
+
if isinstance(audio_conditioning, Path):
|
|
673
|
+
audio, conditioning_sample_rate = audio_read(audio_conditioning)
|
|
674
|
+
|
|
675
|
+
if truncate:
|
|
676
|
+
max_samples = int(30 * conditioning_sample_rate) # 30 seconds of audio
|
|
677
|
+
if audio.shape[-1] > max_samples:
|
|
678
|
+
audio = audio[..., :max_samples]
|
|
679
|
+
logger.info(f"Audio truncated to first 30 seconds ({max_samples} samples)")
|
|
680
|
+
|
|
681
|
+
audio_conditioning = convert_audio(
|
|
682
|
+
audio, conditioning_sample_rate, self.config.mimi.sample_rate, 1
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
with display_execution_time("Encoding audio prompt"):
|
|
686
|
+
prompt = self._encode_audio(audio_conditioning.unsqueeze(0).to(self.device))
|
|
687
|
+
# import safetensors.torch
|
|
688
|
+
# safetensors.torch.save_file(
|
|
689
|
+
# {"audio_prompt": prompt},
|
|
690
|
+
# "/projects/huggingface/pocket-tts/embeddings/cosette.safetensors"
|
|
691
|
+
# )
|
|
692
|
+
|
|
693
|
+
model_state = init_states(self.flow_lm, batch_size=1, sequence_length=1000)
|
|
694
|
+
|
|
695
|
+
with display_execution_time("Prompting audio"):
|
|
696
|
+
self._run_flow_lm_and_increment_step(model_state=model_state, audio_conditioning=prompt)
|
|
697
|
+
|
|
698
|
+
# Optimize memory by slicing KV cache to only keep frames from the audio prompt
|
|
699
|
+
num_audio_frames = prompt.shape[1]
|
|
700
|
+
self._slice_kv_cache(model_state, num_audio_frames)
|
|
701
|
+
|
|
702
|
+
return model_state
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def prepare_text_prompt(text: str) -> tuple[str, int]:
|
|
706
|
+
text = text.strip()
|
|
707
|
+
if text == "":
|
|
708
|
+
raise ValueError("Text prompt cannot be empty")
|
|
709
|
+
text = text.replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
|
710
|
+
number_of_words = len(text.split())
|
|
711
|
+
if number_of_words <= 4:
|
|
712
|
+
frames_after_eos_guess = 3
|
|
713
|
+
else:
|
|
714
|
+
frames_after_eos_guess = 1
|
|
715
|
+
|
|
716
|
+
# Make sure it starts with an uppercase letter
|
|
717
|
+
if not text[0].isupper():
|
|
718
|
+
text = text[0].upper() + text[1:]
|
|
719
|
+
|
|
720
|
+
# Let's make sure it ends with some kind of punctuation
|
|
721
|
+
# If it ends with a letter or digit, we add a period.
|
|
722
|
+
if text[-1].isalnum():
|
|
723
|
+
text = text + "."
|
|
724
|
+
|
|
725
|
+
# The model does not perform well when there are very few tokens, so
|
|
726
|
+
# we can add empty spaces at the beginning to increase the token count.
|
|
727
|
+
if len(text.split()) < 5:
|
|
728
|
+
text = " " * 8 + text
|
|
729
|
+
|
|
730
|
+
return text, frames_after_eos_guess
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
def split_into_best_sentences(tokenizer, text_to_generate: str) -> list[str]:
|
|
734
|
+
text_to_generate, _ = prepare_text_prompt(text_to_generate)
|
|
735
|
+
text_to_generate = text_to_generate.strip()
|
|
736
|
+
tokens = tokenizer(text_to_generate)
|
|
737
|
+
list_of_tokens = tokens.tokens[0].tolist()
|
|
738
|
+
|
|
739
|
+
_, *end_of_sentence_tokens = tokenizer(".!...?").tokens[0].tolist()
|
|
740
|
+
|
|
741
|
+
end_of_sentences_indices = [0]
|
|
742
|
+
previous_was_end_of_sentence_token = False
|
|
743
|
+
|
|
744
|
+
for token_idx, token in enumerate(list_of_tokens):
|
|
745
|
+
if token in end_of_sentence_tokens:
|
|
746
|
+
previous_was_end_of_sentence_token = True
|
|
747
|
+
else:
|
|
748
|
+
if previous_was_end_of_sentence_token:
|
|
749
|
+
end_of_sentences_indices.append(token_idx)
|
|
750
|
+
previous_was_end_of_sentence_token = False
|
|
751
|
+
end_of_sentences_indices.append(len(list_of_tokens))
|
|
752
|
+
|
|
753
|
+
nb_tokens_and_sentences = []
|
|
754
|
+
for i in range(len(end_of_sentences_indices) - 1):
|
|
755
|
+
# let's print
|
|
756
|
+
start = end_of_sentences_indices[i]
|
|
757
|
+
end = end_of_sentences_indices[i + 1]
|
|
758
|
+
text = tokenizer.sp.decode(list_of_tokens[start:end])
|
|
759
|
+
nb_tokens_and_sentences.append((end - start, text))
|
|
760
|
+
|
|
761
|
+
max_nb_tokens_in_a_chunk = 50
|
|
762
|
+
chunks = []
|
|
763
|
+
current_chunk = ""
|
|
764
|
+
current_nb_of_tokens_in_chunk = 0
|
|
765
|
+
for nb_tokens, sentence in nb_tokens_and_sentences:
|
|
766
|
+
if current_chunk == "":
|
|
767
|
+
current_chunk = sentence
|
|
768
|
+
current_nb_of_tokens_in_chunk = nb_tokens
|
|
769
|
+
continue
|
|
770
|
+
|
|
771
|
+
if current_nb_of_tokens_in_chunk + nb_tokens > max_nb_tokens_in_a_chunk:
|
|
772
|
+
chunks.append(current_chunk.strip())
|
|
773
|
+
current_chunk = sentence
|
|
774
|
+
current_nb_of_tokens_in_chunk = nb_tokens
|
|
775
|
+
else:
|
|
776
|
+
current_chunk += " " + sentence
|
|
777
|
+
current_nb_of_tokens_in_chunk += nb_tokens
|
|
778
|
+
|
|
779
|
+
if current_chunk != "":
|
|
780
|
+
chunks.append(current_chunk.strip())
|
|
781
|
+
|
|
782
|
+
return chunks
|