pocket-tts 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pocket_tts/__init__.py +16 -0
- pocket_tts/__main__.py +6 -0
- pocket_tts/conditioners/__init__.py +0 -0
- pocket_tts/conditioners/base.py +38 -0
- pocket_tts/conditioners/text.py +61 -0
- pocket_tts/config/b6369a24.yaml +57 -0
- pocket_tts/data/__init__.py +2 -0
- pocket_tts/data/audio.py +144 -0
- pocket_tts/data/audio_utils.py +28 -0
- pocket_tts/default_parameters.py +7 -0
- pocket_tts/main.py +262 -0
- pocket_tts/models/__init__.py +3 -0
- pocket_tts/models/flow_lm.py +208 -0
- pocket_tts/models/mimi.py +111 -0
- pocket_tts/models/tts_model.py +782 -0
- pocket_tts/modules/__init__.py +1 -0
- pocket_tts/modules/conv.py +161 -0
- pocket_tts/modules/dummy_quantizer.py +18 -0
- pocket_tts/modules/layer_scale.py +11 -0
- pocket_tts/modules/mimi_transformer.py +285 -0
- pocket_tts/modules/mlp.py +215 -0
- pocket_tts/modules/resample.py +46 -0
- pocket_tts/modules/rope.py +74 -0
- pocket_tts/modules/seanet.py +180 -0
- pocket_tts/modules/stateful_module.py +45 -0
- pocket_tts/modules/transformer.py +124 -0
- pocket_tts/static/index.html +374 -0
- pocket_tts/utils/__init__.py +1 -0
- pocket_tts/utils/config.py +122 -0
- pocket_tts/utils/debugging.py +26 -0
- pocket_tts/utils/logging_utils.py +41 -0
- pocket_tts/utils/utils.py +103 -0
- pocket_tts/utils/weights_loading.py +35 -0
- pocket_tts-1.0.2.dist-info/METADATA +174 -0
- pocket_tts-1.0.2.dist-info/RECORD +38 -0
- pocket_tts-1.0.2.dist-info/WHEEL +4 -0
- pocket_tts-1.0.2.dist-info/entry_points.txt +2 -0
- pocket_tts-1.0.2.dist-info/licenses/LICENSE +23 -0
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from beartype.typing import Callable
|
|
6
|
+
from torch import nn
|
|
7
|
+
from typing_extensions import Self
|
|
8
|
+
|
|
9
|
+
from pocket_tts.conditioners.text import LUTConditioner
|
|
10
|
+
from pocket_tts.modules.mimi_transformer import StreamingTransformer
|
|
11
|
+
from pocket_tts.modules.mlp import SimpleMLPAdaLN
|
|
12
|
+
from pocket_tts.utils.config import FlowLMConfig
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
FlowNet2 = Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def lsd_decode(v_t: FlowNet2, x_0: torch.Tensor, num_steps: int = 1) -> torch.Tensor:
|
|
20
|
+
"""Rebuilds the data sample from starting point x_0.
|
|
21
|
+
|
|
22
|
+
Lagrangian Self Distillation (https://arxiv.org/pdf/2505.18825)
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
v_t: Function taking t and x_t as input and returning the flow.
|
|
26
|
+
x_0: Starting point from the known distribution.
|
|
27
|
+
num_steps: Number of steps to take.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
x_1_hat: (B, D) Reconstructed data sample.
|
|
31
|
+
"""
|
|
32
|
+
current = x_0
|
|
33
|
+
for i in range(num_steps):
|
|
34
|
+
s = i / num_steps
|
|
35
|
+
t = (i + 1) / num_steps
|
|
36
|
+
flow_dir = v_t(
|
|
37
|
+
s * torch.ones_like(x_0[..., :1]), t * torch.ones_like(x_0[..., :1]), current
|
|
38
|
+
)
|
|
39
|
+
current += flow_dir / num_steps
|
|
40
|
+
return current
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class FlowLMModel(nn.Module):
|
|
44
|
+
"""Transformer-based flow language model on multiple streams of latents.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
conditioner (LUTConditioner): Text conditioner for processing text inputs.
|
|
48
|
+
flow: Flow module that defines the flow loss and sampling strategy.
|
|
49
|
+
flow_net: Trainable function (cond, t, x_t) -> u_t.
|
|
50
|
+
dim (int): Dimension of the transformer encoder.
|
|
51
|
+
norm (str): Normalization method.
|
|
52
|
+
attribute_dropouts (dict): Attribute dropout probabilities.
|
|
53
|
+
ldim (int): Latent dimension.
|
|
54
|
+
stats_ema_decay (float): Decay for the EMA of the latent statistics.
|
|
55
|
+
**kwargs: Additional parameters for the transformer encoder.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
conditioner: LUTConditioner,
|
|
61
|
+
flow_net: SimpleMLPAdaLN,
|
|
62
|
+
transformer: StreamingTransformer,
|
|
63
|
+
dim: int = 128,
|
|
64
|
+
ldim: int = 64,
|
|
65
|
+
stats_ema_decay: float = 0.999,
|
|
66
|
+
text_padding_weight: float = 1.0,
|
|
67
|
+
dtype=None,
|
|
68
|
+
):
|
|
69
|
+
super().__init__()
|
|
70
|
+
self.conditioner = conditioner
|
|
71
|
+
self.ldim = ldim
|
|
72
|
+
self.stats_ema_decay = stats_ema_decay
|
|
73
|
+
self.dim = dim
|
|
74
|
+
self.text_padding_weight = text_padding_weight
|
|
75
|
+
self.dtype = dtype
|
|
76
|
+
|
|
77
|
+
self.flow_net = flow_net
|
|
78
|
+
self.register_buffer("emb_std", torch.ones(ldim, dtype=dtype))
|
|
79
|
+
self.register_buffer("emb_mean", torch.zeros(ldim, dtype=dtype))
|
|
80
|
+
self.bos_emb = torch.nn.Parameter(torch.randn(ldim, dtype=dtype))
|
|
81
|
+
|
|
82
|
+
self.input_linear = nn.Linear(self.ldim, dim, bias=False, dtype=dtype)
|
|
83
|
+
self.transformer = transformer
|
|
84
|
+
self.out_norm = nn.LayerNorm(dim, eps=1e-5)
|
|
85
|
+
self.out_eos = nn.Linear(dim, 1, dtype=dtype)
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def device(self) -> str:
|
|
89
|
+
return next(self.parameters()).device.type
|
|
90
|
+
|
|
91
|
+
def forward(
|
|
92
|
+
self,
|
|
93
|
+
sequence: torch.Tensor,
|
|
94
|
+
text_embeddings: torch.Tensor,
|
|
95
|
+
model_state: dict,
|
|
96
|
+
lsd_decode_steps: int,
|
|
97
|
+
temp: float,
|
|
98
|
+
noise_clamp: float | None,
|
|
99
|
+
eos_threshold: float,
|
|
100
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
101
|
+
"""Apply language model on sequence and conditions.
|
|
102
|
+
Given a tensor of sequence of shape [B, S, ldim], returns the loss in training mode
|
|
103
|
+
or the reconstructed latent in generation mode.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
sequence (torch.Tensor): Latents to model.
|
|
107
|
+
text_embeddings (torch.Tensor): Pre-computed conditioning
|
|
108
|
+
tensor.
|
|
109
|
+
lsd_decode_steps (int): Number of steps to decode when generating audio.
|
|
110
|
+
If zero, the model computes the loss.
|
|
111
|
+
Returns:
|
|
112
|
+
(output, eos_output, metrics). If `lsd_decode_steps` is zero, `output` is the loss tensor of shape [B, S],
|
|
113
|
+
otherwise it is the reconstructed latent.
|
|
114
|
+
"""
|
|
115
|
+
# NaN values signal a BOS position.
|
|
116
|
+
sequence = torch.where(torch.isnan(sequence), self.bos_emb, sequence)
|
|
117
|
+
input_ = self.input_linear(sequence)
|
|
118
|
+
|
|
119
|
+
transformer_out = self.backbone(input_, text_embeddings, sequence, model_state=model_state)
|
|
120
|
+
transformer_out = transformer_out.to(torch.float32)
|
|
121
|
+
assert lsd_decode_steps > 0
|
|
122
|
+
|
|
123
|
+
transformer_out = transformer_out[:, -1]
|
|
124
|
+
out_eos = self.out_eos(transformer_out) > eos_threshold
|
|
125
|
+
|
|
126
|
+
noise_shape = transformer_out.shape[:-1] + (self.ldim,)
|
|
127
|
+
std = temp**0.5
|
|
128
|
+
noise = torch.empty(noise_shape, dtype=transformer_out.dtype, device=transformer_out.device)
|
|
129
|
+
if noise_clamp is None:
|
|
130
|
+
torch.nn.init.normal_(noise, mean=0.0, std=std)
|
|
131
|
+
else:
|
|
132
|
+
torch.nn.init.trunc_normal_(noise, mean=0.0, std=std, a=-noise_clamp, b=noise_clamp)
|
|
133
|
+
conditioned_flow = partial(self.flow_net, transformer_out)
|
|
134
|
+
return lsd_decode(conditioned_flow, noise, lsd_decode_steps), out_eos
|
|
135
|
+
|
|
136
|
+
def backbone(
|
|
137
|
+
self, input_, text_embeddings: torch.Tensor, sequence, model_state: dict
|
|
138
|
+
) -> torch.Tensor:
|
|
139
|
+
# Most of the time, one of those two tensors is empty, it allows us
|
|
140
|
+
# to input text or audio embeddings into the model without adding an
|
|
141
|
+
# if-else branch.
|
|
142
|
+
# print("text_embeddings shape:", text_embeddings.shape)
|
|
143
|
+
# if text_embeddings.numel() != 0:
|
|
144
|
+
# torch.save(text_embeddings, "debug_flow_lm_text_embeddings.pt")
|
|
145
|
+
input_ = torch.cat([text_embeddings, input_], dim=1)
|
|
146
|
+
# transformer_out = self.transformer(input_, model_state=model_state)
|
|
147
|
+
transformer_out = self.transformer(input_, model_state)
|
|
148
|
+
if self.out_norm:
|
|
149
|
+
transformer_out = self.out_norm(transformer_out)
|
|
150
|
+
# remove the prefix from the model outputs (condition is prepended)
|
|
151
|
+
transformer_out = transformer_out[:, -sequence.shape[1] :]
|
|
152
|
+
return transformer_out
|
|
153
|
+
|
|
154
|
+
def _sample_next_latent(
|
|
155
|
+
self,
|
|
156
|
+
sequence: torch.Tensor,
|
|
157
|
+
text_embeddings: torch.Tensor,
|
|
158
|
+
model_state: dict,
|
|
159
|
+
lsd_decode_steps: int,
|
|
160
|
+
temp: float,
|
|
161
|
+
noise_clamp: float | None,
|
|
162
|
+
eos_threshold: float,
|
|
163
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
164
|
+
"""Sample next latent from the model given a sequence and a set of conditions.
|
|
165
|
+
Args:
|
|
166
|
+
sequence (torch.Tensor): Current sequence of shape [B, K, S]
|
|
167
|
+
with K corresponding to the number of codebooks and S the number of sequence steps.
|
|
168
|
+
S = 1 in streaming mode, except for the first step that contains a bigger prompt.
|
|
169
|
+
text_embeddings (torch.Tensor): Condition tensor.
|
|
170
|
+
n_steps (int): Number of flow steps to decode when generating audio.
|
|
171
|
+
Returns:
|
|
172
|
+
next_latent (torch.Tensor), is_eos (torch.Tensor): Next latent tensor of shape [B, 1, ldim]
|
|
173
|
+
and is_eos tensor of shape [B, 1] with 1 on EOS positions.
|
|
174
|
+
"""
|
|
175
|
+
result = self(
|
|
176
|
+
sequence=sequence,
|
|
177
|
+
text_embeddings=text_embeddings,
|
|
178
|
+
lsd_decode_steps=lsd_decode_steps,
|
|
179
|
+
temp=temp,
|
|
180
|
+
noise_clamp=noise_clamp,
|
|
181
|
+
eos_threshold=eos_threshold,
|
|
182
|
+
model_state=model_state,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
return result
|
|
186
|
+
|
|
187
|
+
@classmethod
|
|
188
|
+
def from_pydantic_config(cls, config: FlowLMConfig, latent_dim: int) -> Self:
|
|
189
|
+
d_model = config.transformer.d_model
|
|
190
|
+
flow_mlp = SimpleMLPAdaLN.from_pydantic_config(config, latent_dim, d_model)
|
|
191
|
+
|
|
192
|
+
conditioner = LUTConditioner(
|
|
193
|
+
n_bins=config.lookup_table.n_bins,
|
|
194
|
+
tokenizer_path=str(config.lookup_table.tokenizer_path),
|
|
195
|
+
dim=config.lookup_table.dim,
|
|
196
|
+
output_dim=d_model,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
transformer = StreamingTransformer.from_pydantic_config(config.transformer)
|
|
200
|
+
|
|
201
|
+
return cls(
|
|
202
|
+
flow_net=flow_mlp,
|
|
203
|
+
transformer=transformer,
|
|
204
|
+
dim=d_model,
|
|
205
|
+
conditioner=conditioner,
|
|
206
|
+
ldim=latent_dim,
|
|
207
|
+
dtype=getattr(torch, config.dtype),
|
|
208
|
+
)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
from pocket_tts.modules.conv import pad_for_conv1d
|
|
7
|
+
from pocket_tts.modules.dummy_quantizer import DummyQuantizer
|
|
8
|
+
from pocket_tts.modules.mimi_transformer import ProjectedTransformer
|
|
9
|
+
from pocket_tts.modules.resample import ConvDownsample1d, ConvTrUpsample1d
|
|
10
|
+
from pocket_tts.modules.seanet import SEANetDecoder, SEANetEncoder
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MimiModel(nn.Module):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
encoder: SEANetEncoder,
|
|
19
|
+
decoder: SEANetDecoder,
|
|
20
|
+
quantizer: DummyQuantizer,
|
|
21
|
+
frame_rate: float,
|
|
22
|
+
encoder_frame_rate: float,
|
|
23
|
+
sample_rate: int,
|
|
24
|
+
channels: int,
|
|
25
|
+
encoder_transformer: ProjectedTransformer,
|
|
26
|
+
decoder_transformer: ProjectedTransformer,
|
|
27
|
+
):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.encoder = encoder
|
|
30
|
+
self.decoder = decoder
|
|
31
|
+
self.encoder_transformer = encoder_transformer
|
|
32
|
+
self.decoder_transformer = decoder_transformer
|
|
33
|
+
self.quantizer = quantizer
|
|
34
|
+
self.frame_rate = frame_rate
|
|
35
|
+
self.sample_rate = sample_rate
|
|
36
|
+
self.channels = channels
|
|
37
|
+
self.encoder_frame_rate = encoder_frame_rate
|
|
38
|
+
|
|
39
|
+
# We will need the dimension for the resampling. In general the encoder will be a SeanetEncoder
|
|
40
|
+
# which exposes a `dimension` attribute.
|
|
41
|
+
dimension = encoder.dimension
|
|
42
|
+
assert isinstance(dimension, int), (
|
|
43
|
+
f"Dimension should be int, got {dimension} of type {type(dimension)}."
|
|
44
|
+
)
|
|
45
|
+
self.dimension = dimension
|
|
46
|
+
|
|
47
|
+
if encoder_frame_rate != frame_rate:
|
|
48
|
+
assert self.encoder_frame_rate > self.frame_rate, "Cannot upsample with conv."
|
|
49
|
+
downsample_stride = self.encoder_frame_rate / self.frame_rate
|
|
50
|
+
assert downsample_stride == int(downsample_stride), (
|
|
51
|
+
f"Only integer strides are supported, got {downsample_stride}"
|
|
52
|
+
)
|
|
53
|
+
self.downsample = ConvDownsample1d(int(downsample_stride), dimension=dimension)
|
|
54
|
+
self.upsample = ConvTrUpsample1d(int(downsample_stride), dimension=dimension)
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def frame_size(self) -> int:
|
|
58
|
+
return int(self.sample_rate / self.frame_rate)
|
|
59
|
+
|
|
60
|
+
def _to_framerate(self, x: torch.Tensor):
|
|
61
|
+
# Convert from the encoder frame rate to the overall framerate.
|
|
62
|
+
_, _, length = x.shape
|
|
63
|
+
frame_rate = self.encoder_frame_rate
|
|
64
|
+
new_frame_rate = self.frame_rate
|
|
65
|
+
if frame_rate == new_frame_rate:
|
|
66
|
+
return x
|
|
67
|
+
return self.downsample(x, model_state=None)
|
|
68
|
+
|
|
69
|
+
def _to_encoder_framerate(self, x: torch.Tensor, mimi_state) -> torch.Tensor:
|
|
70
|
+
# Convert from overall framerate to the encoder frame rate.
|
|
71
|
+
_, _, length = x.shape
|
|
72
|
+
frame_rate = self.encoder_frame_rate
|
|
73
|
+
new_frame_rate = self.frame_rate
|
|
74
|
+
if frame_rate == new_frame_rate:
|
|
75
|
+
return x
|
|
76
|
+
return self.upsample(x, mimi_state)
|
|
77
|
+
|
|
78
|
+
def forward(self, x: torch.Tensor):
|
|
79
|
+
raise NotImplementedError()
|
|
80
|
+
|
|
81
|
+
def decode_from_latent(self, latent: torch.Tensor, mimi_state) -> torch.Tensor:
|
|
82
|
+
emb = self._to_encoder_framerate(latent, mimi_state)
|
|
83
|
+
(emb,) = self.decoder_transformer(emb, mimi_state)
|
|
84
|
+
out = self.decoder(emb, mimi_state)
|
|
85
|
+
# out contains extra padding added by the encoder and decoder
|
|
86
|
+
return out
|
|
87
|
+
|
|
88
|
+
def encode_to_latent(self, x: torch.Tensor) -> torch.Tensor:
|
|
89
|
+
"""Projects a batch of waveforms to unquantized latent space.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
x (torch.Tensor): Float tensor of shape [B, C, T].
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Unquantized embeddings.
|
|
96
|
+
"""
|
|
97
|
+
assert x.dim() == 3, (
|
|
98
|
+
f"CompressionModel._encode_to_unquantized_latent expects audio of shape [B, C, T] but got {x.shape}"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
frame_size = self.frame_size
|
|
102
|
+
|
|
103
|
+
# The underlying convolutions no longer accept partial inputs,
|
|
104
|
+
# `x` needs to be exactly a multiple of the frame size,
|
|
105
|
+
# reproducing the previous padding behavior here.
|
|
106
|
+
x = pad_for_conv1d(x, frame_size, frame_size)
|
|
107
|
+
emb = self.encoder(x, model_state=None)
|
|
108
|
+
|
|
109
|
+
(emb,) = self.encoder_transformer(emb, model_state=None)
|
|
110
|
+
emb = self._to_framerate(emb)
|
|
111
|
+
return emb
|