pocket-tts 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pocket_tts/__init__.py ADDED
@@ -0,0 +1,16 @@
1
+ from beartype import BeartypeConf
2
+ from beartype.claw import beartype_this_package
3
+
4
+ beartype_this_package(conf=BeartypeConf(is_color=False))
5
+
6
+ from pocket_tts.models.tts_model import TTSModel # noqa: E402
7
+
8
+ # Public methods:
9
+ # TTSModel.device
10
+ # TTSModel.sample_rate
11
+ # TTSModel.load_model
12
+ # TTSModel.generate_audio
13
+ # TTSModel.generate_audio_stream
14
+ # TTSModel.get_state_for_audio_prompt
15
+
16
+ __all__ = ["TTSModel"]
pocket_tts/__main__.py ADDED
@@ -0,0 +1,6 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from pocket_tts.main import cli_app
4
+
5
+ if __name__ == "__main__":
6
+ cli_app()
File without changes
@@ -0,0 +1,38 @@
1
+ import logging
2
+ from typing import Generic, NamedTuple, TypeVar
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ Prepared = TypeVar("Prepared") # represents the prepared condition input type.
11
+
12
+
13
+ class TokenizedText(NamedTuple):
14
+ tokens: torch.Tensor # should be long tensor.
15
+
16
+
17
+ class BaseConditioner(nn.Module, Generic[Prepared]):
18
+ """Base model for all conditioner modules.
19
+
20
+ Args:
21
+ dim (int): internal dim of the model.
22
+ output_dim (int): Output dim of the conditioner.
23
+ force_linear (bool, optional): Force linear projection even when `dim == output_dim`.
24
+ output_bias (bool): if True, the output projection will have a bias.
25
+ learn_padding (bool): if True, the padding value will be learnt, zero otherwise.
26
+ """
27
+
28
+ def __init__(
29
+ self, dim: int, output_dim: int, output_bias: bool = False, force_linear: bool = True
30
+ ):
31
+ super().__init__()
32
+ self.dim = dim
33
+ self.output_dim = output_dim
34
+ assert force_linear or dim != output_dim
35
+ assert not output_bias
36
+
37
+ def forward(self, inputs: TokenizedText) -> torch.Tensor:
38
+ return self._get_condition(inputs)
@@ -0,0 +1,61 @@
1
+ import logging
2
+
3
+ import sentencepiece
4
+ import torch
5
+ from torch import nn
6
+
7
+ from pocket_tts.conditioners.base import BaseConditioner, TokenizedText
8
+ from pocket_tts.utils.utils import download_if_necessary
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class SentencePieceTokenizer:
14
+ """This tokenizer should be used for natural language descriptions.
15
+ For example:
16
+ ["he didn't, know he's going home.", 'shorter sentence'] =>
17
+ [[78, 62, 31, 4, 78, 25, 19, 34],
18
+ [59, 77, PAD, PAD, PAD, PAD, PAD, PAD]]
19
+
20
+ Args:
21
+ n_bins (int): should be equal to the number of elements in the sentencepiece tokenizer.
22
+ tokenizer_path (str): path to the sentencepiece tokenizer model.
23
+
24
+ """
25
+
26
+ def __init__(self, nbins: int, tokenizer_path: str) -> None:
27
+ logger.info("Loading sentencepiece tokenizer from %s", tokenizer_path)
28
+ tokenizer_path = download_if_necessary(tokenizer_path)
29
+ self.sp = sentencepiece.SentencePieceProcessor(str(tokenizer_path))
30
+ assert nbins == self.sp.vocab_size(), (
31
+ f"sentencepiece tokenizer has vocab size={self.sp.vocab_size()} but nbins={nbins} was specified"
32
+ )
33
+
34
+ def __call__(self, text: str) -> TokenizedText:
35
+ return TokenizedText(torch.tensor(self.sp.encode(text, out_type=int))[None, :])
36
+
37
+
38
+ class LUTConditioner(BaseConditioner):
39
+ """Lookup table TextConditioner.
40
+
41
+ Args:
42
+ n_bins (int): Number of bins.
43
+ dim (int): Hidden dim of the model (text-encoder/LUT).
44
+ output_dim (int): Output dim of the conditioner.
45
+ tokenizer (str): Name of the tokenizer.
46
+ possible_values (list[str] or None): list of possible values for the tokenizer.
47
+ """
48
+
49
+ def __init__(self, n_bins: int, tokenizer_path: str, dim: int, output_dim: int):
50
+ super().__init__(dim=dim, output_dim=output_dim)
51
+ self.tokenizer = SentencePieceTokenizer(n_bins, tokenizer_path)
52
+ self.embed = nn.Embedding(n_bins + 1, self.dim) # n_bins + 1 for padding.
53
+
54
+ def prepare(self, x: str) -> TokenizedText:
55
+ tokens = self.tokenizer(x)
56
+ tokens = tokens[0].to(self.embed.weight.device)
57
+ return TokenizedText(tokens)
58
+
59
+ def _get_condition(self, inputs: TokenizedText) -> torch.Tensor:
60
+ embeds = self.embed(inputs[0])
61
+ return embeds
@@ -0,0 +1,57 @@
1
+ # sig: b6369a24
2
+
3
+ weights_path: hf://kyutai/pocket-tts/tts_b6369a24.safetensors@427e3d61b276ed69fdd03de0d185fa8a8d97fc5b
4
+ weights_path_without_voice_cloning: hf://kyutai/pocket-tts-without-voice-cloning/tts_b6369a24.safetensors@d4fdd22ae8c8e1cb3634e150ebeff1dab2d16df3
5
+
6
+ flow_lm:
7
+ dtype: float32
8
+ flow:
9
+ depth: 6
10
+ dim: 512
11
+ transformer:
12
+ d_model: 1024
13
+ hidden_scale: 4
14
+ max_period: 10000
15
+ num_heads: 16
16
+ num_layers: 6
17
+ lookup_table:
18
+ dim: 1024
19
+ n_bins: 4000
20
+ tokenizer: sentencepiece
21
+ tokenizer_path: hf://kyutai/pocket-tts-without-voice-cloning/tokenizer.model@d4fdd22ae8c8e1cb3634e150ebeff1dab2d16df3
22
+ #weights_path: flow_lm_b6369a24.safetensors
23
+
24
+ mimi:
25
+ dtype: float32
26
+ sample_rate: 24000
27
+ channels: 1
28
+ frame_rate: 12.5
29
+ seanet:
30
+ dimension: 512
31
+ channels: 1
32
+ n_filters: 64
33
+ n_residual_layers: 1
34
+ ratios:
35
+ - 6
36
+ - 5
37
+ - 4
38
+ kernel_size: 7
39
+ residual_kernel_size: 3
40
+ last_kernel_size: 3
41
+ dilation_base: 2
42
+ pad_mode: constant
43
+ compress: 2
44
+ transformer:
45
+ d_model: 512
46
+ num_heads: 8
47
+ num_layers: 2
48
+ layer_scale: 0.01
49
+ context: 250
50
+ dim_feedforward: 2048
51
+ input_dimension: 512
52
+ output_dimensions:
53
+ - 512
54
+ quantizer:
55
+ dimension: 32
56
+ output_dimension: 512
57
+ #weights_path: mimi_b6369a24.safetensors
@@ -0,0 +1,2 @@
1
+ """Audio loading and writing support. Datasets for raw audio
2
+ or also including some metadata."""
@@ -0,0 +1,144 @@
1
+ """
2
+ Audio IO methods are defined in this module (info, read, write),
3
+ We rely on av library for faster read when possible, otherwise on torchaudio.
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ import sys
9
+ import wave
10
+ from contextlib import nullcontext
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ import numpy as np
15
+ import torch
16
+ from beartype.typing import Iterator
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ FIRST_CHUNK_LENGTH_SECONDS = float(os.environ.get("FIRST_CHUNK_LENGTH_SECONDS", "0"))
21
+
22
+
23
+ def audio_read(filepath: str | Path) -> tuple[torch.Tensor, int]:
24
+ """Read audio file. WAV uses built-in wave module; other formats require soundfile."""
25
+ filepath = Path(filepath)
26
+
27
+ if filepath.suffix.lower() == ".wav":
28
+ # Use built-in wave module for WAV files
29
+ with wave.open(str(filepath), "rb") as wav_file:
30
+ sample_rate = wav_file.getframerate()
31
+ n_channels = wav_file.getnchannels()
32
+ raw_data = wav_file.readframes(-1)
33
+ samples = np.frombuffer(raw_data, dtype=np.int16).astype(np.float32) / 32768.0
34
+ if n_channels > 1:
35
+ samples = samples.reshape(-1, n_channels).mean(axis=1)
36
+ return torch.from_numpy(samples).unsqueeze(0), sample_rate
37
+
38
+ # For non-WAV formats, use soundfile (optional dependency)
39
+ try:
40
+ import soundfile as sf
41
+ except ImportError as e:
42
+ raise ImportError(
43
+ "soundfile is required to read non-WAV audio files. "
44
+ "Install with: `pip install soundfile` or `uvx --with soundfile`"
45
+ ) from e
46
+
47
+ data, sample_rate = sf.read(str(filepath), dtype="float32")
48
+ if data.ndim == 1:
49
+ wav = torch.from_numpy(data).unsqueeze(0)
50
+ else:
51
+ wav = torch.from_numpy(data.mean(axis=1)).unsqueeze(0)
52
+ return wav, sample_rate
53
+
54
+
55
+ class StreamingWAVWriter:
56
+ """WAV writer using Python's standard library wave module."""
57
+
58
+ def __init__(self, output_stream, sample_rate: int):
59
+ self.output_stream = output_stream
60
+ self.sample_rate = sample_rate
61
+ self.wave_writer = None
62
+ self.first_chunk_buffer = []
63
+
64
+ def write_header(self, sample_rate: int):
65
+ """Initialize WAV writer with header."""
66
+ # For stdout streaming, we need to handle the unseekable stream case
67
+ # The wave module supports unseekable streams since Python 3.4
68
+ self.wave_writer = wave.open(self.output_stream, "wb")
69
+ self.wave_writer.setnchannels(1) # Mono
70
+ self.wave_writer.setsampwidth(2) # 16-bit
71
+ self.wave_writer.setframerate(sample_rate)
72
+ self.wave_writer.setnframes(1_000_000_000)
73
+
74
+ def write_pcm_data(self, audio_chunk: torch.Tensor):
75
+ """Write PCM data using wave module."""
76
+ # Convert to int16 PCM bytes
77
+ chunk_int16 = (audio_chunk.clamp(-1, 1) * 32767).short()
78
+ chunk_bytes = chunk_int16.detach().cpu().numpy().tobytes()
79
+
80
+ if self.first_chunk_buffer is not None:
81
+ self.first_chunk_buffer.append(chunk_bytes)
82
+ total_length = sum(len(c) for c in self.first_chunk_buffer)
83
+ target_length = (
84
+ int(self.sample_rate * FIRST_CHUNK_LENGTH_SECONDS) * 2
85
+ ) # 2 bytes per sample
86
+ if total_length < target_length:
87
+ return
88
+ self._flush()
89
+ return
90
+
91
+ # Use writeframesraw to avoid frame count validation for streaming
92
+ self.wave_writer.writeframesraw(chunk_bytes)
93
+
94
+ def _flush(self):
95
+ if self.first_chunk_buffer is not None:
96
+ self.wave_writer.writeframesraw(b"".join(self.first_chunk_buffer))
97
+ self.first_chunk_buffer = None
98
+
99
+ def finalize(self):
100
+ """Close the wave writer."""
101
+ self._flush()
102
+
103
+ # Let's add 200ms of silence to ensure proper playback
104
+ silence_duration_sec = 0.2
105
+ num_silence_samples = int(self.sample_rate * silence_duration_sec)
106
+
107
+ self.wave_writer.writeframesraw(bytes(num_silence_samples * 2))
108
+
109
+ if self.wave_writer:
110
+ # do not update the header for unseekable streams
111
+ self.wave_writer._patchheader = lambda: None
112
+ self.wave_writer.close()
113
+
114
+
115
+ def is_file_like(obj):
116
+ """Check if object has basic file-like methods."""
117
+ return all(hasattr(obj, attr) for attr in ["write", "close"])
118
+
119
+
120
+ def stream_audio_chunks(
121
+ path: str | Path | None | Any, audio_chunks: Iterator[torch.Tensor], sample_rate: int
122
+ ):
123
+ """Stream audio chunks to a WAV file or stdout, optionally playing them."""
124
+ if path == "-":
125
+ f = sys.stdout.buffer
126
+ elif path is None:
127
+ f = nullcontext()
128
+ elif is_file_like(path):
129
+ f = path
130
+ else:
131
+ f = open(path, "wb")
132
+
133
+ with f:
134
+ if path is not None:
135
+ writer = StreamingWAVWriter(f, sample_rate)
136
+ writer.write_header(sample_rate)
137
+
138
+ for chunk in audio_chunks:
139
+ # Then write to file
140
+ if path is not None:
141
+ writer.write_pcm_data(chunk)
142
+
143
+ if path is not None:
144
+ writer.finalize()
@@ -0,0 +1,28 @@
1
+ """Various utilities for audio convertion (pcm format, sample rate and channels),
2
+ and volume normalization."""
3
+
4
+ import torch
5
+ from scipy.signal import resample_poly
6
+
7
+
8
+ def convert_audio(
9
+ wav: torch.Tensor, from_rate: int | float, to_rate: int | float, to_channels: int
10
+ ) -> torch.Tensor:
11
+ """Convert audio to new sample rate and number of audio channels."""
12
+ if from_rate != to_rate:
13
+ # Convert to numpy for scipy resampling
14
+ wav_np = wav.detach().cpu().numpy()
15
+
16
+ # Calculate resampling parameters
17
+ gcd = int(torch.gcd(torch.tensor(from_rate), torch.tensor(to_rate)).item())
18
+ up = int(to_rate // gcd)
19
+ down = int(from_rate // gcd)
20
+
21
+ # Resample using scipy
22
+ resampled_np = resample_poly(wav_np, up, down, axis=-1)
23
+
24
+ # Convert back to torch tensor
25
+ wav = torch.from_numpy(resampled_np).to(wav.device).to(wav.dtype)
26
+
27
+ assert wav.shape[-2] == to_channels
28
+ return wav
@@ -0,0 +1,7 @@
1
+ DEFAULT_AUDIO_PROMPT = "alba"
2
+ DEFAULT_VARIANT = "b6369a24"
3
+ DEFAULT_TEMPERATURE = 0.7
4
+ DEFAULT_LSD_DECODE_STEPS = 1
5
+ DEFAULT_NOISE_CLAMP = None
6
+ DEFAULT_EOS_THRESHOLD = -4.0
7
+ DEFAULT_FRAMES_AFTER_EOS = None
pocket_tts/main.py ADDED
@@ -0,0 +1,262 @@
1
+ import io
2
+ import logging
3
+ import os
4
+ import tempfile
5
+ import threading
6
+ from pathlib import Path
7
+ from queue import Queue
8
+
9
+ import typer
10
+ import uvicorn
11
+ from fastapi import FastAPI, File, Form, HTTPException, UploadFile
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+ from fastapi.responses import FileResponse, StreamingResponse
14
+ from typing_extensions import Annotated
15
+
16
+ from pocket_tts.data.audio import stream_audio_chunks
17
+ from pocket_tts.default_parameters import (
18
+ DEFAULT_AUDIO_PROMPT,
19
+ DEFAULT_EOS_THRESHOLD,
20
+ DEFAULT_FRAMES_AFTER_EOS,
21
+ DEFAULT_LSD_DECODE_STEPS,
22
+ DEFAULT_NOISE_CLAMP,
23
+ DEFAULT_TEMPERATURE,
24
+ DEFAULT_VARIANT,
25
+ )
26
+ from pocket_tts.models.tts_model import TTSModel
27
+ from pocket_tts.utils.logging_utils import enable_logging
28
+ from pocket_tts.utils.utils import PREDEFINED_VOICES, size_of_dict
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ cli_app = typer.Typer(
33
+ help="Kyutai Pocket TTS - Text-to-Speech generation tool", pretty_exceptions_show_locals=False
34
+ )
35
+
36
+
37
+ # ------------------------------------------------------
38
+ # The pocket-tts server implementation
39
+ # ------------------------------------------------------
40
+
41
+ # Global model instance
42
+ tts_model = None
43
+ global_model_state = None
44
+
45
+ web_app = FastAPI(
46
+ title="Kyutai Pocket TTS API", description="Text-to-Speech generation API", version="1.0.0"
47
+ )
48
+ web_app.add_middleware(
49
+ CORSMiddleware,
50
+ allow_origins=[
51
+ "http://localhost:3000",
52
+ "https://pod1-10007.internal.kyutai.org",
53
+ "https://kyutai.org",
54
+ ],
55
+ allow_credentials=True,
56
+ allow_methods=["*"],
57
+ allow_headers=["*"],
58
+ )
59
+
60
+
61
+ @web_app.get("/")
62
+ async def root():
63
+ """Serve the frontend."""
64
+ static_path = Path(__file__).parent / "static" / "index.html"
65
+ return FileResponse(static_path)
66
+
67
+
68
+ @web_app.get("/health")
69
+ async def health():
70
+ return {"status": "healthy"}
71
+
72
+
73
+ def write_to_queue(queue, text_to_generate, model_state):
74
+ """Allows writing to the StreamingResponse as if it were a file."""
75
+
76
+ class FileLikeToQueue(io.IOBase):
77
+ def __init__(self, queue):
78
+ self.queue = queue
79
+
80
+ def write(self, data):
81
+ self.queue.put(data)
82
+
83
+ def flush(self):
84
+ pass
85
+
86
+ def close(self):
87
+ self.queue.put(None)
88
+
89
+ audio_chunks = tts_model.generate_audio_stream(
90
+ model_state=model_state, text_to_generate=text_to_generate
91
+ )
92
+ stream_audio_chunks(FileLikeToQueue(queue), audio_chunks, tts_model.config.mimi.sample_rate)
93
+
94
+
95
+ def generate_data_with_state(text_to_generate: str, model_state: dict):
96
+ queue = Queue()
97
+
98
+ # Run your function in a thread
99
+ thread = threading.Thread(target=write_to_queue, args=(queue, text_to_generate, model_state))
100
+ thread.start()
101
+
102
+ # Yield data as it becomes available
103
+ i = 0
104
+ while True:
105
+ data = queue.get()
106
+ if data is None:
107
+ break
108
+ i += 1
109
+ yield data
110
+
111
+ thread.join()
112
+
113
+
114
+ @web_app.post("/tts")
115
+ def text_to_speech(
116
+ text: str = Form(...),
117
+ voice_url: str | None = Form(None),
118
+ voice_wav: UploadFile | None = File(None),
119
+ ):
120
+ """
121
+ Generate speech from text using the pre-loaded voice prompt or a custom voice.
122
+
123
+ Args:
124
+ text: Text to convert to speech
125
+ voice_url: Optional voice URL (http://, https://, or hf://)
126
+ voice_wav: Optional uploaded voice file (mutually exclusive with voice_url)
127
+ """
128
+ if not text.strip():
129
+ raise HTTPException(status_code=400, detail="Text cannot be empty")
130
+
131
+ if voice_url is not None and voice_wav is not None:
132
+ raise HTTPException(status_code=400, detail="Cannot provide both voice_url and voice_wav")
133
+
134
+ # Use the appropriate model state
135
+ if voice_url is not None:
136
+ if not (
137
+ voice_url.startswith("http://")
138
+ or voice_url.startswith("https://")
139
+ or voice_url.startswith("hf://")
140
+ or voice_url in PREDEFINED_VOICES
141
+ ):
142
+ raise HTTPException(
143
+ status_code=400, detail="voice_url must start with http://, https://, or hf://"
144
+ )
145
+ model_state = tts_model._cached_get_state_for_audio_prompt(voice_url)
146
+ logging.warning("Using voice from URL: %s", voice_url)
147
+ elif voice_wav is not None:
148
+ # Use uploaded voice file - preserve extension for format detection
149
+ suffix = Path(voice_wav.filename).suffix if voice_wav.filename else ".wav"
150
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
151
+ content = voice_wav.file.read()
152
+ temp_file.write(content)
153
+ temp_file.flush()
154
+
155
+ try:
156
+ model_state = tts_model.get_state_for_audio_prompt(
157
+ Path(temp_file.name), truncate=True
158
+ )
159
+ finally:
160
+ os.unlink(temp_file.name)
161
+ else:
162
+ # Use default global model state
163
+ model_state = global_model_state
164
+
165
+ return StreamingResponse(
166
+ generate_data_with_state(text, model_state),
167
+ media_type="audio/wav",
168
+ headers={
169
+ "Content-Disposition": "attachment; filename=generated_speech.wav",
170
+ "Transfer-Encoding": "chunked",
171
+ },
172
+ )
173
+
174
+
175
+ @cli_app.command()
176
+ def serve(
177
+ voice: Annotated[
178
+ str, typer.Option(help="Path to voice prompt audio file (voice to clone)")
179
+ ] = DEFAULT_AUDIO_PROMPT,
180
+ host: Annotated[str, typer.Option(help="Host to bind to")] = "localhost",
181
+ port: Annotated[int, typer.Option(help="Port to bind to")] = 8000,
182
+ reload: Annotated[bool, typer.Option(help="Enable auto-reload")] = False,
183
+ ):
184
+ """Start the FastAPI server."""
185
+
186
+ global tts_model, global_model_state
187
+ tts_model = TTSModel.load_model(DEFAULT_VARIANT)
188
+
189
+ # Pre-load the voice prompt
190
+ global_model_state = tts_model.get_state_for_audio_prompt(voice)
191
+ logger.info(f"The size of the model state is {size_of_dict(global_model_state) // 1e6} MB")
192
+
193
+ uvicorn.run("pocket_tts.main:web_app", host=host, port=port, reload=reload)
194
+
195
+
196
+ # ------------------------------------------------------
197
+ # The pocket-tts single generation CLI implementation
198
+ # ------------------------------------------------------
199
+
200
+
201
+ @cli_app.command()
202
+ def generate(
203
+ text: Annotated[
204
+ str, typer.Option(help="Text to generate")
205
+ ] = "Hello world. I am Kyutai's Pocket TTS. I'm fast enough to run on small CPUs. I hope you'll like me.",
206
+ voice: Annotated[
207
+ str, typer.Option(help="Path to audio conditioning file (voice to clone)")
208
+ ] = DEFAULT_AUDIO_PROMPT,
209
+ quiet: Annotated[bool, typer.Option("-q", "--quiet", help="Disable logging output")] = False,
210
+ variant: Annotated[str, typer.Option(help="Model signature")] = DEFAULT_VARIANT,
211
+ lsd_decode_steps: Annotated[
212
+ int, typer.Option(help="Number of generation steps")
213
+ ] = DEFAULT_LSD_DECODE_STEPS,
214
+ temperature: Annotated[
215
+ float, typer.Option(help="Temperature for generation")
216
+ ] = DEFAULT_TEMPERATURE,
217
+ noise_clamp: Annotated[float, typer.Option(help="Noise clamp value")] = DEFAULT_NOISE_CLAMP,
218
+ eos_threshold: Annotated[float, typer.Option(help="EOS threshold")] = DEFAULT_EOS_THRESHOLD,
219
+ frames_after_eos: Annotated[
220
+ int, typer.Option(help="Number of frames to generate after EOS")
221
+ ] = DEFAULT_FRAMES_AFTER_EOS,
222
+ output_path: Annotated[
223
+ str, typer.Option(help="Output path for generated audio")
224
+ ] = "./tts_output.wav",
225
+ device: Annotated[str, typer.Option(help="Device to use")] = "cpu",
226
+ ):
227
+ """Generate speech using Kyutai Pocket TTS."""
228
+ if "cuda" in device:
229
+ # Cuda graphs capturing does not play nice with multithreading.
230
+ os.environ["NO_CUDA_GRAPH"] = "1"
231
+
232
+ log_level = logging.ERROR if quiet else logging.INFO
233
+ with enable_logging("pocket_tts", log_level):
234
+ tts_model = TTSModel.load_model(
235
+ variant, temperature, lsd_decode_steps, noise_clamp, eos_threshold
236
+ )
237
+ tts_model.to(device)
238
+
239
+ model_state_for_voice = tts_model.get_state_for_audio_prompt(voice)
240
+ # Stream audio generation directly to file or stdout
241
+ audio_chunks = tts_model.generate_audio_stream(
242
+ model_state=model_state_for_voice,
243
+ text_to_generate=text,
244
+ frames_after_eos=frames_after_eos,
245
+ )
246
+
247
+ stream_audio_chunks(output_path, audio_chunks, tts_model.config.mimi.sample_rate)
248
+
249
+ # Only print the result message if not writing to stdout
250
+ if output_path != "-":
251
+ logger.info("Results written in %s", output_path)
252
+ logger.info("-" * 20)
253
+ logger.info(
254
+ "If you want to try multiple voices and prompts quickly, try the `serve` command."
255
+ )
256
+ logger.info(
257
+ "If you like Kyutai projects, comment, like, subscribe at https://x.com/kyutai_labs"
258
+ )
259
+
260
+
261
+ if __name__ == "__main__":
262
+ cli_app()
@@ -0,0 +1,3 @@
1
+ """
2
+ Models for EnCodec (and Mimi), and FlowLMModel.
3
+ """