pocket-tts 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,208 @@
1
+ import logging
2
+ from functools import partial
3
+
4
+ import torch
5
+ from beartype.typing import Callable
6
+ from torch import nn
7
+ from typing_extensions import Self
8
+
9
+ from pocket_tts.conditioners.text import LUTConditioner
10
+ from pocket_tts.modules.mimi_transformer import StreamingTransformer
11
+ from pocket_tts.modules.mlp import SimpleMLPAdaLN
12
+ from pocket_tts.utils.config import FlowLMConfig
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ FlowNet2 = Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
17
+
18
+
19
+ def lsd_decode(v_t: FlowNet2, x_0: torch.Tensor, num_steps: int = 1) -> torch.Tensor:
20
+ """Rebuilds the data sample from starting point x_0.
21
+
22
+ Lagrangian Self Distillation (https://arxiv.org/pdf/2505.18825)
23
+
24
+ Args:
25
+ v_t: Function taking t and x_t as input and returning the flow.
26
+ x_0: Starting point from the known distribution.
27
+ num_steps: Number of steps to take.
28
+
29
+ Returns:
30
+ x_1_hat: (B, D) Reconstructed data sample.
31
+ """
32
+ current = x_0
33
+ for i in range(num_steps):
34
+ s = i / num_steps
35
+ t = (i + 1) / num_steps
36
+ flow_dir = v_t(
37
+ s * torch.ones_like(x_0[..., :1]), t * torch.ones_like(x_0[..., :1]), current
38
+ )
39
+ current += flow_dir / num_steps
40
+ return current
41
+
42
+
43
+ class FlowLMModel(nn.Module):
44
+ """Transformer-based flow language model on multiple streams of latents.
45
+
46
+ Args:
47
+ conditioner (LUTConditioner): Text conditioner for processing text inputs.
48
+ flow: Flow module that defines the flow loss and sampling strategy.
49
+ flow_net: Trainable function (cond, t, x_t) -> u_t.
50
+ dim (int): Dimension of the transformer encoder.
51
+ norm (str): Normalization method.
52
+ attribute_dropouts (dict): Attribute dropout probabilities.
53
+ ldim (int): Latent dimension.
54
+ stats_ema_decay (float): Decay for the EMA of the latent statistics.
55
+ **kwargs: Additional parameters for the transformer encoder.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ conditioner: LUTConditioner,
61
+ flow_net: SimpleMLPAdaLN,
62
+ transformer: StreamingTransformer,
63
+ dim: int = 128,
64
+ ldim: int = 64,
65
+ stats_ema_decay: float = 0.999,
66
+ text_padding_weight: float = 1.0,
67
+ dtype=None,
68
+ ):
69
+ super().__init__()
70
+ self.conditioner = conditioner
71
+ self.ldim = ldim
72
+ self.stats_ema_decay = stats_ema_decay
73
+ self.dim = dim
74
+ self.text_padding_weight = text_padding_weight
75
+ self.dtype = dtype
76
+
77
+ self.flow_net = flow_net
78
+ self.register_buffer("emb_std", torch.ones(ldim, dtype=dtype))
79
+ self.register_buffer("emb_mean", torch.zeros(ldim, dtype=dtype))
80
+ self.bos_emb = torch.nn.Parameter(torch.randn(ldim, dtype=dtype))
81
+
82
+ self.input_linear = nn.Linear(self.ldim, dim, bias=False, dtype=dtype)
83
+ self.transformer = transformer
84
+ self.out_norm = nn.LayerNorm(dim, eps=1e-5)
85
+ self.out_eos = nn.Linear(dim, 1, dtype=dtype)
86
+
87
+ @property
88
+ def device(self) -> str:
89
+ return next(self.parameters()).device.type
90
+
91
+ def forward(
92
+ self,
93
+ sequence: torch.Tensor,
94
+ text_embeddings: torch.Tensor,
95
+ model_state: dict,
96
+ lsd_decode_steps: int,
97
+ temp: float,
98
+ noise_clamp: float | None,
99
+ eos_threshold: float,
100
+ ) -> tuple[torch.Tensor, torch.Tensor]:
101
+ """Apply language model on sequence and conditions.
102
+ Given a tensor of sequence of shape [B, S, ldim], returns the loss in training mode
103
+ or the reconstructed latent in generation mode.
104
+
105
+ Args:
106
+ sequence (torch.Tensor): Latents to model.
107
+ text_embeddings (torch.Tensor): Pre-computed conditioning
108
+ tensor.
109
+ lsd_decode_steps (int): Number of steps to decode when generating audio.
110
+ If zero, the model computes the loss.
111
+ Returns:
112
+ (output, eos_output, metrics). If `lsd_decode_steps` is zero, `output` is the loss tensor of shape [B, S],
113
+ otherwise it is the reconstructed latent.
114
+ """
115
+ # NaN values signal a BOS position.
116
+ sequence = torch.where(torch.isnan(sequence), self.bos_emb, sequence)
117
+ input_ = self.input_linear(sequence)
118
+
119
+ transformer_out = self.backbone(input_, text_embeddings, sequence, model_state=model_state)
120
+ transformer_out = transformer_out.to(torch.float32)
121
+ assert lsd_decode_steps > 0
122
+
123
+ transformer_out = transformer_out[:, -1]
124
+ out_eos = self.out_eos(transformer_out) > eos_threshold
125
+
126
+ noise_shape = transformer_out.shape[:-1] + (self.ldim,)
127
+ std = temp**0.5
128
+ noise = torch.empty(noise_shape, dtype=transformer_out.dtype, device=transformer_out.device)
129
+ if noise_clamp is None:
130
+ torch.nn.init.normal_(noise, mean=0.0, std=std)
131
+ else:
132
+ torch.nn.init.trunc_normal_(noise, mean=0.0, std=std, a=-noise_clamp, b=noise_clamp)
133
+ conditioned_flow = partial(self.flow_net, transformer_out)
134
+ return lsd_decode(conditioned_flow, noise, lsd_decode_steps), out_eos
135
+
136
+ def backbone(
137
+ self, input_, text_embeddings: torch.Tensor, sequence, model_state: dict
138
+ ) -> torch.Tensor:
139
+ # Most of the time, one of those two tensors is empty, it allows us
140
+ # to input text or audio embeddings into the model without adding an
141
+ # if-else branch.
142
+ # print("text_embeddings shape:", text_embeddings.shape)
143
+ # if text_embeddings.numel() != 0:
144
+ # torch.save(text_embeddings, "debug_flow_lm_text_embeddings.pt")
145
+ input_ = torch.cat([text_embeddings, input_], dim=1)
146
+ # transformer_out = self.transformer(input_, model_state=model_state)
147
+ transformer_out = self.transformer(input_, model_state)
148
+ if self.out_norm:
149
+ transformer_out = self.out_norm(transformer_out)
150
+ # remove the prefix from the model outputs (condition is prepended)
151
+ transformer_out = transformer_out[:, -sequence.shape[1] :]
152
+ return transformer_out
153
+
154
+ def _sample_next_latent(
155
+ self,
156
+ sequence: torch.Tensor,
157
+ text_embeddings: torch.Tensor,
158
+ model_state: dict,
159
+ lsd_decode_steps: int,
160
+ temp: float,
161
+ noise_clamp: float | None,
162
+ eos_threshold: float,
163
+ ) -> tuple[torch.Tensor, torch.Tensor]:
164
+ """Sample next latent from the model given a sequence and a set of conditions.
165
+ Args:
166
+ sequence (torch.Tensor): Current sequence of shape [B, K, S]
167
+ with K corresponding to the number of codebooks and S the number of sequence steps.
168
+ S = 1 in streaming mode, except for the first step that contains a bigger prompt.
169
+ text_embeddings (torch.Tensor): Condition tensor.
170
+ n_steps (int): Number of flow steps to decode when generating audio.
171
+ Returns:
172
+ next_latent (torch.Tensor), is_eos (torch.Tensor): Next latent tensor of shape [B, 1, ldim]
173
+ and is_eos tensor of shape [B, 1] with 1 on EOS positions.
174
+ """
175
+ result = self(
176
+ sequence=sequence,
177
+ text_embeddings=text_embeddings,
178
+ lsd_decode_steps=lsd_decode_steps,
179
+ temp=temp,
180
+ noise_clamp=noise_clamp,
181
+ eos_threshold=eos_threshold,
182
+ model_state=model_state,
183
+ )
184
+
185
+ return result
186
+
187
+ @classmethod
188
+ def from_pydantic_config(cls, config: FlowLMConfig, latent_dim: int) -> Self:
189
+ d_model = config.transformer.d_model
190
+ flow_mlp = SimpleMLPAdaLN.from_pydantic_config(config, latent_dim, d_model)
191
+
192
+ conditioner = LUTConditioner(
193
+ n_bins=config.lookup_table.n_bins,
194
+ tokenizer_path=str(config.lookup_table.tokenizer_path),
195
+ dim=config.lookup_table.dim,
196
+ output_dim=d_model,
197
+ )
198
+
199
+ transformer = StreamingTransformer.from_pydantic_config(config.transformer)
200
+
201
+ return cls(
202
+ flow_net=flow_mlp,
203
+ transformer=transformer,
204
+ dim=d_model,
205
+ conditioner=conditioner,
206
+ ldim=latent_dim,
207
+ dtype=getattr(torch, config.dtype),
208
+ )
@@ -0,0 +1,111 @@
1
+ import logging
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from pocket_tts.modules.conv import pad_for_conv1d
7
+ from pocket_tts.modules.dummy_quantizer import DummyQuantizer
8
+ from pocket_tts.modules.mimi_transformer import ProjectedTransformer
9
+ from pocket_tts.modules.resample import ConvDownsample1d, ConvTrUpsample1d
10
+ from pocket_tts.modules.seanet import SEANetDecoder, SEANetEncoder
11
+
12
+ logger = logging.getLogger()
13
+
14
+
15
+ class MimiModel(nn.Module):
16
+ def __init__(
17
+ self,
18
+ encoder: SEANetEncoder,
19
+ decoder: SEANetDecoder,
20
+ quantizer: DummyQuantizer,
21
+ frame_rate: float,
22
+ encoder_frame_rate: float,
23
+ sample_rate: int,
24
+ channels: int,
25
+ encoder_transformer: ProjectedTransformer,
26
+ decoder_transformer: ProjectedTransformer,
27
+ ):
28
+ super().__init__()
29
+ self.encoder = encoder
30
+ self.decoder = decoder
31
+ self.encoder_transformer = encoder_transformer
32
+ self.decoder_transformer = decoder_transformer
33
+ self.quantizer = quantizer
34
+ self.frame_rate = frame_rate
35
+ self.sample_rate = sample_rate
36
+ self.channels = channels
37
+ self.encoder_frame_rate = encoder_frame_rate
38
+
39
+ # We will need the dimension for the resampling. In general the encoder will be a SeanetEncoder
40
+ # which exposes a `dimension` attribute.
41
+ dimension = encoder.dimension
42
+ assert isinstance(dimension, int), (
43
+ f"Dimension should be int, got {dimension} of type {type(dimension)}."
44
+ )
45
+ self.dimension = dimension
46
+
47
+ if encoder_frame_rate != frame_rate:
48
+ assert self.encoder_frame_rate > self.frame_rate, "Cannot upsample with conv."
49
+ downsample_stride = self.encoder_frame_rate / self.frame_rate
50
+ assert downsample_stride == int(downsample_stride), (
51
+ f"Only integer strides are supported, got {downsample_stride}"
52
+ )
53
+ self.downsample = ConvDownsample1d(int(downsample_stride), dimension=dimension)
54
+ self.upsample = ConvTrUpsample1d(int(downsample_stride), dimension=dimension)
55
+
56
+ @property
57
+ def frame_size(self) -> int:
58
+ return int(self.sample_rate / self.frame_rate)
59
+
60
+ def _to_framerate(self, x: torch.Tensor):
61
+ # Convert from the encoder frame rate to the overall framerate.
62
+ _, _, length = x.shape
63
+ frame_rate = self.encoder_frame_rate
64
+ new_frame_rate = self.frame_rate
65
+ if frame_rate == new_frame_rate:
66
+ return x
67
+ return self.downsample(x, model_state=None)
68
+
69
+ def _to_encoder_framerate(self, x: torch.Tensor, mimi_state) -> torch.Tensor:
70
+ # Convert from overall framerate to the encoder frame rate.
71
+ _, _, length = x.shape
72
+ frame_rate = self.encoder_frame_rate
73
+ new_frame_rate = self.frame_rate
74
+ if frame_rate == new_frame_rate:
75
+ return x
76
+ return self.upsample(x, mimi_state)
77
+
78
+ def forward(self, x: torch.Tensor):
79
+ raise NotImplementedError()
80
+
81
+ def decode_from_latent(self, latent: torch.Tensor, mimi_state) -> torch.Tensor:
82
+ emb = self._to_encoder_framerate(latent, mimi_state)
83
+ (emb,) = self.decoder_transformer(emb, mimi_state)
84
+ out = self.decoder(emb, mimi_state)
85
+ # out contains extra padding added by the encoder and decoder
86
+ return out
87
+
88
+ def encode_to_latent(self, x: torch.Tensor) -> torch.Tensor:
89
+ """Projects a batch of waveforms to unquantized latent space.
90
+
91
+ Args:
92
+ x (torch.Tensor): Float tensor of shape [B, C, T].
93
+
94
+ Returns:
95
+ Unquantized embeddings.
96
+ """
97
+ assert x.dim() == 3, (
98
+ f"CompressionModel._encode_to_unquantized_latent expects audio of shape [B, C, T] but got {x.shape}"
99
+ )
100
+
101
+ frame_size = self.frame_size
102
+
103
+ # The underlying convolutions no longer accept partial inputs,
104
+ # `x` needs to be exactly a multiple of the frame size,
105
+ # reproducing the previous padding behavior here.
106
+ x = pad_for_conv1d(x, frame_size, frame_size)
107
+ emb = self.encoder(x, model_state=None)
108
+
109
+ (emb,) = self.encoder_transformer(emb, model_state=None)
110
+ emb = self._to_framerate(emb)
111
+ return emb