pocket-tts 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pocket_tts/__init__.py +16 -0
- pocket_tts/__main__.py +6 -0
- pocket_tts/conditioners/__init__.py +0 -0
- pocket_tts/conditioners/base.py +38 -0
- pocket_tts/conditioners/text.py +61 -0
- pocket_tts/config/b6369a24.yaml +57 -0
- pocket_tts/data/__init__.py +2 -0
- pocket_tts/data/audio.py +144 -0
- pocket_tts/data/audio_utils.py +28 -0
- pocket_tts/default_parameters.py +7 -0
- pocket_tts/main.py +262 -0
- pocket_tts/models/__init__.py +3 -0
- pocket_tts/models/flow_lm.py +208 -0
- pocket_tts/models/mimi.py +111 -0
- pocket_tts/models/tts_model.py +782 -0
- pocket_tts/modules/__init__.py +1 -0
- pocket_tts/modules/conv.py +161 -0
- pocket_tts/modules/dummy_quantizer.py +18 -0
- pocket_tts/modules/layer_scale.py +11 -0
- pocket_tts/modules/mimi_transformer.py +285 -0
- pocket_tts/modules/mlp.py +215 -0
- pocket_tts/modules/resample.py +46 -0
- pocket_tts/modules/rope.py +74 -0
- pocket_tts/modules/seanet.py +180 -0
- pocket_tts/modules/stateful_module.py +45 -0
- pocket_tts/modules/transformer.py +124 -0
- pocket_tts/static/index.html +374 -0
- pocket_tts/utils/__init__.py +1 -0
- pocket_tts/utils/config.py +122 -0
- pocket_tts/utils/debugging.py +26 -0
- pocket_tts/utils/logging_utils.py +41 -0
- pocket_tts/utils/utils.py +103 -0
- pocket_tts/utils/weights_loading.py +35 -0
- pocket_tts-1.0.2.dist-info/METADATA +174 -0
- pocket_tts-1.0.2.dist-info/RECORD +38 -0
- pocket_tts-1.0.2.dist-info/WHEEL +4 -0
- pocket_tts-1.0.2.dist-info/entry_points.txt +2 -0
- pocket_tts-1.0.2.dist-info/licenses/LICENSE +23 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Modules used for building the models."""
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
from torch.nn import functional as F
|
|
7
|
+
|
|
8
|
+
from pocket_tts.modules.stateful_module import StatefulModule
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_extra_padding_for_conv1d(
|
|
12
|
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
|
13
|
+
) -> int:
|
|
14
|
+
"""See `pad_for_conv1d`."""
|
|
15
|
+
length = x.shape[-1]
|
|
16
|
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
|
17
|
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
|
18
|
+
return ideal_length - length
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
|
|
22
|
+
"""Pad for a convolution to make sure that the last window is full.
|
|
23
|
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
|
24
|
+
an output of the same length, as otherwise, even with padding, some time steps
|
|
25
|
+
might get removed.
|
|
26
|
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
|
27
|
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
|
28
|
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
|
29
|
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
|
30
|
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
|
31
|
+
"""
|
|
32
|
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
|
33
|
+
return F.pad(x, (0, extra_padding))
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class StreamingConv1d(StatefulModule):
|
|
37
|
+
"""Conv1d with some builtin handling of asymmetric or causal padding
|
|
38
|
+
and normalization.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
in_channels: int,
|
|
44
|
+
out_channels: int,
|
|
45
|
+
kernel_size: int,
|
|
46
|
+
stride: int = 1,
|
|
47
|
+
dilation: int = 1,
|
|
48
|
+
groups: int = 1,
|
|
49
|
+
bias: bool = True,
|
|
50
|
+
pad_mode: str = "constant",
|
|
51
|
+
):
|
|
52
|
+
super().__init__()
|
|
53
|
+
assert pad_mode in ["constant", "replicate"], pad_mode
|
|
54
|
+
self.pad_mode = pad_mode
|
|
55
|
+
# warn user on unusual setup between dilation and stride
|
|
56
|
+
if stride > 1 and dilation > 1:
|
|
57
|
+
warnings.warn(
|
|
58
|
+
"StreamingConv1d has been initialized with stride > 1 and dilation > 1"
|
|
59
|
+
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
|
|
60
|
+
)
|
|
61
|
+
self.conv = nn.Conv1d(
|
|
62
|
+
in_channels,
|
|
63
|
+
out_channels,
|
|
64
|
+
kernel_size,
|
|
65
|
+
stride,
|
|
66
|
+
dilation=dilation,
|
|
67
|
+
groups=groups,
|
|
68
|
+
bias=bias,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def _stride(self) -> int:
|
|
73
|
+
return self.conv.stride[0]
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def _kernel_size(self) -> int:
|
|
77
|
+
return self.conv.kernel_size[0]
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def _effective_kernel_size(self) -> int:
|
|
81
|
+
dilation = self.conv.dilation[0]
|
|
82
|
+
return (self._kernel_size - 1) * dilation + 1 # effective kernel size with dilations
|
|
83
|
+
|
|
84
|
+
def init_state(self, batch_size: int, sequence_length: int) -> dict[str, torch.Tensor]:
|
|
85
|
+
stride = self._stride
|
|
86
|
+
# Effective kernel size accounting for dilation.
|
|
87
|
+
kernel = self._effective_kernel_size
|
|
88
|
+
previous = torch.zeros(batch_size, self.conv.in_channels, kernel - stride)
|
|
89
|
+
first = torch.ones(batch_size, dtype=torch.bool)
|
|
90
|
+
return dict(previous=previous, first=first)
|
|
91
|
+
|
|
92
|
+
def forward(self, x, model_state: dict | None):
|
|
93
|
+
B, C, T = x.shape
|
|
94
|
+
S = self._stride
|
|
95
|
+
assert T > 0 and T % S == 0, "Steps must be multiple of stride"
|
|
96
|
+
if model_state is None:
|
|
97
|
+
state = self.init_state(B, 0)
|
|
98
|
+
else:
|
|
99
|
+
state = self.get_state(model_state)
|
|
100
|
+
TP = state["previous"].shape[-1]
|
|
101
|
+
if TP and self.pad_mode == "replicate":
|
|
102
|
+
assert T >= TP, "Not enough content to pad streaming."
|
|
103
|
+
init = x[..., :1]
|
|
104
|
+
state["previous"][:] = torch.where(
|
|
105
|
+
state["first"].view(-1, 1, 1), init, state["previous"]
|
|
106
|
+
)
|
|
107
|
+
if TP:
|
|
108
|
+
x = torch.cat([state["previous"], x], dim=-1)
|
|
109
|
+
y = self.conv(x)
|
|
110
|
+
if TP:
|
|
111
|
+
state["previous"][:] = x[..., -TP:]
|
|
112
|
+
if self.pad_mode == "replicate":
|
|
113
|
+
state["first"] = torch.zeros_like(state["first"])
|
|
114
|
+
return y
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class StreamingConvTranspose1d(StatefulModule):
|
|
118
|
+
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
|
119
|
+
and normalization.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
in_channels: int,
|
|
125
|
+
out_channels: int,
|
|
126
|
+
kernel_size: int,
|
|
127
|
+
stride: int = 1,
|
|
128
|
+
groups: int = 1,
|
|
129
|
+
bias: bool = True,
|
|
130
|
+
):
|
|
131
|
+
super().__init__()
|
|
132
|
+
self.convtr = nn.ConvTranspose1d(
|
|
133
|
+
in_channels, out_channels, kernel_size, stride, groups=groups, bias=bias
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def _stride(self) -> int:
|
|
138
|
+
return self.convtr.stride[0]
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def _kernel_size(self) -> int:
|
|
142
|
+
return self.convtr.kernel_size[0]
|
|
143
|
+
|
|
144
|
+
def init_state(self, batch_size: int, sequence_length: int) -> dict[str, torch.Tensor]:
|
|
145
|
+
K = self._kernel_size
|
|
146
|
+
S = self._stride
|
|
147
|
+
return dict(partial=torch.zeros(batch_size, self.convtr.out_channels, K - S))
|
|
148
|
+
|
|
149
|
+
def forward(self, x, mimi_state: dict):
|
|
150
|
+
layer_state = self.get_state(mimi_state)["partial"]
|
|
151
|
+
y = self.convtr(x)
|
|
152
|
+
PT = layer_state.shape[-1]
|
|
153
|
+
if PT > 0:
|
|
154
|
+
y[..., :PT] += layer_state
|
|
155
|
+
bias = self.convtr.bias
|
|
156
|
+
for_partial = y[..., -PT:]
|
|
157
|
+
if bias is not None:
|
|
158
|
+
for_partial -= bias[:, None]
|
|
159
|
+
layer_state[:] = for_partial
|
|
160
|
+
y = y[..., :-PT]
|
|
161
|
+
return y
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class DummyQuantizer(nn.Module):
|
|
6
|
+
"""Simplified quantizer that only provides output projection for TTS.
|
|
7
|
+
|
|
8
|
+
This removes all unnecessary quantization logic since we don't use actual quantization.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def __init__(self, dimension: int, output_dimension: int):
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.dimension = dimension
|
|
14
|
+
self.output_dimension = output_dimension
|
|
15
|
+
self.output_proj = torch.nn.Conv1d(self.dimension, self.output_dimension, 1, bias=False)
|
|
16
|
+
|
|
17
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
18
|
+
return self.output_proj(x)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class LayerScale(nn.Module):
|
|
6
|
+
def __init__(self, channels: int, init: float):
|
|
7
|
+
super().__init__()
|
|
8
|
+
self.scale = nn.Parameter(torch.full((channels,), init))
|
|
9
|
+
|
|
10
|
+
def forward(self, x: torch.Tensor):
|
|
11
|
+
return self.scale * x
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
from typing import NamedTuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
from einops import rearrange
|
|
6
|
+
from torch.nn import functional as F
|
|
7
|
+
from typing_extensions import Self
|
|
8
|
+
|
|
9
|
+
from pocket_tts.modules.layer_scale import LayerScale
|
|
10
|
+
from pocket_tts.modules.rope import RotaryEmbedding
|
|
11
|
+
from pocket_tts.modules.stateful_module import StatefulModule
|
|
12
|
+
from pocket_tts.modules.transformer import StreamingMultiheadAttention
|
|
13
|
+
from pocket_tts.utils.config import FlowLMTransformerConfig
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class KVCacheResult(NamedTuple):
|
|
17
|
+
keys: torch.Tensor
|
|
18
|
+
values: torch.Tensor
|
|
19
|
+
positions: torch.Tensor
|
|
20
|
+
|
|
21
|
+
@staticmethod
|
|
22
|
+
def from_kv(keys: torch.Tensor, values: torch.Tensor) -> "KVCacheResult":
|
|
23
|
+
B, H, T, D = keys.shape
|
|
24
|
+
assert tuple(values.shape[:-1]) == (B, H, T)
|
|
25
|
+
positions = torch.arange(T, device=keys.device, dtype=torch.long)
|
|
26
|
+
return KVCacheResult(keys, values, positions.expand(B, -1))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def complete(
|
|
30
|
+
cache: torch.Tensor, end_offset: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
|
31
|
+
) -> KVCacheResult:
|
|
32
|
+
capacity = cache.shape[3]
|
|
33
|
+
assert k.shape[:-1] == v.shape[:-1], (k.shape, v.shape)
|
|
34
|
+
B, H, T, D = k.shape
|
|
35
|
+
assert T > 0
|
|
36
|
+
indexes = torch.arange(T, device=end_offset.device, dtype=end_offset.dtype)
|
|
37
|
+
indexes = indexes + end_offset.view(-1, 1)
|
|
38
|
+
indexes = indexes % capacity
|
|
39
|
+
# indexes is [B, T]
|
|
40
|
+
# k is [B, H, T, D]
|
|
41
|
+
# cache is [B, H, T', D]
|
|
42
|
+
this_indexes = indexes.view(B, 1, T, 1)
|
|
43
|
+
this_indexes = this_indexes.expand(-1, H, T, D)
|
|
44
|
+
cache[0].scatter_(2, this_indexes, k)
|
|
45
|
+
cache[1].scatter_(2, this_indexes, v)
|
|
46
|
+
|
|
47
|
+
keys = cache[0]
|
|
48
|
+
values = cache[1]
|
|
49
|
+
|
|
50
|
+
indexes = torch.arange(capacity, device=end_offset.device, dtype=torch.long)
|
|
51
|
+
|
|
52
|
+
# end_index correspond to the actual index where the last value was written.
|
|
53
|
+
last_offset = end_offset.view(-1, 1) + T - 1
|
|
54
|
+
end_index = last_offset % capacity
|
|
55
|
+
delta = indexes - end_index
|
|
56
|
+
|
|
57
|
+
positions = torch.where(delta <= 0, last_offset + delta, last_offset + delta - capacity)
|
|
58
|
+
end_offset[:] = end_offset + T
|
|
59
|
+
invalid = indexes >= end_offset.view(-1, 1)
|
|
60
|
+
positions = torch.where(invalid, torch.full_like(positions, -1), positions)
|
|
61
|
+
|
|
62
|
+
return KVCacheResult(keys, values, positions)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class MimiStreamingMultiheadAttention(StatefulModule):
|
|
66
|
+
def __init__(self, embed_dim: int, num_heads: int, context: int, rope: RotaryEmbedding):
|
|
67
|
+
super().__init__()
|
|
68
|
+
|
|
69
|
+
self.embed_dim = embed_dim
|
|
70
|
+
self.context = context
|
|
71
|
+
self.rope = rope
|
|
72
|
+
self.num_heads = num_heads
|
|
73
|
+
out_dim = 3 * embed_dim
|
|
74
|
+
|
|
75
|
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
|
76
|
+
self.in_proj = nn.Linear(embed_dim, out_dim, bias=False)
|
|
77
|
+
|
|
78
|
+
def init_state(self, batch_size: int, sequence_length: int) -> dict[str, torch.Tensor]:
|
|
79
|
+
dim_per_head = self.embed_dim // self.num_heads
|
|
80
|
+
|
|
81
|
+
state = {}
|
|
82
|
+
state["offset"] = torch.zeros(batch_size, dtype=torch.long)
|
|
83
|
+
state["cache"] = torch.zeros((2, batch_size, self.num_heads, sequence_length, dim_per_head))
|
|
84
|
+
state["end_offset"] = torch.zeros(batch_size, dtype=torch.long)
|
|
85
|
+
return state
|
|
86
|
+
|
|
87
|
+
def increment_step(self, state, increment: int = 1):
|
|
88
|
+
state["offset"] += increment
|
|
89
|
+
|
|
90
|
+
def _complete_kv(self, k, v, model_state: dict | None) -> KVCacheResult:
|
|
91
|
+
if model_state is None:
|
|
92
|
+
return KVCacheResult.from_kv(k, v)
|
|
93
|
+
else:
|
|
94
|
+
layer_state = self.get_state(model_state)
|
|
95
|
+
return complete(layer_state["cache"], layer_state["end_offset"], k, v)
|
|
96
|
+
|
|
97
|
+
def forward(self, query: torch.Tensor, model_state: dict | None) -> torch.Tensor:
|
|
98
|
+
B, T = query.shape[:2]
|
|
99
|
+
|
|
100
|
+
if model_state is None:
|
|
101
|
+
offset = torch.zeros(B, device=query.device, dtype=torch.long)
|
|
102
|
+
else:
|
|
103
|
+
offset = self.get_state(model_state)["offset"]
|
|
104
|
+
|
|
105
|
+
projected = self.in_proj(query)
|
|
106
|
+
|
|
107
|
+
q, k, v = rearrange(projected, "b t (p h d) -> p b h t d", p=3, h=self.num_heads)
|
|
108
|
+
|
|
109
|
+
# Permute from [b, h, t, d] to [b, t, h, d] for rope
|
|
110
|
+
q = q.permute(0, 2, 1, 3)
|
|
111
|
+
k = k.permute(0, 2, 1, 3)
|
|
112
|
+
q, k = self.rope(q, k, offset)
|
|
113
|
+
# Permute back from [b, t, h, d] to [b, h, t, d]
|
|
114
|
+
q = q.permute(0, 2, 1, 3)
|
|
115
|
+
k = k.permute(0, 2, 1, 3)
|
|
116
|
+
|
|
117
|
+
k, v, pos_k = self._complete_kv(k, v, model_state)
|
|
118
|
+
pos_k = pos_k[:, None]
|
|
119
|
+
pos_q = offset.view(-1, 1, 1) + torch.arange(T, device=q.device, dtype=torch.long).view(
|
|
120
|
+
-1, 1
|
|
121
|
+
)
|
|
122
|
+
delta = pos_q - pos_k
|
|
123
|
+
attn_bias = (pos_k >= 0) & (delta >= 0)
|
|
124
|
+
attn_bias = attn_bias & (delta < self.context)
|
|
125
|
+
attn_bias = attn_bias[:, None]
|
|
126
|
+
|
|
127
|
+
x = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0)
|
|
128
|
+
|
|
129
|
+
x = rearrange(x, "b h t d -> b t (h d)")
|
|
130
|
+
x = self.out_proj(x)
|
|
131
|
+
return x
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class StreamingTransformerLayer(nn.Module):
|
|
135
|
+
def __init__(
|
|
136
|
+
self,
|
|
137
|
+
d_model: int,
|
|
138
|
+
num_heads: int,
|
|
139
|
+
dim_feedforward: int,
|
|
140
|
+
context: int | None,
|
|
141
|
+
rope: RotaryEmbedding,
|
|
142
|
+
layer_scale: float | None = None,
|
|
143
|
+
attention_kind: str = "mimi",
|
|
144
|
+
):
|
|
145
|
+
super().__init__()
|
|
146
|
+
# Redefine self_attn to our streaming multi-head attention
|
|
147
|
+
if attention_kind == "mimi":
|
|
148
|
+
# TODO: we should actually use StreamingMultiheadAttention here and add context window
|
|
149
|
+
# support. And we should then delete MimiStreamingMultiheadAttention.
|
|
150
|
+
# The implementation is really close.
|
|
151
|
+
self.self_attn = MimiStreamingMultiheadAttention(
|
|
152
|
+
context=context, rope=rope, embed_dim=d_model, num_heads=num_heads
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
self.self_attn = StreamingMultiheadAttention(
|
|
156
|
+
rope=rope, embed_dim=d_model, num_heads=num_heads
|
|
157
|
+
)
|
|
158
|
+
self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
|
|
159
|
+
self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
|
|
160
|
+
|
|
161
|
+
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
|
|
162
|
+
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
|
|
163
|
+
|
|
164
|
+
if layer_scale is None:
|
|
165
|
+
self.layer_scale_1 = nn.Identity()
|
|
166
|
+
self.layer_scale_2 = nn.Identity()
|
|
167
|
+
else:
|
|
168
|
+
self.layer_scale_1 = LayerScale(d_model, layer_scale)
|
|
169
|
+
self.layer_scale_2 = LayerScale(d_model, layer_scale)
|
|
170
|
+
|
|
171
|
+
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
|
|
172
|
+
x_orig = x
|
|
173
|
+
x = self.norm2(x)
|
|
174
|
+
update = self.linear2(F.gelu(self.linear1(x)))
|
|
175
|
+
return x_orig.to(update) + self.layer_scale_2(update)
|
|
176
|
+
|
|
177
|
+
def _sa_block(self, x: torch.Tensor, model_state: dict | None) -> torch.Tensor:
|
|
178
|
+
x_orig = x
|
|
179
|
+
x = self.norm1(x)
|
|
180
|
+
update = self.self_attn(x, model_state)
|
|
181
|
+
return x_orig.to(update) + self.layer_scale_1(update)
|
|
182
|
+
|
|
183
|
+
def forward(self, x: torch.Tensor, model_state: dict | None) -> torch.Tensor:
|
|
184
|
+
x = self._sa_block(x, model_state)
|
|
185
|
+
x = self._ff_block(x)
|
|
186
|
+
return x
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class StreamingTransformer(nn.Module):
|
|
190
|
+
def __init__(
|
|
191
|
+
self,
|
|
192
|
+
d_model: int,
|
|
193
|
+
num_heads: int,
|
|
194
|
+
num_layers: int,
|
|
195
|
+
layer_scale: float | None = None,
|
|
196
|
+
dim_feedforward: int | list[int] = 2048,
|
|
197
|
+
context: int | None = None,
|
|
198
|
+
max_period: float = 10_000.0,
|
|
199
|
+
kind: str = "mimi",
|
|
200
|
+
):
|
|
201
|
+
super().__init__()
|
|
202
|
+
assert d_model % num_heads == 0
|
|
203
|
+
self.max_period = max_period
|
|
204
|
+
|
|
205
|
+
self.rope = RotaryEmbedding(max_period=max_period)
|
|
206
|
+
|
|
207
|
+
self.layers = nn.ModuleList()
|
|
208
|
+
for _ in range(num_layers):
|
|
209
|
+
self.layers.append(
|
|
210
|
+
StreamingTransformerLayer(
|
|
211
|
+
d_model=d_model,
|
|
212
|
+
num_heads=num_heads,
|
|
213
|
+
dim_feedforward=dim_feedforward,
|
|
214
|
+
context=context,
|
|
215
|
+
rope=self.rope,
|
|
216
|
+
layer_scale=layer_scale,
|
|
217
|
+
attention_kind=kind,
|
|
218
|
+
)
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
@classmethod
|
|
222
|
+
def from_pydantic_config(cls, config: FlowLMTransformerConfig) -> Self:
|
|
223
|
+
dim_feedforward = int(config.d_model * config.hidden_scale)
|
|
224
|
+
return cls(
|
|
225
|
+
d_model=config.d_model,
|
|
226
|
+
num_heads=config.num_heads,
|
|
227
|
+
num_layers=config.num_layers,
|
|
228
|
+
dim_feedforward=dim_feedforward,
|
|
229
|
+
max_period=float(config.max_period),
|
|
230
|
+
kind="flow_lm",
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
def forward(self, x: torch.Tensor, model_state: dict | None):
|
|
234
|
+
for layer in self.layers:
|
|
235
|
+
x = layer(x, model_state)
|
|
236
|
+
return x
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class ProjectedTransformer(nn.Module):
|
|
240
|
+
def __init__(
|
|
241
|
+
self,
|
|
242
|
+
input_dimension: int,
|
|
243
|
+
output_dimensions: tuple[int, ...],
|
|
244
|
+
d_model: int,
|
|
245
|
+
num_heads: int,
|
|
246
|
+
num_layers: int,
|
|
247
|
+
layer_scale: float,
|
|
248
|
+
context: int,
|
|
249
|
+
max_period: float,
|
|
250
|
+
dim_feedforward: int,
|
|
251
|
+
):
|
|
252
|
+
super().__init__()
|
|
253
|
+
self.transformer = StreamingTransformer(
|
|
254
|
+
d_model=d_model,
|
|
255
|
+
num_heads=num_heads,
|
|
256
|
+
num_layers=num_layers,
|
|
257
|
+
layer_scale=layer_scale,
|
|
258
|
+
context=context,
|
|
259
|
+
max_period=max_period,
|
|
260
|
+
dim_feedforward=dim_feedforward,
|
|
261
|
+
)
|
|
262
|
+
self.input_dimension = input_dimension
|
|
263
|
+
self.output_dimensions = output_dimensions
|
|
264
|
+
self.input_proj = None
|
|
265
|
+
if d_model != input_dimension:
|
|
266
|
+
self.input_proj = nn.Linear(input_dimension, d_model, bias=False)
|
|
267
|
+
|
|
268
|
+
self.output_projs = nn.ModuleList()
|
|
269
|
+
for output_dimension in output_dimensions:
|
|
270
|
+
if d_model == output_dimension:
|
|
271
|
+
self.output_projs.append(nn.Identity())
|
|
272
|
+
else:
|
|
273
|
+
self.output_projs.append(nn.Linear(d_model, output_dimension, bias=False))
|
|
274
|
+
|
|
275
|
+
def forward(self, x, model_state: dict | None):
|
|
276
|
+
x = x.transpose(1, 2)
|
|
277
|
+
if self.input_proj is not None:
|
|
278
|
+
x = self.input_proj(x)
|
|
279
|
+
z = self.transformer(x, model_state)
|
|
280
|
+
ys = []
|
|
281
|
+
for output_proj in self.output_projs:
|
|
282
|
+
y = output_proj(z)
|
|
283
|
+
y = y.transpose(1, 2)
|
|
284
|
+
ys.append(y)
|
|
285
|
+
return ys
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Taken from
|
|
3
|
+
https://github.com/LTH14/mar/blob/fe470ac24afbee924668d8c5c83e9fec60af3a73/models/diffloss.py
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
from typing_extensions import Self
|
|
12
|
+
|
|
13
|
+
from pocket_tts.utils.config import FlowLMConfig
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def modulate(x, shift, scale):
|
|
17
|
+
return x * (1 + scale) + shift
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _rms_norm(x: torch.Tensor, alpha: torch.Tensor, eps: float):
|
|
21
|
+
assert x.dim() >= alpha.dim()
|
|
22
|
+
x_dtype = x.dtype
|
|
23
|
+
var = eps + x.var(dim=-1, keepdim=True)
|
|
24
|
+
y = (x * (alpha.to(var) * torch.rsqrt(var))).to(x_dtype)
|
|
25
|
+
return y
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class RMSNorm(nn.Module):
|
|
29
|
+
def __init__(self, dim: int, eps: float = 1e-5):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.eps = eps
|
|
32
|
+
alpha_shape = (dim,)
|
|
33
|
+
self.alpha = nn.Parameter(torch.full(alpha_shape, 1.0, requires_grad=True))
|
|
34
|
+
|
|
35
|
+
def forward(self, x: torch.Tensor):
|
|
36
|
+
return _rms_norm(x, self.alpha, self.eps)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class LayerNorm(nn.Module):
|
|
40
|
+
"""Reimplementation of LayerNorm because the default one doesn't support jvp."""
|
|
41
|
+
|
|
42
|
+
def __init__(self, channels, eps=1e-6, elementwise_affine=True):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.eps = eps
|
|
45
|
+
if elementwise_affine:
|
|
46
|
+
self.weight = nn.Parameter(torch.ones(channels))
|
|
47
|
+
self.bias = nn.Parameter(torch.zeros(channels))
|
|
48
|
+
|
|
49
|
+
def forward(self, x):
|
|
50
|
+
mean = x.mean(dim=-1, keepdim=True)
|
|
51
|
+
var = x.var(dim=-1, unbiased=False, keepdim=True)
|
|
52
|
+
x = (x - mean) / torch.sqrt(var + self.eps)
|
|
53
|
+
if hasattr(self, "weight"):
|
|
54
|
+
x = x * self.weight + self.bias
|
|
55
|
+
return x
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class TimestepEmbedder(nn.Module):
|
|
59
|
+
"""Embeds scalar timesteps into vector representations."""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self, hidden_size: int, frequency_embedding_size: int = 256, max_period: int = 10000
|
|
63
|
+
):
|
|
64
|
+
super().__init__()
|
|
65
|
+
blocks = [
|
|
66
|
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
|
67
|
+
nn.SiLU(),
|
|
68
|
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
|
69
|
+
]
|
|
70
|
+
blocks.append(RMSNorm(hidden_size))
|
|
71
|
+
self.mlp = nn.Sequential(*blocks)
|
|
72
|
+
self.frequency_embedding_size = frequency_embedding_size
|
|
73
|
+
half = frequency_embedding_size // 2
|
|
74
|
+
self.register_buffer(
|
|
75
|
+
"freqs", torch.exp(-math.log(max_period) * torch.arange(start=0, end=half) / half)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def forward(self, t):
|
|
79
|
+
args = t * self.freqs.to(t.dtype)
|
|
80
|
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
81
|
+
assert not (self.frequency_embedding_size % 2)
|
|
82
|
+
t_emb = self.mlp(embedding)
|
|
83
|
+
return t_emb
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class ResBlock(nn.Module):
|
|
87
|
+
"""
|
|
88
|
+
A residual block that can optionally change the number of channels.
|
|
89
|
+
:param channels: the number of input channels.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(self, channels):
|
|
93
|
+
super().__init__()
|
|
94
|
+
self.channels = channels
|
|
95
|
+
|
|
96
|
+
self.in_ln = LayerNorm(channels, eps=1e-6)
|
|
97
|
+
self.mlp = nn.Sequential(
|
|
98
|
+
nn.Linear(channels, channels, bias=True),
|
|
99
|
+
nn.SiLU(),
|
|
100
|
+
nn.Linear(channels, channels, bias=True),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
self.adaLN_modulation = nn.Sequential(
|
|
104
|
+
nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True)
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def forward(self, x, y):
|
|
108
|
+
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
|
|
109
|
+
h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
|
|
110
|
+
h = self.mlp(h)
|
|
111
|
+
return x + gate_mlp * h
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class FinalLayer(nn.Module):
|
|
115
|
+
"""
|
|
116
|
+
The final layer adopted from DiT.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def __init__(self, model_channels, out_channels):
|
|
120
|
+
super().__init__()
|
|
121
|
+
self.norm_final = LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
|
|
122
|
+
self.linear = nn.Linear(model_channels, out_channels, bias=True)
|
|
123
|
+
self.adaLN_modulation = nn.Sequential(
|
|
124
|
+
nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True)
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def forward(self, x, c):
|
|
128
|
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
|
129
|
+
x = modulate(self.norm_final(x), shift, scale)
|
|
130
|
+
x = self.linear(x)
|
|
131
|
+
return x
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class SimpleMLPAdaLN(nn.Module):
|
|
135
|
+
"""Taken from https://arxiv.org/abs/2406.11838.
|
|
136
|
+
|
|
137
|
+
The MLP for Diffusion Loss.
|
|
138
|
+
:param in_channels: channels in the input Tensor.
|
|
139
|
+
:param model_channels: base channel count for the model.
|
|
140
|
+
:param out_channels: channels in the output Tensor.
|
|
141
|
+
:param cond_channels: channels in the condition.
|
|
142
|
+
:param num_res_blocks: number of residual blocks per downsample.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
def __init__(
|
|
146
|
+
self,
|
|
147
|
+
in_channels,
|
|
148
|
+
model_channels,
|
|
149
|
+
out_channels,
|
|
150
|
+
cond_channels,
|
|
151
|
+
num_res_blocks,
|
|
152
|
+
num_time_conds=1,
|
|
153
|
+
):
|
|
154
|
+
super().__init__()
|
|
155
|
+
|
|
156
|
+
self.in_channels = in_channels
|
|
157
|
+
self.model_channels = model_channels
|
|
158
|
+
self.out_channels = out_channels
|
|
159
|
+
self.num_res_blocks = num_res_blocks
|
|
160
|
+
self.num_time_conds = num_time_conds
|
|
161
|
+
|
|
162
|
+
assert num_time_conds != 1
|
|
163
|
+
self.time_embed = nn.ModuleList(
|
|
164
|
+
[TimestepEmbedder(model_channels) for _ in range(num_time_conds)]
|
|
165
|
+
)
|
|
166
|
+
self.cond_embed = nn.Linear(cond_channels, model_channels)
|
|
167
|
+
|
|
168
|
+
self.input_proj = nn.Linear(in_channels, model_channels)
|
|
169
|
+
|
|
170
|
+
res_blocks = []
|
|
171
|
+
for i in range(num_res_blocks):
|
|
172
|
+
res_blocks.append(ResBlock(model_channels))
|
|
173
|
+
|
|
174
|
+
self.res_blocks = nn.ModuleList(res_blocks)
|
|
175
|
+
self.final_layer = FinalLayer(model_channels, out_channels)
|
|
176
|
+
|
|
177
|
+
@classmethod
|
|
178
|
+
def from_pydantic_config(cls, cfg: FlowLMConfig, latent_dim: int, cond_dim: int) -> Self:
|
|
179
|
+
config = cfg.flow
|
|
180
|
+
|
|
181
|
+
flow_dim = config.dim
|
|
182
|
+
flow_depth = config.depth
|
|
183
|
+
num_time_conds = 2
|
|
184
|
+
return SimpleMLPAdaLN(
|
|
185
|
+
latent_dim, flow_dim, latent_dim, cond_dim, flow_depth, num_time_conds=num_time_conds
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def forward(
|
|
189
|
+
self, c: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x: torch.Tensor
|
|
190
|
+
) -> torch.Tensor:
|
|
191
|
+
"""
|
|
192
|
+
Apply the model to an input batch.
|
|
193
|
+
:param c: conditioning from AR transformer.
|
|
194
|
+
:param s: start time tensor.
|
|
195
|
+
:param t: target time tensor.
|
|
196
|
+
:param x: an [N x C] Tensor of inputs.
|
|
197
|
+
:return: an [N x C] Tensor of outputs.
|
|
198
|
+
"""
|
|
199
|
+
# Combine time conditions
|
|
200
|
+
ts = [s, t]
|
|
201
|
+
x = self.input_proj(x)
|
|
202
|
+
assert len(ts) == self.num_time_conds, (
|
|
203
|
+
f"Expected {self.num_time_conds} time conditions, got {len(ts)}"
|
|
204
|
+
)
|
|
205
|
+
assert self.num_time_conds != 1
|
|
206
|
+
t_combined = (
|
|
207
|
+
sum(self.time_embed[i](ts[i]) for i in range(self.num_time_conds)) / self.num_time_conds
|
|
208
|
+
)
|
|
209
|
+
c = self.cond_embed(c)
|
|
210
|
+
y = t_combined + c
|
|
211
|
+
|
|
212
|
+
for block in self.res_blocks:
|
|
213
|
+
x = block(x, y)
|
|
214
|
+
|
|
215
|
+
return self.final_layer(x, y)
|