pocket-tts 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,782 @@
1
+ import copy
2
+ import logging
3
+ import os
4
+ import queue
5
+ import statistics
6
+ import threading
7
+ import time
8
+ from functools import lru_cache
9
+ from pathlib import Path
10
+
11
+ import safetensors
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+ from typing_extensions import Self
16
+
17
+ from pocket_tts.conditioners.base import TokenizedText
18
+ from pocket_tts.data.audio import audio_read
19
+ from pocket_tts.data.audio_utils import convert_audio
20
+ from pocket_tts.default_parameters import (
21
+ DEFAULT_EOS_THRESHOLD,
22
+ DEFAULT_LSD_DECODE_STEPS,
23
+ DEFAULT_NOISE_CLAMP,
24
+ DEFAULT_TEMPERATURE,
25
+ DEFAULT_VARIANT,
26
+ )
27
+ from pocket_tts.models.flow_lm import FlowLMModel
28
+ from pocket_tts.models.mimi import MimiModel
29
+ from pocket_tts.modules import mimi_transformer
30
+ from pocket_tts.modules.dummy_quantizer import DummyQuantizer
31
+ from pocket_tts.modules.seanet import SEANetDecoder, SEANetEncoder
32
+ from pocket_tts.modules.stateful_module import increment_steps, init_states
33
+ from pocket_tts.utils.config import Config, load_config
34
+ from pocket_tts.utils.utils import (
35
+ PREDEFINED_VOICES,
36
+ display_execution_time,
37
+ download_if_necessary,
38
+ load_predefined_voice,
39
+ size_of_dict,
40
+ )
41
+ from pocket_tts.utils.weights_loading import get_flow_lm_state_dict, get_mimi_state_dict
42
+
43
+ torch.set_num_threads(1)
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ class TTSModel(nn.Module):
48
+ def __init__(
49
+ self,
50
+ flow_lm: FlowLMModel,
51
+ temp: float,
52
+ lsd_decode_steps: int,
53
+ noise_clamp: float | None,
54
+ eos_threshold,
55
+ config: Config,
56
+ ):
57
+ super().__init__()
58
+ self.flow_lm = flow_lm
59
+ self.temp = temp
60
+ self.lsd_decode_steps = lsd_decode_steps
61
+ self.noise_clamp = noise_clamp
62
+ self.eos_threshold = eos_threshold
63
+ self.config = config
64
+ self.has_voice_cloning = True
65
+
66
+ @property
67
+ def device(self) -> str:
68
+ return next(self.parameters()).device.type
69
+
70
+ @property
71
+ def sample_rate(self) -> int:
72
+ return self.config.mimi.sample_rate
73
+
74
+ @classmethod
75
+ def _from_pydantic_config(
76
+ cls, config: Config, temp, lsd_decode_steps, noise_clamp: float | None, eos_threshold
77
+ ) -> Self:
78
+ flow_lm = FlowLMModel.from_pydantic_config(
79
+ config.flow_lm, latent_dim=config.mimi.quantizer.dimension
80
+ )
81
+ tts_model = cls(flow_lm, temp, lsd_decode_steps, noise_clamp, eos_threshold, config)
82
+ return tts_model
83
+
84
+ @classmethod
85
+ def _from_pydantic_config_with_weights(
86
+ cls, config: Config, temp, lsd_decode_steps, noise_clamp: float | None, eos_threshold
87
+ ) -> Self:
88
+ tts_model = cls._from_pydantic_config(
89
+ config, temp, lsd_decode_steps, noise_clamp, eos_threshold
90
+ )
91
+ tts_model.flow_lm.speaker_proj_weight = torch.nn.Parameter(
92
+ torch.zeros((1024, 512), dtype=torch.float32)
93
+ )
94
+ if config.flow_lm.weights_path is not None:
95
+ if config.mimi.weights_path is None:
96
+ raise ValueError(
97
+ "If you specify flow_lm.weights_path you should specify mimi.weights_path"
98
+ )
99
+ logger.info(f"Loading FlowLM weights from {config.flow_lm.weights_path}")
100
+ state_dict_flowlm = get_flow_lm_state_dict(
101
+ download_if_necessary(config.flow_lm.weights_path)
102
+ )
103
+ tts_model.flow_lm.load_state_dict(state_dict_flowlm, strict=True)
104
+
105
+ # safetensors.torch.save_file(tts_model.state_dict(), "7442637a.safetensors")
106
+ # Create mimi config directly from the provided config using model_dump
107
+ mimi_config = config.mimi.model_dump()
108
+
109
+ # Build mimi model from config
110
+ encoder = SEANetEncoder(**mimi_config["seanet"])
111
+ decoder = SEANetDecoder(**mimi_config["seanet"])
112
+
113
+ encoder_transformer = mimi_transformer.ProjectedTransformer(**mimi_config["transformer"])
114
+ decoder_transformer = mimi_transformer.ProjectedTransformer(**mimi_config["transformer"])
115
+ quantizer = DummyQuantizer(**mimi_config["quantizer"])
116
+
117
+ tts_model.mimi = MimiModel(
118
+ encoder,
119
+ decoder,
120
+ quantizer,
121
+ channels=mimi_config["channels"],
122
+ sample_rate=mimi_config["sample_rate"],
123
+ frame_rate=mimi_config["frame_rate"],
124
+ encoder_frame_rate=mimi_config["sample_rate"] / encoder.hop_length,
125
+ encoder_transformer=encoder_transformer,
126
+ decoder_transformer=decoder_transformer,
127
+ ).to(device="cpu")
128
+
129
+ # Load mimi weights from the config safetensors file with complete mapping for strict loading
130
+
131
+ if config.mimi.weights_path is not None:
132
+ if config.flow_lm.weights_path is None:
133
+ raise ValueError(
134
+ "If you specify mimi.weights_path you should specify flow_lm.weights_path"
135
+ )
136
+ logger.info(f"Loading Mimi weights from {config.mimi.weights_path}")
137
+ mimi_state = get_mimi_state_dict(download_if_necessary(config.mimi.weights_path))
138
+ tts_model.mimi.load_state_dict(mimi_state, strict=True)
139
+
140
+ tts_model.mimi.eval()
141
+ # tts_model.to(dtype=torch.float32)
142
+
143
+ # uncomment to save the weights
144
+ # tts_model = tts_model.to(dtype=torch.bfloat16)
145
+ # safetensors.torch.save_file(tts_model.state_dict(), "tts_b6369a24.safetensors")
146
+ if config.weights_path is not None:
147
+ logger.info(f"Loading TTSModel weights from {config.weights_path}")
148
+ try:
149
+ weights_file = download_if_necessary(config.weights_path)
150
+ except Exception:
151
+ tts_model.has_voice_cloning = False
152
+ weights_file = download_if_necessary(config.weights_path_without_voice_cloning)
153
+
154
+ state_dict = safetensors.torch.load_file(weights_file)
155
+ tts_model.load_state_dict(state_dict, strict=True)
156
+
157
+ if config.flow_lm.weights_path is None and config.weights_path is None:
158
+ logger.warning(
159
+ "No weights_path specified for FlowLM or TTSModel, model is uninitialized!"
160
+ )
161
+ size_in_mb = size_of_dict(tts_model.state_dict()) // 1e6
162
+ logging.info(f"TTS Model loaded successfully. Its size is {size_in_mb} MB")
163
+
164
+ return tts_model
165
+
166
+ def load_model(
167
+ variant: str = DEFAULT_VARIANT,
168
+ temp: float | int = DEFAULT_TEMPERATURE,
169
+ lsd_decode_steps: int = DEFAULT_LSD_DECODE_STEPS,
170
+ noise_clamp: float | int | None = DEFAULT_NOISE_CLAMP,
171
+ eos_threshold: float = DEFAULT_EOS_THRESHOLD,
172
+ ) -> Self:
173
+ """Load a pre-trained TTS model with specified configuration.
174
+
175
+ This class method loads a complete TTS model including the flow language model
176
+ and Mimi compression model from pre-trained weights. The model is initialized
177
+ with the specified generation parameters and ready for inference.
178
+
179
+ Args:
180
+ variant: Model variant identifier corresponding to a config file name
181
+ (e.g., '610b0b2c'). Must match a YAML file in the config directory.
182
+ temp: Sampling temperature for generation. Higher values produce more
183
+ diverse but potentially lower quality output.
184
+ lsd_decode_steps: Number of steps for Lagrangian Self Distillation
185
+ decoding. More steps can improve quality but increase computation.
186
+ noise_clamp: Maximum value for noise sampling. If None, no clamping
187
+ is applied. Helps prevent extreme values in generation.
188
+ eos_threshold: Threshold for end-of-sequence detection. Higher values
189
+ make the model more likely to continue generating.
190
+
191
+ Returns:
192
+ TTSModel: Fully initialized model with loaded weights on cpu, ready for
193
+ text-to-speech generation.
194
+
195
+ Raises:
196
+ FileNotFoundError: If the specified config file or model weights
197
+ are not found.
198
+ ValueError: If the configuration is invalid or incompatible.
199
+ """
200
+ config = load_config(Path(__file__).parents[1] / f"config/{variant}.yaml")
201
+ tts_model = TTSModel._from_pydantic_config_with_weights(
202
+ config, temp, lsd_decode_steps, noise_clamp, eos_threshold
203
+ )
204
+ return tts_model
205
+
206
+ def _run_flow_lm_and_increment_step(
207
+ self,
208
+ model_state: dict,
209
+ text_tokens: torch.Tensor | None = None,
210
+ backbone_input_latents: torch.Tensor | None = None,
211
+ audio_conditioning: torch.Tensor | None = None,
212
+ ) -> tuple[torch.Tensor, torch.Tensor]:
213
+ """First one is the backbone output, second one is the audio decoding output."""
214
+ if text_tokens is None:
215
+ text_tokens = torch.zeros((1, 0), dtype=torch.int64, device=self.flow_lm.device)
216
+ if backbone_input_latents is None:
217
+ backbone_input_latents = torch.empty(
218
+ (1, 0, self.flow_lm.ldim), dtype=self.flow_lm.dtype, device=self.flow_lm.device
219
+ )
220
+ if audio_conditioning is None:
221
+ audio_conditioning = torch.empty(
222
+ (1, 0, self.flow_lm.dim), dtype=self.flow_lm.dtype, device=self.flow_lm.device
223
+ )
224
+
225
+ output = self._run_flow_lm(
226
+ text_tokens=text_tokens,
227
+ backbone_input_latents=backbone_input_latents,
228
+ model_state=model_state,
229
+ audio_conditioning=audio_conditioning,
230
+ )
231
+ increment_by = (
232
+ text_tokens.shape[1] + backbone_input_latents.shape[1] + audio_conditioning.shape[1]
233
+ )
234
+ increment_steps(self.flow_lm, model_state, increment=increment_by)
235
+ return output
236
+
237
+ def _run_flow_lm(
238
+ self,
239
+ model_state: dict,
240
+ text_tokens: torch.Tensor,
241
+ backbone_input_latents: torch.Tensor,
242
+ audio_conditioning: torch.Tensor,
243
+ ) -> tuple[torch.Tensor, torch.Tensor]:
244
+ text_embeddings = self.flow_lm.conditioner(TokenizedText(text_tokens))
245
+ text_embeddings = torch.cat([text_embeddings, audio_conditioning], dim=1)
246
+
247
+ output_embeddings, is_eos = self.flow_lm._sample_next_latent(
248
+ backbone_input_latents,
249
+ text_embeddings,
250
+ model_state=model_state,
251
+ lsd_decode_steps=self.lsd_decode_steps,
252
+ temp=self.temp,
253
+ noise_clamp=self.noise_clamp,
254
+ eos_threshold=self.eos_threshold,
255
+ )
256
+ return output_embeddings[:, None, :], is_eos
257
+
258
+ def _encode_audio(self, audio: torch.Tensor) -> torch.Tensor:
259
+ encoded = self.mimi.encode_to_latent(audio)
260
+ latents = encoded.transpose(-1, -2).to(torch.float32)
261
+ conditioning = F.linear(latents, self.flow_lm.speaker_proj_weight)
262
+ return conditioning
263
+
264
+ def _slice_kv_cache(self, model_state: dict, num_frames: int) -> None:
265
+ """Slice KV cache to only keep the first num_frames elements.
266
+
267
+ This optimizes memory usage when caching voice states by discarding
268
+ unused cache capacity beyond the actual audio prompt length.
269
+
270
+ Args:
271
+ model_state: The model state dict containing KV caches for all modules
272
+ num_frames: Number of frames to keep in the KV cache
273
+ """
274
+ original_size = 0
275
+ sliced_size = 0
276
+ for module_name, module_state in model_state.items():
277
+ if "cache" in module_state:
278
+ # KV cache has shape [2, batch_size, sequence_length, num_heads, dim_per_head]
279
+ cache = module_state["cache"]
280
+ original_size += cache.numel() * cache.element_size()
281
+ # Slice to keep only the first num_frames positions
282
+ module_state["cache"] = cache[:, :, :num_frames, :, :].clone()
283
+ sliced_size += module_state["cache"].numel() * module_state["cache"].element_size()
284
+
285
+ memory_saved_mb = (original_size - sliced_size) / (1024 * 1024)
286
+ logger.info(
287
+ f"Sliced KV cache from {original_size / (1024 * 1024):.1f} MB to {sliced_size / (1024 * 1024):.1f} MB "
288
+ f"(saved {memory_saved_mb:.1f} MB)"
289
+ )
290
+
291
+ def _expand_kv_cache(self, model_state: dict, sequence_length: int) -> None:
292
+ """Expand KV cache back to full sequence_length for generation.
293
+
294
+ When a model state is retrieved from cache with sliced KV caches,
295
+ this method expands them back to the full size needed for generation.
296
+
297
+ Args:
298
+ model_state: The model state dict containing potentially sliced KV caches
299
+ sequence_length: Target sequence length to expand caches to
300
+ """
301
+ for module_name, module_state in model_state.items():
302
+ if "cache" in module_state:
303
+ cache = module_state["cache"]
304
+ # KV cache has shape [2, batch_size, current_length, num_heads, dim_per_head]
305
+ current_length = cache.shape[2]
306
+ if current_length < sequence_length:
307
+ # Create expanded cache filled with NaN for unused positions
308
+ expanded_cache = torch.full(
309
+ (
310
+ cache.shape[0],
311
+ cache.shape[1],
312
+ sequence_length,
313
+ cache.shape[3],
314
+ cache.shape[4],
315
+ ),
316
+ float("NaN"),
317
+ device=cache.device,
318
+ dtype=cache.dtype,
319
+ )
320
+ # Copy existing data to the beginning
321
+ expanded_cache[:, :, :current_length, :, :] = cache
322
+ module_state["cache"] = expanded_cache
323
+
324
+ @torch.no_grad
325
+ def _decode_audio_worker(self, latents_queue: queue.Queue, result_queue: queue.Queue):
326
+ """Worker thread function for decoding audio latents from queue with immediate streaming."""
327
+ try:
328
+ audio_chunks = []
329
+ mimi_state = init_states(self.mimi, batch_size=1, sequence_length=1000)
330
+ while True:
331
+ latent = latents_queue.get()
332
+ if latent is None:
333
+ break
334
+ mimi_decoding_input = latent * self.flow_lm.emb_std + self.flow_lm.emb_mean
335
+ transposed = mimi_decoding_input.transpose(-1, -2)
336
+ quantized = self.mimi.quantizer(transposed)
337
+
338
+ t = time.monotonic()
339
+ audio_frame = self.mimi.decode_from_latent(quantized, mimi_state)
340
+ increment_steps(self.mimi, mimi_state, increment=16)
341
+ audio_frame_duration = audio_frame.shape[2] / self.config.mimi.sample_rate
342
+ # We could log the timings here.
343
+ logger.debug(
344
+ " " * 30 + "Decoded %d ms of audio with mimi in %d ms",
345
+ int(audio_frame_duration * 1000),
346
+ int((time.monotonic() - t) * 1000),
347
+ )
348
+ audio_chunks.append(audio_frame)
349
+
350
+ result_queue.put(("chunk", audio_frame))
351
+
352
+ latents_queue.task_done()
353
+
354
+ # Signal completion
355
+ result_queue.put(("done", None))
356
+
357
+ except Exception as e:
358
+ # Put error in result queue
359
+ result_queue.put(("error", e))
360
+
361
+ @torch.no_grad
362
+ def generate_audio(
363
+ self,
364
+ model_state: dict,
365
+ text_to_generate: str,
366
+ frames_after_eos: int | None = None,
367
+ copy_state: bool = True,
368
+ ) -> torch.Tensor:
369
+ """Generate complete audio tensor from text input.
370
+
371
+ This method generates the full audio output for the given text prompt
372
+ and returns it as a single tensor. It internally uses the streaming
373
+ generation method but collects all chunks before returning.
374
+
375
+ This method is NOT thread-safe; separate model instances should be used
376
+ for concurrent generation.
377
+
378
+ Args:
379
+ model_state: Model state dictionary containing hidden states and
380
+ positional information. Can be obtained from get_state_for_audio_prompt()
381
+ or init_states(). The state may be modified during generation.
382
+ text_to_generate: Input text to convert to speech. The text will be
383
+ automatically formatted (capitalization, punctuation) for optimal
384
+ generation quality.
385
+ frames_after_eos: Number of additional frames to generate after
386
+ detecting end-of-sequence. If None, automatically determined
387
+ based on text length (1-3 frames).
388
+ copy_state: Whether to create a deep copy of the model state before
389
+ generation. If True, preserves the original state for reuse.
390
+ If False, modifies the input state in-place. Defaults to True.
391
+
392
+ Returns:
393
+ torch.Tensor: Generated audio tensor with shape [channels, samples]
394
+ at the model's sample rate (typically 24kHz). The audio is
395
+ normalized and ready for playback or saving.
396
+ You can get the sample rate from the `sample_rate` attribute.
397
+
398
+ Raises:
399
+ ValueError: If text_to_generate is empty or invalid.
400
+ RuntimeError: If generation fails due to model errors.
401
+ """
402
+ audio_chunks = []
403
+ for chunk in self.generate_audio_stream(
404
+ model_state=model_state,
405
+ text_to_generate=text_to_generate,
406
+ frames_after_eos=frames_after_eos,
407
+ copy_state=copy_state,
408
+ ):
409
+ audio_chunks.append(chunk)
410
+ return torch.cat(audio_chunks, dim=0)
411
+
412
+ @torch.no_grad
413
+ def generate_audio_stream(
414
+ self,
415
+ model_state: dict,
416
+ text_to_generate: str,
417
+ frames_after_eos: int | None = None,
418
+ copy_state: bool = True,
419
+ ):
420
+ """Generate audio streaming chunks from text input.
421
+
422
+ This method generates audio from text and yields chunks as they become
423
+ available, enabling real-time playback or processing. It uses multithreading
424
+ to parallelize generation and decoding for optimal performance.
425
+ This method is NOT thread-safe; separate model instances should be used
426
+ for concurrent generation.
427
+
428
+ Args:
429
+ model_state: Model state dictionary containing hidden states and
430
+ positional information. Can be obtained from get_state_for_audio_prompt()
431
+ or init_states(). The state may be modified during generation.
432
+ text_to_generate: Input text to convert to speech. The text will be
433
+ automatically formatted (capitalization, punctuation) for optimal
434
+ generation quality.
435
+ frames_after_eos: Number of additional frames to generate after
436
+ detecting end-of-sequence. If None, automatically determined
437
+ based on text length (1-3 frames). Defaults to None.
438
+ copy_state: Whether to create a deep copy of the model state before
439
+ generation. If True, preserves the original state for reuse.
440
+ If False, modifies the input state in-place. Defaults to True.
441
+
442
+ Yields:
443
+ torch.Tensor: Audio chunks with shape [samples] at the model's
444
+ sample rate (typically 24kHz). Chunks are yielded as soon as
445
+ they are decoded, enabling real-time streaming.
446
+
447
+ Raises:
448
+ ValueError: If text_to_generate is empty or invalid.
449
+ RuntimeError: If generation fails due to model errors or threading issues.
450
+
451
+ Note:
452
+ This method uses multithreading to parallelize latent generation
453
+ and audio decoding. Generation performance is logged including
454
+ real-time factor (RTF) metrics.
455
+ """
456
+
457
+ # This is a very simplistic way of handling long texts. We could do much better
458
+ # by using teacher forcing, but it would be a bit slower.
459
+ # TODO: add the teacher forcing method for long texts where we use the audio of one chunk
460
+ # as conditioning for the next chunk.
461
+ chunks = split_into_best_sentences(self.flow_lm.conditioner.tokenizer, text_to_generate)
462
+
463
+ for chunk in chunks:
464
+ text_to_generate, frames_after_eos_guess = prepare_text_prompt(chunk)
465
+ frames_after_eos_guess += 2
466
+ yield from self._generate_audio_stream_short_text(
467
+ model_state=model_state,
468
+ text_to_generate=chunk,
469
+ frames_after_eos=frames_after_eos_guess,
470
+ copy_state=copy_state,
471
+ )
472
+
473
+ @torch.no_grad
474
+ def _generate_audio_stream_short_text(
475
+ self, model_state: dict, text_to_generate: str, frames_after_eos: int, copy_state: bool
476
+ ):
477
+ if copy_state:
478
+ model_state = copy.deepcopy(model_state)
479
+
480
+ # Expand sliced KV caches back to full size for generation
481
+ self._expand_kv_cache(model_state, sequence_length=1000)
482
+
483
+ # Set up multithreaded generation and decoding
484
+ latents_queue = queue.Queue()
485
+ result_queue = queue.Queue()
486
+
487
+ # Start decoder worker thread
488
+ decoder_thread = threading.Thread(
489
+ target=self._decode_audio_worker, args=(latents_queue, result_queue), daemon=True
490
+ )
491
+ logger.info("starting timer now!")
492
+ t_generating = time.monotonic()
493
+ decoder_thread.start()
494
+
495
+ # Generate latents and add them to queue (decoder processes them in parallel)
496
+ self._generate(
497
+ model_state=model_state,
498
+ text_to_generate=text_to_generate,
499
+ frames_after_eos=frames_after_eos,
500
+ latents_queue=latents_queue,
501
+ result_queue=result_queue,
502
+ )
503
+
504
+ # Stream audio chunks as they become available
505
+ total_generated_samples = 0
506
+ while True:
507
+ result = result_queue.get()
508
+ if result[0] == "chunk":
509
+ # Audio chunk available immediately for streaming/playback
510
+ audio_chunk = result[1]
511
+ total_generated_samples += audio_chunk.shape[-1]
512
+ yield audio_chunk[0, 0] # Remove batch, channel
513
+ elif result[0] == "done":
514
+ # Generation complete
515
+ break
516
+ elif result[0] == "error":
517
+ # Wait for decoder thread to finish cleanly before propagating error
518
+ with display_execution_time("Waiting for mimi decoder to finish"):
519
+ decoder_thread.join()
520
+ # Propagate error
521
+ raise result[1]
522
+
523
+ # Wait for decoder thread to finish cleanly
524
+ with display_execution_time("Waiting for mimi decoder to finish"):
525
+ decoder_thread.join()
526
+
527
+ # Print timing information
528
+ duration_generated_audio = int(
529
+ total_generated_samples * 1000 / self.config.mimi.sample_rate
530
+ )
531
+ generation_time = int((time.monotonic() - t_generating) * 1000)
532
+ real_time_factor = duration_generated_audio / generation_time
533
+
534
+ logger.info(
535
+ "Generated: %d ms of audio in %d ms so %.2fx faster than real-time",
536
+ duration_generated_audio,
537
+ generation_time,
538
+ real_time_factor,
539
+ )
540
+
541
+ @torch.no_grad
542
+ def _generate(
543
+ self,
544
+ model_state: dict,
545
+ text_to_generate: str,
546
+ frames_after_eos: int,
547
+ latents_queue: queue.Queue,
548
+ result_queue: queue.Queue,
549
+ ):
550
+ gen_len_sec = len(text_to_generate.split()) * 1 + 2.0
551
+ max_gen_len = int(gen_len_sec * 12.5)
552
+ prepared = self.flow_lm.conditioner.prepare(text_to_generate)
553
+
554
+ with display_execution_time("Prompting text"):
555
+ self._run_flow_lm_and_increment_step(
556
+ model_state=model_state, text_tokens=prepared.tokens
557
+ )
558
+
559
+ def run_generation():
560
+ try:
561
+ self._autoregressive_generation(
562
+ model_state, max_gen_len, frames_after_eos, latents_queue
563
+ )
564
+ except Exception as e:
565
+ logger.error(f"Error in autoregressive generation: {e}")
566
+ # Signal decoder to stop by putting None (completion sentinel)
567
+ if latents_queue is not None:
568
+ latents_queue.put(None)
569
+ # Report error to main thread
570
+ if result_queue is not None:
571
+ result_queue.put(("error", e))
572
+
573
+ generation_thread = threading.Thread(target=run_generation, daemon=True)
574
+ generation_thread.start()
575
+
576
+ @torch.no_grad
577
+ def _autoregressive_generation(
578
+ self, model_state: dict, max_gen_len: int, frames_after_eos: int, latents_queue: queue.Queue
579
+ ):
580
+ backbone_input = torch.full(
581
+ (1, 1, self.flow_lm.ldim),
582
+ fill_value=float("NaN"),
583
+ device=next(iter(self.flow_lm.parameters())).device,
584
+ dtype=self.flow_lm.dtype,
585
+ )
586
+ steps_times = []
587
+ eos_step = None
588
+ for generation_step in range(max_gen_len):
589
+ with display_execution_time("Generating latent", print_output=False) as timer:
590
+ next_latent, is_eos = self._run_flow_lm_and_increment_step(
591
+ model_state=model_state, backbone_input_latents=backbone_input
592
+ )
593
+ if is_eos.item() and eos_step is None:
594
+ eos_step = generation_step
595
+ if eos_step is not None and generation_step >= eos_step + frames_after_eos:
596
+ break
597
+
598
+ # Add generated latent to queue for immediate decoding
599
+ latents_queue.put(next_latent)
600
+ backbone_input = next_latent
601
+ steps_times.append(timer.elapsed_time_ms)
602
+ else:
603
+ if os.environ.get("KPOCKET_TTS_ERROR_WITHOUT_EOS", "0") == "1":
604
+ raise RuntimeError("Generation reached maximum length without EOS!")
605
+ logger.warning(
606
+ "Maximum generation length reached without EOS, this very often indicates an error."
607
+ )
608
+
609
+ # Add sentinel value to signal end of generation
610
+ latents_queue.put(None)
611
+ logger.info("Average generation step time: %d ms", int(statistics.mean(steps_times)))
612
+
613
+ @lru_cache(maxsize=2)
614
+ def _cached_get_state_for_audio_prompt(
615
+ self, audio_conditioning: Path | str | torch.Tensor, truncate: bool = False
616
+ ) -> dict:
617
+ return self.get_state_for_audio_prompt(audio_conditioning, truncate)
618
+
619
+ @torch.no_grad
620
+ def get_state_for_audio_prompt(
621
+ self, audio_conditioning: Path | str | torch.Tensor, truncate: bool = False
622
+ ) -> dict:
623
+ """Create model state conditioned on audio prompt for continuation.
624
+
625
+ This method processes an audio prompt and creates a model state that
626
+ captures the acoustic characteristics (speaker voice, style, prosody)
627
+ for use in subsequent text-to-speech generation. The resulting state
628
+ enables voice cloning and audio continuation with speaker consistency.
629
+
630
+ Args:
631
+ audio_conditioning: Audio prompt to condition on. Can be:
632
+ - Path: Local file path to audio file
633
+ - str: URL to download audio file from
634
+ - torch.Tensor: Pre-loaded audio tensor with shape [channels, samples]
635
+ truncate: Whether to truncate long audio prompts to 30 seconds.
636
+ Helps prevent memory issues with very long inputs. Defaults to False.
637
+
638
+ Returns:
639
+ dict: Model state dictionary containing hidden states and positional
640
+ information conditioned on the audio prompt. This state can be
641
+ passed to `generate_audio()` or `generate_audio_stream()` for
642
+ voice-consistent generation.
643
+
644
+ Raises:
645
+ FileNotFoundError: If audio file path doesn't exist.
646
+ ValueError: If audio tensor is invalid or empty.
647
+ RuntimeError: If audio processing or encoding fails.
648
+
649
+ Note:
650
+ - Audio is automatically resampled to the model's sample rate (24kHz)
651
+ - The audio is encoded using the Mimi compression model and projected
652
+ to the flow model's latent space
653
+ - Processing time is logged for performance monitoring
654
+ - The state preserves speaker characteristics for voice cloning
655
+ """
656
+ if isinstance(audio_conditioning, str) and audio_conditioning in PREDEFINED_VOICES:
657
+ # We get the audio conditioning directly from the safetensors file.
658
+ prompt = load_predefined_voice(audio_conditioning)
659
+ else:
660
+ if not self.has_voice_cloning and isinstance(audio_conditioning, (str, Path)):
661
+ raise ValueError(
662
+ f"We could not download the weights for the model with voice cloning, "
663
+ f"but you're trying to use voice cloning. "
664
+ f"Without voice cloning, you can use our catalog of voices {list(PREDEFINED_VOICES)}. "
665
+ f"If you want access to the model with voice cloning, go to "
666
+ f"https://huggingface.co/kyutai/pocket-tts and accept the terms, "
667
+ f"then make sure you're logged in locally with `uvx hf auth login`."
668
+ )
669
+ if isinstance(audio_conditioning, str):
670
+ audio_conditioning = download_if_necessary(audio_conditioning)
671
+
672
+ if isinstance(audio_conditioning, Path):
673
+ audio, conditioning_sample_rate = audio_read(audio_conditioning)
674
+
675
+ if truncate:
676
+ max_samples = int(30 * conditioning_sample_rate) # 30 seconds of audio
677
+ if audio.shape[-1] > max_samples:
678
+ audio = audio[..., :max_samples]
679
+ logger.info(f"Audio truncated to first 30 seconds ({max_samples} samples)")
680
+
681
+ audio_conditioning = convert_audio(
682
+ audio, conditioning_sample_rate, self.config.mimi.sample_rate, 1
683
+ )
684
+
685
+ with display_execution_time("Encoding audio prompt"):
686
+ prompt = self._encode_audio(audio_conditioning.unsqueeze(0).to(self.device))
687
+ # import safetensors.torch
688
+ # safetensors.torch.save_file(
689
+ # {"audio_prompt": prompt},
690
+ # "/projects/huggingface/pocket-tts/embeddings/cosette.safetensors"
691
+ # )
692
+
693
+ model_state = init_states(self.flow_lm, batch_size=1, sequence_length=1000)
694
+
695
+ with display_execution_time("Prompting audio"):
696
+ self._run_flow_lm_and_increment_step(model_state=model_state, audio_conditioning=prompt)
697
+
698
+ # Optimize memory by slicing KV cache to only keep frames from the audio prompt
699
+ num_audio_frames = prompt.shape[1]
700
+ self._slice_kv_cache(model_state, num_audio_frames)
701
+
702
+ return model_state
703
+
704
+
705
+ def prepare_text_prompt(text: str) -> tuple[str, int]:
706
+ text = text.strip()
707
+ if text == "":
708
+ raise ValueError("Text prompt cannot be empty")
709
+ text = text.replace("\n", " ").replace("\r", " ").replace(" ", " ")
710
+ number_of_words = len(text.split())
711
+ if number_of_words <= 4:
712
+ frames_after_eos_guess = 3
713
+ else:
714
+ frames_after_eos_guess = 1
715
+
716
+ # Make sure it starts with an uppercase letter
717
+ if not text[0].isupper():
718
+ text = text[0].upper() + text[1:]
719
+
720
+ # Let's make sure it ends with some kind of punctuation
721
+ # If it ends with a letter or digit, we add a period.
722
+ if text[-1].isalnum():
723
+ text = text + "."
724
+
725
+ # The model does not perform well when there are very few tokens, so
726
+ # we can add empty spaces at the beginning to increase the token count.
727
+ if len(text.split()) < 5:
728
+ text = " " * 8 + text
729
+
730
+ return text, frames_after_eos_guess
731
+
732
+
733
+ def split_into_best_sentences(tokenizer, text_to_generate: str) -> list[str]:
734
+ text_to_generate, _ = prepare_text_prompt(text_to_generate)
735
+ text_to_generate = text_to_generate.strip()
736
+ tokens = tokenizer(text_to_generate)
737
+ list_of_tokens = tokens.tokens[0].tolist()
738
+
739
+ _, *end_of_sentence_tokens = tokenizer(".!...?").tokens[0].tolist()
740
+
741
+ end_of_sentences_indices = [0]
742
+ previous_was_end_of_sentence_token = False
743
+
744
+ for token_idx, token in enumerate(list_of_tokens):
745
+ if token in end_of_sentence_tokens:
746
+ previous_was_end_of_sentence_token = True
747
+ else:
748
+ if previous_was_end_of_sentence_token:
749
+ end_of_sentences_indices.append(token_idx)
750
+ previous_was_end_of_sentence_token = False
751
+ end_of_sentences_indices.append(len(list_of_tokens))
752
+
753
+ nb_tokens_and_sentences = []
754
+ for i in range(len(end_of_sentences_indices) - 1):
755
+ # let's print
756
+ start = end_of_sentences_indices[i]
757
+ end = end_of_sentences_indices[i + 1]
758
+ text = tokenizer.sp.decode(list_of_tokens[start:end])
759
+ nb_tokens_and_sentences.append((end - start, text))
760
+
761
+ max_nb_tokens_in_a_chunk = 50
762
+ chunks = []
763
+ current_chunk = ""
764
+ current_nb_of_tokens_in_chunk = 0
765
+ for nb_tokens, sentence in nb_tokens_and_sentences:
766
+ if current_chunk == "":
767
+ current_chunk = sentence
768
+ current_nb_of_tokens_in_chunk = nb_tokens
769
+ continue
770
+
771
+ if current_nb_of_tokens_in_chunk + nb_tokens > max_nb_tokens_in_a_chunk:
772
+ chunks.append(current_chunk.strip())
773
+ current_chunk = sentence
774
+ current_nb_of_tokens_in_chunk = nb_tokens
775
+ else:
776
+ current_chunk += " " + sentence
777
+ current_nb_of_tokens_in_chunk += nb_tokens
778
+
779
+ if current_chunk != "":
780
+ chunks.append(current_chunk.strip())
781
+
782
+ return chunks