pocket-tts 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pocket_tts/__init__.py +16 -0
- pocket_tts/__main__.py +6 -0
- pocket_tts/conditioners/__init__.py +0 -0
- pocket_tts/conditioners/base.py +38 -0
- pocket_tts/conditioners/text.py +61 -0
- pocket_tts/config/b6369a24.yaml +57 -0
- pocket_tts/data/__init__.py +2 -0
- pocket_tts/data/audio.py +144 -0
- pocket_tts/data/audio_utils.py +28 -0
- pocket_tts/default_parameters.py +7 -0
- pocket_tts/main.py +262 -0
- pocket_tts/models/__init__.py +3 -0
- pocket_tts/models/flow_lm.py +208 -0
- pocket_tts/models/mimi.py +111 -0
- pocket_tts/models/tts_model.py +782 -0
- pocket_tts/modules/__init__.py +1 -0
- pocket_tts/modules/conv.py +161 -0
- pocket_tts/modules/dummy_quantizer.py +18 -0
- pocket_tts/modules/layer_scale.py +11 -0
- pocket_tts/modules/mimi_transformer.py +285 -0
- pocket_tts/modules/mlp.py +215 -0
- pocket_tts/modules/resample.py +46 -0
- pocket_tts/modules/rope.py +74 -0
- pocket_tts/modules/seanet.py +180 -0
- pocket_tts/modules/stateful_module.py +45 -0
- pocket_tts/modules/transformer.py +124 -0
- pocket_tts/static/index.html +374 -0
- pocket_tts/utils/__init__.py +1 -0
- pocket_tts/utils/config.py +122 -0
- pocket_tts/utils/debugging.py +26 -0
- pocket_tts/utils/logging_utils.py +41 -0
- pocket_tts/utils/utils.py +103 -0
- pocket_tts/utils/weights_loading.py +35 -0
- pocket_tts-1.0.2.dist-info/METADATA +174 -0
- pocket_tts-1.0.2.dist-info/RECORD +38 -0
- pocket_tts-1.0.2.dist-info/WHEEL +4 -0
- pocket_tts-1.0.2.dist-info/entry_points.txt +2 -0
- pocket_tts-1.0.2.dist-info/licenses/LICENSE +23 -0
pocket_tts/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from beartype import BeartypeConf
|
|
2
|
+
from beartype.claw import beartype_this_package
|
|
3
|
+
|
|
4
|
+
beartype_this_package(conf=BeartypeConf(is_color=False))
|
|
5
|
+
|
|
6
|
+
from pocket_tts.models.tts_model import TTSModel # noqa: E402
|
|
7
|
+
|
|
8
|
+
# Public methods:
|
|
9
|
+
# TTSModel.device
|
|
10
|
+
# TTSModel.sample_rate
|
|
11
|
+
# TTSModel.load_model
|
|
12
|
+
# TTSModel.generate_audio
|
|
13
|
+
# TTSModel.generate_audio_stream
|
|
14
|
+
# TTSModel.get_state_for_audio_prompt
|
|
15
|
+
|
|
16
|
+
__all__ = ["TTSModel"]
|
pocket_tts/__main__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Generic, NamedTuple, TypeVar
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
Prepared = TypeVar("Prepared") # represents the prepared condition input type.
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TokenizedText(NamedTuple):
|
|
14
|
+
tokens: torch.Tensor # should be long tensor.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BaseConditioner(nn.Module, Generic[Prepared]):
|
|
18
|
+
"""Base model for all conditioner modules.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
dim (int): internal dim of the model.
|
|
22
|
+
output_dim (int): Output dim of the conditioner.
|
|
23
|
+
force_linear (bool, optional): Force linear projection even when `dim == output_dim`.
|
|
24
|
+
output_bias (bool): if True, the output projection will have a bias.
|
|
25
|
+
learn_padding (bool): if True, the padding value will be learnt, zero otherwise.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self, dim: int, output_dim: int, output_bias: bool = False, force_linear: bool = True
|
|
30
|
+
):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.dim = dim
|
|
33
|
+
self.output_dim = output_dim
|
|
34
|
+
assert force_linear or dim != output_dim
|
|
35
|
+
assert not output_bias
|
|
36
|
+
|
|
37
|
+
def forward(self, inputs: TokenizedText) -> torch.Tensor:
|
|
38
|
+
return self._get_condition(inputs)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import sentencepiece
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from pocket_tts.conditioners.base import BaseConditioner, TokenizedText
|
|
8
|
+
from pocket_tts.utils.utils import download_if_necessary
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SentencePieceTokenizer:
|
|
14
|
+
"""This tokenizer should be used for natural language descriptions.
|
|
15
|
+
For example:
|
|
16
|
+
["he didn't, know he's going home.", 'shorter sentence'] =>
|
|
17
|
+
[[78, 62, 31, 4, 78, 25, 19, 34],
|
|
18
|
+
[59, 77, PAD, PAD, PAD, PAD, PAD, PAD]]
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
n_bins (int): should be equal to the number of elements in the sentencepiece tokenizer.
|
|
22
|
+
tokenizer_path (str): path to the sentencepiece tokenizer model.
|
|
23
|
+
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, nbins: int, tokenizer_path: str) -> None:
|
|
27
|
+
logger.info("Loading sentencepiece tokenizer from %s", tokenizer_path)
|
|
28
|
+
tokenizer_path = download_if_necessary(tokenizer_path)
|
|
29
|
+
self.sp = sentencepiece.SentencePieceProcessor(str(tokenizer_path))
|
|
30
|
+
assert nbins == self.sp.vocab_size(), (
|
|
31
|
+
f"sentencepiece tokenizer has vocab size={self.sp.vocab_size()} but nbins={nbins} was specified"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def __call__(self, text: str) -> TokenizedText:
|
|
35
|
+
return TokenizedText(torch.tensor(self.sp.encode(text, out_type=int))[None, :])
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class LUTConditioner(BaseConditioner):
|
|
39
|
+
"""Lookup table TextConditioner.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
n_bins (int): Number of bins.
|
|
43
|
+
dim (int): Hidden dim of the model (text-encoder/LUT).
|
|
44
|
+
output_dim (int): Output dim of the conditioner.
|
|
45
|
+
tokenizer (str): Name of the tokenizer.
|
|
46
|
+
possible_values (list[str] or None): list of possible values for the tokenizer.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, n_bins: int, tokenizer_path: str, dim: int, output_dim: int):
|
|
50
|
+
super().__init__(dim=dim, output_dim=output_dim)
|
|
51
|
+
self.tokenizer = SentencePieceTokenizer(n_bins, tokenizer_path)
|
|
52
|
+
self.embed = nn.Embedding(n_bins + 1, self.dim) # n_bins + 1 for padding.
|
|
53
|
+
|
|
54
|
+
def prepare(self, x: str) -> TokenizedText:
|
|
55
|
+
tokens = self.tokenizer(x)
|
|
56
|
+
tokens = tokens[0].to(self.embed.weight.device)
|
|
57
|
+
return TokenizedText(tokens)
|
|
58
|
+
|
|
59
|
+
def _get_condition(self, inputs: TokenizedText) -> torch.Tensor:
|
|
60
|
+
embeds = self.embed(inputs[0])
|
|
61
|
+
return embeds
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# sig: b6369a24
|
|
2
|
+
|
|
3
|
+
weights_path: hf://kyutai/pocket-tts/tts_b6369a24.safetensors@427e3d61b276ed69fdd03de0d185fa8a8d97fc5b
|
|
4
|
+
weights_path_without_voice_cloning: hf://kyutai/pocket-tts-without-voice-cloning/tts_b6369a24.safetensors@d4fdd22ae8c8e1cb3634e150ebeff1dab2d16df3
|
|
5
|
+
|
|
6
|
+
flow_lm:
|
|
7
|
+
dtype: float32
|
|
8
|
+
flow:
|
|
9
|
+
depth: 6
|
|
10
|
+
dim: 512
|
|
11
|
+
transformer:
|
|
12
|
+
d_model: 1024
|
|
13
|
+
hidden_scale: 4
|
|
14
|
+
max_period: 10000
|
|
15
|
+
num_heads: 16
|
|
16
|
+
num_layers: 6
|
|
17
|
+
lookup_table:
|
|
18
|
+
dim: 1024
|
|
19
|
+
n_bins: 4000
|
|
20
|
+
tokenizer: sentencepiece
|
|
21
|
+
tokenizer_path: hf://kyutai/pocket-tts-without-voice-cloning/tokenizer.model@d4fdd22ae8c8e1cb3634e150ebeff1dab2d16df3
|
|
22
|
+
#weights_path: flow_lm_b6369a24.safetensors
|
|
23
|
+
|
|
24
|
+
mimi:
|
|
25
|
+
dtype: float32
|
|
26
|
+
sample_rate: 24000
|
|
27
|
+
channels: 1
|
|
28
|
+
frame_rate: 12.5
|
|
29
|
+
seanet:
|
|
30
|
+
dimension: 512
|
|
31
|
+
channels: 1
|
|
32
|
+
n_filters: 64
|
|
33
|
+
n_residual_layers: 1
|
|
34
|
+
ratios:
|
|
35
|
+
- 6
|
|
36
|
+
- 5
|
|
37
|
+
- 4
|
|
38
|
+
kernel_size: 7
|
|
39
|
+
residual_kernel_size: 3
|
|
40
|
+
last_kernel_size: 3
|
|
41
|
+
dilation_base: 2
|
|
42
|
+
pad_mode: constant
|
|
43
|
+
compress: 2
|
|
44
|
+
transformer:
|
|
45
|
+
d_model: 512
|
|
46
|
+
num_heads: 8
|
|
47
|
+
num_layers: 2
|
|
48
|
+
layer_scale: 0.01
|
|
49
|
+
context: 250
|
|
50
|
+
dim_feedforward: 2048
|
|
51
|
+
input_dimension: 512
|
|
52
|
+
output_dimensions:
|
|
53
|
+
- 512
|
|
54
|
+
quantizer:
|
|
55
|
+
dimension: 32
|
|
56
|
+
output_dimension: 512
|
|
57
|
+
#weights_path: mimi_b6369a24.safetensors
|
pocket_tts/data/audio.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Audio IO methods are defined in this module (info, read, write),
|
|
3
|
+
We rely on av library for faster read when possible, otherwise on torchaudio.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
import wave
|
|
10
|
+
from contextlib import nullcontext
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import torch
|
|
16
|
+
from beartype.typing import Iterator
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
FIRST_CHUNK_LENGTH_SECONDS = float(os.environ.get("FIRST_CHUNK_LENGTH_SECONDS", "0"))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def audio_read(filepath: str | Path) -> tuple[torch.Tensor, int]:
|
|
24
|
+
"""Read audio file. WAV uses built-in wave module; other formats require soundfile."""
|
|
25
|
+
filepath = Path(filepath)
|
|
26
|
+
|
|
27
|
+
if filepath.suffix.lower() == ".wav":
|
|
28
|
+
# Use built-in wave module for WAV files
|
|
29
|
+
with wave.open(str(filepath), "rb") as wav_file:
|
|
30
|
+
sample_rate = wav_file.getframerate()
|
|
31
|
+
n_channels = wav_file.getnchannels()
|
|
32
|
+
raw_data = wav_file.readframes(-1)
|
|
33
|
+
samples = np.frombuffer(raw_data, dtype=np.int16).astype(np.float32) / 32768.0
|
|
34
|
+
if n_channels > 1:
|
|
35
|
+
samples = samples.reshape(-1, n_channels).mean(axis=1)
|
|
36
|
+
return torch.from_numpy(samples).unsqueeze(0), sample_rate
|
|
37
|
+
|
|
38
|
+
# For non-WAV formats, use soundfile (optional dependency)
|
|
39
|
+
try:
|
|
40
|
+
import soundfile as sf
|
|
41
|
+
except ImportError as e:
|
|
42
|
+
raise ImportError(
|
|
43
|
+
"soundfile is required to read non-WAV audio files. "
|
|
44
|
+
"Install with: `pip install soundfile` or `uvx --with soundfile`"
|
|
45
|
+
) from e
|
|
46
|
+
|
|
47
|
+
data, sample_rate = sf.read(str(filepath), dtype="float32")
|
|
48
|
+
if data.ndim == 1:
|
|
49
|
+
wav = torch.from_numpy(data).unsqueeze(0)
|
|
50
|
+
else:
|
|
51
|
+
wav = torch.from_numpy(data.mean(axis=1)).unsqueeze(0)
|
|
52
|
+
return wav, sample_rate
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class StreamingWAVWriter:
|
|
56
|
+
"""WAV writer using Python's standard library wave module."""
|
|
57
|
+
|
|
58
|
+
def __init__(self, output_stream, sample_rate: int):
|
|
59
|
+
self.output_stream = output_stream
|
|
60
|
+
self.sample_rate = sample_rate
|
|
61
|
+
self.wave_writer = None
|
|
62
|
+
self.first_chunk_buffer = []
|
|
63
|
+
|
|
64
|
+
def write_header(self, sample_rate: int):
|
|
65
|
+
"""Initialize WAV writer with header."""
|
|
66
|
+
# For stdout streaming, we need to handle the unseekable stream case
|
|
67
|
+
# The wave module supports unseekable streams since Python 3.4
|
|
68
|
+
self.wave_writer = wave.open(self.output_stream, "wb")
|
|
69
|
+
self.wave_writer.setnchannels(1) # Mono
|
|
70
|
+
self.wave_writer.setsampwidth(2) # 16-bit
|
|
71
|
+
self.wave_writer.setframerate(sample_rate)
|
|
72
|
+
self.wave_writer.setnframes(1_000_000_000)
|
|
73
|
+
|
|
74
|
+
def write_pcm_data(self, audio_chunk: torch.Tensor):
|
|
75
|
+
"""Write PCM data using wave module."""
|
|
76
|
+
# Convert to int16 PCM bytes
|
|
77
|
+
chunk_int16 = (audio_chunk.clamp(-1, 1) * 32767).short()
|
|
78
|
+
chunk_bytes = chunk_int16.detach().cpu().numpy().tobytes()
|
|
79
|
+
|
|
80
|
+
if self.first_chunk_buffer is not None:
|
|
81
|
+
self.first_chunk_buffer.append(chunk_bytes)
|
|
82
|
+
total_length = sum(len(c) for c in self.first_chunk_buffer)
|
|
83
|
+
target_length = (
|
|
84
|
+
int(self.sample_rate * FIRST_CHUNK_LENGTH_SECONDS) * 2
|
|
85
|
+
) # 2 bytes per sample
|
|
86
|
+
if total_length < target_length:
|
|
87
|
+
return
|
|
88
|
+
self._flush()
|
|
89
|
+
return
|
|
90
|
+
|
|
91
|
+
# Use writeframesraw to avoid frame count validation for streaming
|
|
92
|
+
self.wave_writer.writeframesraw(chunk_bytes)
|
|
93
|
+
|
|
94
|
+
def _flush(self):
|
|
95
|
+
if self.first_chunk_buffer is not None:
|
|
96
|
+
self.wave_writer.writeframesraw(b"".join(self.first_chunk_buffer))
|
|
97
|
+
self.first_chunk_buffer = None
|
|
98
|
+
|
|
99
|
+
def finalize(self):
|
|
100
|
+
"""Close the wave writer."""
|
|
101
|
+
self._flush()
|
|
102
|
+
|
|
103
|
+
# Let's add 200ms of silence to ensure proper playback
|
|
104
|
+
silence_duration_sec = 0.2
|
|
105
|
+
num_silence_samples = int(self.sample_rate * silence_duration_sec)
|
|
106
|
+
|
|
107
|
+
self.wave_writer.writeframesraw(bytes(num_silence_samples * 2))
|
|
108
|
+
|
|
109
|
+
if self.wave_writer:
|
|
110
|
+
# do not update the header for unseekable streams
|
|
111
|
+
self.wave_writer._patchheader = lambda: None
|
|
112
|
+
self.wave_writer.close()
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def is_file_like(obj):
|
|
116
|
+
"""Check if object has basic file-like methods."""
|
|
117
|
+
return all(hasattr(obj, attr) for attr in ["write", "close"])
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def stream_audio_chunks(
|
|
121
|
+
path: str | Path | None | Any, audio_chunks: Iterator[torch.Tensor], sample_rate: int
|
|
122
|
+
):
|
|
123
|
+
"""Stream audio chunks to a WAV file or stdout, optionally playing them."""
|
|
124
|
+
if path == "-":
|
|
125
|
+
f = sys.stdout.buffer
|
|
126
|
+
elif path is None:
|
|
127
|
+
f = nullcontext()
|
|
128
|
+
elif is_file_like(path):
|
|
129
|
+
f = path
|
|
130
|
+
else:
|
|
131
|
+
f = open(path, "wb")
|
|
132
|
+
|
|
133
|
+
with f:
|
|
134
|
+
if path is not None:
|
|
135
|
+
writer = StreamingWAVWriter(f, sample_rate)
|
|
136
|
+
writer.write_header(sample_rate)
|
|
137
|
+
|
|
138
|
+
for chunk in audio_chunks:
|
|
139
|
+
# Then write to file
|
|
140
|
+
if path is not None:
|
|
141
|
+
writer.write_pcm_data(chunk)
|
|
142
|
+
|
|
143
|
+
if path is not None:
|
|
144
|
+
writer.finalize()
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Various utilities for audio convertion (pcm format, sample rate and channels),
|
|
2
|
+
and volume normalization."""
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from scipy.signal import resample_poly
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def convert_audio(
|
|
9
|
+
wav: torch.Tensor, from_rate: int | float, to_rate: int | float, to_channels: int
|
|
10
|
+
) -> torch.Tensor:
|
|
11
|
+
"""Convert audio to new sample rate and number of audio channels."""
|
|
12
|
+
if from_rate != to_rate:
|
|
13
|
+
# Convert to numpy for scipy resampling
|
|
14
|
+
wav_np = wav.detach().cpu().numpy()
|
|
15
|
+
|
|
16
|
+
# Calculate resampling parameters
|
|
17
|
+
gcd = int(torch.gcd(torch.tensor(from_rate), torch.tensor(to_rate)).item())
|
|
18
|
+
up = int(to_rate // gcd)
|
|
19
|
+
down = int(from_rate // gcd)
|
|
20
|
+
|
|
21
|
+
# Resample using scipy
|
|
22
|
+
resampled_np = resample_poly(wav_np, up, down, axis=-1)
|
|
23
|
+
|
|
24
|
+
# Convert back to torch tensor
|
|
25
|
+
wav = torch.from_numpy(resampled_np).to(wav.device).to(wav.dtype)
|
|
26
|
+
|
|
27
|
+
assert wav.shape[-2] == to_channels
|
|
28
|
+
return wav
|
pocket_tts/main.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import tempfile
|
|
5
|
+
import threading
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from queue import Queue
|
|
8
|
+
|
|
9
|
+
import typer
|
|
10
|
+
import uvicorn
|
|
11
|
+
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
|
12
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
13
|
+
from fastapi.responses import FileResponse, StreamingResponse
|
|
14
|
+
from typing_extensions import Annotated
|
|
15
|
+
|
|
16
|
+
from pocket_tts.data.audio import stream_audio_chunks
|
|
17
|
+
from pocket_tts.default_parameters import (
|
|
18
|
+
DEFAULT_AUDIO_PROMPT,
|
|
19
|
+
DEFAULT_EOS_THRESHOLD,
|
|
20
|
+
DEFAULT_FRAMES_AFTER_EOS,
|
|
21
|
+
DEFAULT_LSD_DECODE_STEPS,
|
|
22
|
+
DEFAULT_NOISE_CLAMP,
|
|
23
|
+
DEFAULT_TEMPERATURE,
|
|
24
|
+
DEFAULT_VARIANT,
|
|
25
|
+
)
|
|
26
|
+
from pocket_tts.models.tts_model import TTSModel
|
|
27
|
+
from pocket_tts.utils.logging_utils import enable_logging
|
|
28
|
+
from pocket_tts.utils.utils import PREDEFINED_VOICES, size_of_dict
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
cli_app = typer.Typer(
|
|
33
|
+
help="Kyutai Pocket TTS - Text-to-Speech generation tool", pretty_exceptions_show_locals=False
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# ------------------------------------------------------
|
|
38
|
+
# The pocket-tts server implementation
|
|
39
|
+
# ------------------------------------------------------
|
|
40
|
+
|
|
41
|
+
# Global model instance
|
|
42
|
+
tts_model = None
|
|
43
|
+
global_model_state = None
|
|
44
|
+
|
|
45
|
+
web_app = FastAPI(
|
|
46
|
+
title="Kyutai Pocket TTS API", description="Text-to-Speech generation API", version="1.0.0"
|
|
47
|
+
)
|
|
48
|
+
web_app.add_middleware(
|
|
49
|
+
CORSMiddleware,
|
|
50
|
+
allow_origins=[
|
|
51
|
+
"http://localhost:3000",
|
|
52
|
+
"https://pod1-10007.internal.kyutai.org",
|
|
53
|
+
"https://kyutai.org",
|
|
54
|
+
],
|
|
55
|
+
allow_credentials=True,
|
|
56
|
+
allow_methods=["*"],
|
|
57
|
+
allow_headers=["*"],
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@web_app.get("/")
|
|
62
|
+
async def root():
|
|
63
|
+
"""Serve the frontend."""
|
|
64
|
+
static_path = Path(__file__).parent / "static" / "index.html"
|
|
65
|
+
return FileResponse(static_path)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@web_app.get("/health")
|
|
69
|
+
async def health():
|
|
70
|
+
return {"status": "healthy"}
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def write_to_queue(queue, text_to_generate, model_state):
|
|
74
|
+
"""Allows writing to the StreamingResponse as if it were a file."""
|
|
75
|
+
|
|
76
|
+
class FileLikeToQueue(io.IOBase):
|
|
77
|
+
def __init__(self, queue):
|
|
78
|
+
self.queue = queue
|
|
79
|
+
|
|
80
|
+
def write(self, data):
|
|
81
|
+
self.queue.put(data)
|
|
82
|
+
|
|
83
|
+
def flush(self):
|
|
84
|
+
pass
|
|
85
|
+
|
|
86
|
+
def close(self):
|
|
87
|
+
self.queue.put(None)
|
|
88
|
+
|
|
89
|
+
audio_chunks = tts_model.generate_audio_stream(
|
|
90
|
+
model_state=model_state, text_to_generate=text_to_generate
|
|
91
|
+
)
|
|
92
|
+
stream_audio_chunks(FileLikeToQueue(queue), audio_chunks, tts_model.config.mimi.sample_rate)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def generate_data_with_state(text_to_generate: str, model_state: dict):
|
|
96
|
+
queue = Queue()
|
|
97
|
+
|
|
98
|
+
# Run your function in a thread
|
|
99
|
+
thread = threading.Thread(target=write_to_queue, args=(queue, text_to_generate, model_state))
|
|
100
|
+
thread.start()
|
|
101
|
+
|
|
102
|
+
# Yield data as it becomes available
|
|
103
|
+
i = 0
|
|
104
|
+
while True:
|
|
105
|
+
data = queue.get()
|
|
106
|
+
if data is None:
|
|
107
|
+
break
|
|
108
|
+
i += 1
|
|
109
|
+
yield data
|
|
110
|
+
|
|
111
|
+
thread.join()
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@web_app.post("/tts")
|
|
115
|
+
def text_to_speech(
|
|
116
|
+
text: str = Form(...),
|
|
117
|
+
voice_url: str | None = Form(None),
|
|
118
|
+
voice_wav: UploadFile | None = File(None),
|
|
119
|
+
):
|
|
120
|
+
"""
|
|
121
|
+
Generate speech from text using the pre-loaded voice prompt or a custom voice.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
text: Text to convert to speech
|
|
125
|
+
voice_url: Optional voice URL (http://, https://, or hf://)
|
|
126
|
+
voice_wav: Optional uploaded voice file (mutually exclusive with voice_url)
|
|
127
|
+
"""
|
|
128
|
+
if not text.strip():
|
|
129
|
+
raise HTTPException(status_code=400, detail="Text cannot be empty")
|
|
130
|
+
|
|
131
|
+
if voice_url is not None and voice_wav is not None:
|
|
132
|
+
raise HTTPException(status_code=400, detail="Cannot provide both voice_url and voice_wav")
|
|
133
|
+
|
|
134
|
+
# Use the appropriate model state
|
|
135
|
+
if voice_url is not None:
|
|
136
|
+
if not (
|
|
137
|
+
voice_url.startswith("http://")
|
|
138
|
+
or voice_url.startswith("https://")
|
|
139
|
+
or voice_url.startswith("hf://")
|
|
140
|
+
or voice_url in PREDEFINED_VOICES
|
|
141
|
+
):
|
|
142
|
+
raise HTTPException(
|
|
143
|
+
status_code=400, detail="voice_url must start with http://, https://, or hf://"
|
|
144
|
+
)
|
|
145
|
+
model_state = tts_model._cached_get_state_for_audio_prompt(voice_url)
|
|
146
|
+
logging.warning("Using voice from URL: %s", voice_url)
|
|
147
|
+
elif voice_wav is not None:
|
|
148
|
+
# Use uploaded voice file - preserve extension for format detection
|
|
149
|
+
suffix = Path(voice_wav.filename).suffix if voice_wav.filename else ".wav"
|
|
150
|
+
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
|
151
|
+
content = voice_wav.file.read()
|
|
152
|
+
temp_file.write(content)
|
|
153
|
+
temp_file.flush()
|
|
154
|
+
|
|
155
|
+
try:
|
|
156
|
+
model_state = tts_model.get_state_for_audio_prompt(
|
|
157
|
+
Path(temp_file.name), truncate=True
|
|
158
|
+
)
|
|
159
|
+
finally:
|
|
160
|
+
os.unlink(temp_file.name)
|
|
161
|
+
else:
|
|
162
|
+
# Use default global model state
|
|
163
|
+
model_state = global_model_state
|
|
164
|
+
|
|
165
|
+
return StreamingResponse(
|
|
166
|
+
generate_data_with_state(text, model_state),
|
|
167
|
+
media_type="audio/wav",
|
|
168
|
+
headers={
|
|
169
|
+
"Content-Disposition": "attachment; filename=generated_speech.wav",
|
|
170
|
+
"Transfer-Encoding": "chunked",
|
|
171
|
+
},
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@cli_app.command()
|
|
176
|
+
def serve(
|
|
177
|
+
voice: Annotated[
|
|
178
|
+
str, typer.Option(help="Path to voice prompt audio file (voice to clone)")
|
|
179
|
+
] = DEFAULT_AUDIO_PROMPT,
|
|
180
|
+
host: Annotated[str, typer.Option(help="Host to bind to")] = "localhost",
|
|
181
|
+
port: Annotated[int, typer.Option(help="Port to bind to")] = 8000,
|
|
182
|
+
reload: Annotated[bool, typer.Option(help="Enable auto-reload")] = False,
|
|
183
|
+
):
|
|
184
|
+
"""Start the FastAPI server."""
|
|
185
|
+
|
|
186
|
+
global tts_model, global_model_state
|
|
187
|
+
tts_model = TTSModel.load_model(DEFAULT_VARIANT)
|
|
188
|
+
|
|
189
|
+
# Pre-load the voice prompt
|
|
190
|
+
global_model_state = tts_model.get_state_for_audio_prompt(voice)
|
|
191
|
+
logger.info(f"The size of the model state is {size_of_dict(global_model_state) // 1e6} MB")
|
|
192
|
+
|
|
193
|
+
uvicorn.run("pocket_tts.main:web_app", host=host, port=port, reload=reload)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
# ------------------------------------------------------
|
|
197
|
+
# The pocket-tts single generation CLI implementation
|
|
198
|
+
# ------------------------------------------------------
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@cli_app.command()
|
|
202
|
+
def generate(
|
|
203
|
+
text: Annotated[
|
|
204
|
+
str, typer.Option(help="Text to generate")
|
|
205
|
+
] = "Hello world. I am Kyutai's Pocket TTS. I'm fast enough to run on small CPUs. I hope you'll like me.",
|
|
206
|
+
voice: Annotated[
|
|
207
|
+
str, typer.Option(help="Path to audio conditioning file (voice to clone)")
|
|
208
|
+
] = DEFAULT_AUDIO_PROMPT,
|
|
209
|
+
quiet: Annotated[bool, typer.Option("-q", "--quiet", help="Disable logging output")] = False,
|
|
210
|
+
variant: Annotated[str, typer.Option(help="Model signature")] = DEFAULT_VARIANT,
|
|
211
|
+
lsd_decode_steps: Annotated[
|
|
212
|
+
int, typer.Option(help="Number of generation steps")
|
|
213
|
+
] = DEFAULT_LSD_DECODE_STEPS,
|
|
214
|
+
temperature: Annotated[
|
|
215
|
+
float, typer.Option(help="Temperature for generation")
|
|
216
|
+
] = DEFAULT_TEMPERATURE,
|
|
217
|
+
noise_clamp: Annotated[float, typer.Option(help="Noise clamp value")] = DEFAULT_NOISE_CLAMP,
|
|
218
|
+
eos_threshold: Annotated[float, typer.Option(help="EOS threshold")] = DEFAULT_EOS_THRESHOLD,
|
|
219
|
+
frames_after_eos: Annotated[
|
|
220
|
+
int, typer.Option(help="Number of frames to generate after EOS")
|
|
221
|
+
] = DEFAULT_FRAMES_AFTER_EOS,
|
|
222
|
+
output_path: Annotated[
|
|
223
|
+
str, typer.Option(help="Output path for generated audio")
|
|
224
|
+
] = "./tts_output.wav",
|
|
225
|
+
device: Annotated[str, typer.Option(help="Device to use")] = "cpu",
|
|
226
|
+
):
|
|
227
|
+
"""Generate speech using Kyutai Pocket TTS."""
|
|
228
|
+
if "cuda" in device:
|
|
229
|
+
# Cuda graphs capturing does not play nice with multithreading.
|
|
230
|
+
os.environ["NO_CUDA_GRAPH"] = "1"
|
|
231
|
+
|
|
232
|
+
log_level = logging.ERROR if quiet else logging.INFO
|
|
233
|
+
with enable_logging("pocket_tts", log_level):
|
|
234
|
+
tts_model = TTSModel.load_model(
|
|
235
|
+
variant, temperature, lsd_decode_steps, noise_clamp, eos_threshold
|
|
236
|
+
)
|
|
237
|
+
tts_model.to(device)
|
|
238
|
+
|
|
239
|
+
model_state_for_voice = tts_model.get_state_for_audio_prompt(voice)
|
|
240
|
+
# Stream audio generation directly to file or stdout
|
|
241
|
+
audio_chunks = tts_model.generate_audio_stream(
|
|
242
|
+
model_state=model_state_for_voice,
|
|
243
|
+
text_to_generate=text,
|
|
244
|
+
frames_after_eos=frames_after_eos,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
stream_audio_chunks(output_path, audio_chunks, tts_model.config.mimi.sample_rate)
|
|
248
|
+
|
|
249
|
+
# Only print the result message if not writing to stdout
|
|
250
|
+
if output_path != "-":
|
|
251
|
+
logger.info("Results written in %s", output_path)
|
|
252
|
+
logger.info("-" * 20)
|
|
253
|
+
logger.info(
|
|
254
|
+
"If you want to try multiple voices and prompts quickly, try the `serve` command."
|
|
255
|
+
)
|
|
256
|
+
logger.info(
|
|
257
|
+
"If you like Kyutai projects, comment, like, subscribe at https://x.com/kyutai_labs"
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
if __name__ == "__main__":
|
|
262
|
+
cli_app()
|