pocket-tts 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pocket_tts/__init__.py +16 -0
- pocket_tts/__main__.py +6 -0
- pocket_tts/conditioners/__init__.py +0 -0
- pocket_tts/conditioners/base.py +38 -0
- pocket_tts/conditioners/text.py +61 -0
- pocket_tts/config/b6369a24.yaml +57 -0
- pocket_tts/data/__init__.py +2 -0
- pocket_tts/data/audio.py +144 -0
- pocket_tts/data/audio_utils.py +28 -0
- pocket_tts/default_parameters.py +7 -0
- pocket_tts/main.py +262 -0
- pocket_tts/models/__init__.py +3 -0
- pocket_tts/models/flow_lm.py +208 -0
- pocket_tts/models/mimi.py +111 -0
- pocket_tts/models/tts_model.py +782 -0
- pocket_tts/modules/__init__.py +1 -0
- pocket_tts/modules/conv.py +161 -0
- pocket_tts/modules/dummy_quantizer.py +18 -0
- pocket_tts/modules/layer_scale.py +11 -0
- pocket_tts/modules/mimi_transformer.py +285 -0
- pocket_tts/modules/mlp.py +215 -0
- pocket_tts/modules/resample.py +46 -0
- pocket_tts/modules/rope.py +74 -0
- pocket_tts/modules/seanet.py +180 -0
- pocket_tts/modules/stateful_module.py +45 -0
- pocket_tts/modules/transformer.py +124 -0
- pocket_tts/static/index.html +374 -0
- pocket_tts/utils/__init__.py +1 -0
- pocket_tts/utils/config.py +122 -0
- pocket_tts/utils/debugging.py +26 -0
- pocket_tts/utils/logging_utils.py +41 -0
- pocket_tts/utils/utils.py +103 -0
- pocket_tts/utils/weights_loading.py +35 -0
- pocket_tts-1.0.2.dist-info/METADATA +174 -0
- pocket_tts-1.0.2.dist-info/RECORD +38 -0
- pocket_tts-1.0.2.dist-info/WHEEL +4 -0
- pocket_tts-1.0.2.dist-info/entry_points.txt +2 -0
- pocket_tts-1.0.2.dist-info/licenses/LICENSE +23 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
|
|
4
|
+
from pocket_tts.modules.conv import StreamingConv1d, StreamingConvTranspose1d
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ConvDownsample1d(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
Downsampling by some integer amount `stride` using convolutions
|
|
10
|
+
with a kernel size of twice the stride.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, stride: int, dimension: int):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.conv = StreamingConv1d(
|
|
16
|
+
dimension,
|
|
17
|
+
dimension,
|
|
18
|
+
kernel_size=2 * stride,
|
|
19
|
+
stride=stride,
|
|
20
|
+
groups=1,
|
|
21
|
+
bias=False,
|
|
22
|
+
pad_mode="replicate",
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
def forward(self, x: torch.Tensor, model_state: dict | None):
|
|
26
|
+
return self.conv(x, model_state)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ConvTrUpsample1d(nn.Module):
|
|
30
|
+
"""
|
|
31
|
+
Upsample by some integer amount `stride` using transposed convolutions.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, stride: int, dimension: int):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.convtr = StreamingConvTranspose1d(
|
|
37
|
+
dimension,
|
|
38
|
+
dimension,
|
|
39
|
+
kernel_size=2 * stride,
|
|
40
|
+
stride=stride,
|
|
41
|
+
groups=dimension,
|
|
42
|
+
bias=False,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
def forward(self, x: torch.Tensor, model_state: dict | None):
|
|
46
|
+
return self.convtr(x, model_state)
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def apply_rope(
|
|
8
|
+
q: torch.Tensor,
|
|
9
|
+
k: torch.Tensor,
|
|
10
|
+
offset: int | torch.Tensor = 0,
|
|
11
|
+
max_period: int | float = 10_000,
|
|
12
|
+
):
|
|
13
|
+
"""
|
|
14
|
+
Args:
|
|
15
|
+
q (torch.Tensor): Queries, shape `[B, T, H, D]`.
|
|
16
|
+
k (torch.Tensor): Keys, shape `[B, T, H, D]`.
|
|
17
|
+
offset (int): Current offset, e.g. when streaming.
|
|
18
|
+
max_period (float): Maximum period for the cos and sin.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
B, T, H, D = q.shape
|
|
22
|
+
Bk, Tk, Hk, Dk = k.shape
|
|
23
|
+
assert (B, T, D) == (Bk, Tk, Dk)
|
|
24
|
+
assert D > 0
|
|
25
|
+
assert D % 2 == 0
|
|
26
|
+
assert max_period > 0
|
|
27
|
+
|
|
28
|
+
ds = torch.arange(D // 2, device=q.device, dtype=torch.float32)
|
|
29
|
+
freqs = torch.exp(ds * (-math.log(max_period) * 2 / D))
|
|
30
|
+
|
|
31
|
+
# could be optimized in one call
|
|
32
|
+
ts = torch.arange(T, device=q.device, dtype=torch.float32)
|
|
33
|
+
ts += offset
|
|
34
|
+
ts = ts.view(-1, 1, 1)
|
|
35
|
+
|
|
36
|
+
q = q.view(B, T, H, D // 2, 2)
|
|
37
|
+
k = k.view(B, T, Hk, D // 2, 2)
|
|
38
|
+
|
|
39
|
+
# convention is `r` suffix is real part, `i` is imaginary.
|
|
40
|
+
qr = q[..., 0].float()
|
|
41
|
+
qi = q[..., 1].float()
|
|
42
|
+
|
|
43
|
+
kr = k[..., 0].float()
|
|
44
|
+
ki = k[..., 1].float()
|
|
45
|
+
|
|
46
|
+
rotr = torch.cos(freqs * ts)
|
|
47
|
+
roti = torch.sin(freqs * ts)
|
|
48
|
+
qor = qr * rotr - qi * roti
|
|
49
|
+
qoi = qr * roti + qi * rotr
|
|
50
|
+
|
|
51
|
+
kor = kr * rotr - ki * roti
|
|
52
|
+
koi = kr * roti + ki * rotr
|
|
53
|
+
|
|
54
|
+
dtype = q.dtype
|
|
55
|
+
qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1)
|
|
56
|
+
ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1)
|
|
57
|
+
|
|
58
|
+
return qo.view(B, T, H, D), ko.view(B, T, Hk, D)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class RotaryEmbedding(nn.Module):
|
|
62
|
+
"""Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
max_period (float): Maximum period of the rotation frequencies.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(self, max_period: float | int = 10000.0):
|
|
69
|
+
super().__init__()
|
|
70
|
+
self.max_period = max_period
|
|
71
|
+
|
|
72
|
+
def forward(self, q: torch.Tensor, k: torch.Tensor, offset: torch.Tensor | int):
|
|
73
|
+
"""Apply rope rotation to query or key tensor."""
|
|
74
|
+
return apply_rope(q, k, offset, self.max_period)
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from .conv import StreamingConv1d, StreamingConvTranspose1d
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SEANetResnetBlock(nn.Module):
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
dim: int,
|
|
11
|
+
kernel_sizes: list[int] = [3, 1],
|
|
12
|
+
dilations: list[int] = [1, 1],
|
|
13
|
+
pad_mode: str = "reflect",
|
|
14
|
+
compress: int = 2,
|
|
15
|
+
):
|
|
16
|
+
super().__init__()
|
|
17
|
+
assert len(kernel_sizes) == len(dilations), (
|
|
18
|
+
"Number of kernel sizes should match number of dilations"
|
|
19
|
+
)
|
|
20
|
+
hidden = dim // compress
|
|
21
|
+
block = nn.ModuleList([])
|
|
22
|
+
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
|
23
|
+
in_chs = dim if i == 0 else hidden
|
|
24
|
+
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
|
25
|
+
block += [
|
|
26
|
+
nn.ELU(alpha=1.0),
|
|
27
|
+
StreamingConv1d(
|
|
28
|
+
in_chs, out_chs, kernel_size=kernel_size, dilation=dilation, pad_mode=pad_mode
|
|
29
|
+
),
|
|
30
|
+
]
|
|
31
|
+
self.block = block
|
|
32
|
+
|
|
33
|
+
def forward(self, x, model_state: dict | None):
|
|
34
|
+
v = x
|
|
35
|
+
for layer in self.block:
|
|
36
|
+
if isinstance(layer, StreamingConv1d):
|
|
37
|
+
v = layer(v, model_state)
|
|
38
|
+
else:
|
|
39
|
+
v = layer(v)
|
|
40
|
+
assert x.shape == v.shape, (x.shape, v.shape, x.shape)
|
|
41
|
+
return x + v
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class SEANetEncoder(nn.Module):
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
channels: int = 1,
|
|
48
|
+
dimension: int = 128,
|
|
49
|
+
n_filters: int = 32,
|
|
50
|
+
n_residual_layers: int = 3,
|
|
51
|
+
ratios: list[int] = [8, 5, 4, 2],
|
|
52
|
+
kernel_size: int = 7,
|
|
53
|
+
last_kernel_size: int = 7,
|
|
54
|
+
residual_kernel_size: int = 3,
|
|
55
|
+
dilation_base: int = 2,
|
|
56
|
+
pad_mode: str = "reflect",
|
|
57
|
+
compress: int = 2,
|
|
58
|
+
):
|
|
59
|
+
super().__init__()
|
|
60
|
+
self.channels = channels
|
|
61
|
+
self.dimension = dimension
|
|
62
|
+
self.n_filters = n_filters
|
|
63
|
+
self.ratios = list(reversed(ratios))
|
|
64
|
+
del ratios
|
|
65
|
+
self.n_residual_layers = n_residual_layers
|
|
66
|
+
self.hop_length = int(np.prod(self.ratios))
|
|
67
|
+
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
|
68
|
+
|
|
69
|
+
mult = 1
|
|
70
|
+
model = nn.ModuleList(
|
|
71
|
+
[StreamingConv1d(channels, mult * n_filters, kernel_size, pad_mode=pad_mode)]
|
|
72
|
+
)
|
|
73
|
+
# Downsample to raw audio scale
|
|
74
|
+
for i, ratio in enumerate(self.ratios):
|
|
75
|
+
# Add residual layers
|
|
76
|
+
for j in range(n_residual_layers):
|
|
77
|
+
model += [
|
|
78
|
+
SEANetResnetBlock(
|
|
79
|
+
mult * n_filters,
|
|
80
|
+
kernel_sizes=[residual_kernel_size, 1],
|
|
81
|
+
dilations=[dilation_base**j, 1],
|
|
82
|
+
pad_mode=pad_mode,
|
|
83
|
+
compress=compress,
|
|
84
|
+
)
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
# Add downsampling layers
|
|
88
|
+
model += [
|
|
89
|
+
nn.ELU(alpha=1.0),
|
|
90
|
+
StreamingConv1d(
|
|
91
|
+
mult * n_filters,
|
|
92
|
+
mult * n_filters * 2,
|
|
93
|
+
kernel_size=ratio * 2,
|
|
94
|
+
stride=ratio,
|
|
95
|
+
pad_mode=pad_mode,
|
|
96
|
+
),
|
|
97
|
+
]
|
|
98
|
+
mult *= 2
|
|
99
|
+
|
|
100
|
+
model += [
|
|
101
|
+
nn.ELU(alpha=1.0),
|
|
102
|
+
StreamingConv1d(mult * n_filters, dimension, last_kernel_size, pad_mode=pad_mode),
|
|
103
|
+
]
|
|
104
|
+
|
|
105
|
+
self.model = model
|
|
106
|
+
|
|
107
|
+
def forward(self, x, model_state: dict | None):
|
|
108
|
+
for layer in self.model:
|
|
109
|
+
if isinstance(layer, (StreamingConv1d, SEANetResnetBlock)):
|
|
110
|
+
x = layer(x, model_state)
|
|
111
|
+
else:
|
|
112
|
+
x = layer(x)
|
|
113
|
+
return x
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class SEANetDecoder(nn.Module):
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
channels: int = 1,
|
|
120
|
+
dimension: int = 128,
|
|
121
|
+
n_filters: int = 32,
|
|
122
|
+
n_residual_layers: int = 3,
|
|
123
|
+
ratios: list[int] = [8, 5, 4, 2],
|
|
124
|
+
kernel_size: int = 7,
|
|
125
|
+
last_kernel_size: int = 7,
|
|
126
|
+
residual_kernel_size: int = 3,
|
|
127
|
+
dilation_base: int = 2,
|
|
128
|
+
pad_mode: str = "reflect",
|
|
129
|
+
compress: int = 2,
|
|
130
|
+
):
|
|
131
|
+
super().__init__()
|
|
132
|
+
self.dimension = dimension
|
|
133
|
+
self.channels = channels
|
|
134
|
+
self.n_filters = n_filters
|
|
135
|
+
self.ratios = ratios
|
|
136
|
+
del ratios
|
|
137
|
+
self.n_residual_layers = n_residual_layers
|
|
138
|
+
self.hop_length = int(np.prod(self.ratios))
|
|
139
|
+
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
|
140
|
+
mult = int(2 ** len(self.ratios))
|
|
141
|
+
model = nn.ModuleList(
|
|
142
|
+
[StreamingConv1d(dimension, mult * n_filters, kernel_size, pad_mode=pad_mode)]
|
|
143
|
+
)
|
|
144
|
+
# Upsample to raw audio scale
|
|
145
|
+
for _, ratio in enumerate(self.ratios):
|
|
146
|
+
# Add upsampling layers
|
|
147
|
+
model += [
|
|
148
|
+
nn.ELU(alpha=1.0),
|
|
149
|
+
StreamingConvTranspose1d(
|
|
150
|
+
mult * n_filters, mult * n_filters // 2, kernel_size=ratio * 2, stride=ratio
|
|
151
|
+
),
|
|
152
|
+
]
|
|
153
|
+
# Add residual layers
|
|
154
|
+
for j in range(n_residual_layers):
|
|
155
|
+
model += [
|
|
156
|
+
SEANetResnetBlock(
|
|
157
|
+
mult * n_filters // 2,
|
|
158
|
+
kernel_sizes=[residual_kernel_size, 1],
|
|
159
|
+
dilations=[dilation_base**j, 1],
|
|
160
|
+
pad_mode=pad_mode,
|
|
161
|
+
compress=compress,
|
|
162
|
+
)
|
|
163
|
+
]
|
|
164
|
+
|
|
165
|
+
mult //= 2
|
|
166
|
+
|
|
167
|
+
# Add final layers
|
|
168
|
+
model += [
|
|
169
|
+
nn.ELU(alpha=1.0),
|
|
170
|
+
StreamingConv1d(n_filters, channels, last_kernel_size, pad_mode=pad_mode),
|
|
171
|
+
]
|
|
172
|
+
self.model = model
|
|
173
|
+
|
|
174
|
+
def forward(self, z, model_state: dict | None):
|
|
175
|
+
for layer in self.model:
|
|
176
|
+
if isinstance(layer, (StreamingConvTranspose1d, SEANetResnetBlock, StreamingConv1d)):
|
|
177
|
+
z = layer(z, model_state)
|
|
178
|
+
else:
|
|
179
|
+
z = layer(z)
|
|
180
|
+
return z
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def init_states(
|
|
8
|
+
model: nn.Module, batch_size: int, sequence_length: int
|
|
9
|
+
) -> dict[str, dict[str, torch.Tensor]]:
|
|
10
|
+
result = {}
|
|
11
|
+
for module_name, module in model.named_modules():
|
|
12
|
+
if not isinstance(module, StatefulModule):
|
|
13
|
+
continue
|
|
14
|
+
module._module_absolute_name = module_name
|
|
15
|
+
module_state = module.init_state(batch_size, sequence_length=sequence_length)
|
|
16
|
+
result[module_name] = module_state
|
|
17
|
+
return result
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def increment_steps(
|
|
21
|
+
module: nn.Module, model_state: dict[str, dict[str, torch.Tensor]], increment: int = 1
|
|
22
|
+
):
|
|
23
|
+
# print("incrementing steps by", increment)
|
|
24
|
+
for module_name, module in module.named_modules():
|
|
25
|
+
if not isinstance(module, StatefulModule):
|
|
26
|
+
continue
|
|
27
|
+
module.increment_step(model_state[module_name], increment)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class StatefulModule(ABC, nn.Module):
|
|
31
|
+
def __init__(self, *args, **kwds):
|
|
32
|
+
self._module_absolute_name = None
|
|
33
|
+
return super().__init__(*args, **kwds)
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def init_state(self, batch_size: int, sequence_length: int):
|
|
37
|
+
"""Initialize the state."""
|
|
38
|
+
raise NotImplementedError
|
|
39
|
+
|
|
40
|
+
def increment_step(self, state: dict, increment: int = 1):
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
def get_state(self, model_state: dict[str, dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
|
|
44
|
+
"""Get the state for this module from the model state."""
|
|
45
|
+
return model_state[self._module_absolute_name]
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from torch.nn import functional as F
|
|
4
|
+
|
|
5
|
+
from pocket_tts.modules.rope import RotaryEmbedding
|
|
6
|
+
from pocket_tts.modules.stateful_module import StatefulModule
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def complete_kv(
|
|
10
|
+
cache: torch.Tensor, current_end: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
|
11
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
12
|
+
current_end = current_end.shape[0]
|
|
13
|
+
|
|
14
|
+
cache[0, :, current_end : current_end + k.shape[1]] = k
|
|
15
|
+
cache[1, :, current_end : current_end + v.shape[1]] = v
|
|
16
|
+
valid = cache[:, :, : current_end + k.shape[1]]
|
|
17
|
+
return valid[0], valid[1]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _materialize_causal_mask(
|
|
21
|
+
shape: tuple[int, ...], shift: int, device: str | torch.device = "cpu"
|
|
22
|
+
) -> torch.Tensor:
|
|
23
|
+
dtype = torch.float32
|
|
24
|
+
|
|
25
|
+
num_queries, num_keys = shape[-2:]
|
|
26
|
+
shift = num_keys - num_queries
|
|
27
|
+
|
|
28
|
+
tensor = torch.full(shape, dtype=dtype, fill_value=1, device=device)
|
|
29
|
+
mask = torch.tril(tensor, diagonal=shift).to(dtype)
|
|
30
|
+
mask = torch.log(mask)
|
|
31
|
+
return mask.to(dtype)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class StreamingMultiheadAttention(StatefulModule):
|
|
35
|
+
"""Similar to `nn.MultiheadAttention` but with support for streaming.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
embed_dim (int): Dimension to project to.
|
|
39
|
+
num_heads (int): Number of heads.
|
|
40
|
+
context (int, optional): Number of time steps the attention can access to.
|
|
41
|
+
Can access `context` time steps into the past.
|
|
42
|
+
rope (`RotaryEmbedding`, optional): Rope embedding to use.
|
|
43
|
+
device (torch.device, optional): Device on which to initialize.
|
|
44
|
+
dtype (torch.dtype, optional): dtype to use.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, embed_dim: int, num_heads: int, rope: RotaryEmbedding):
|
|
48
|
+
super().__init__()
|
|
49
|
+
|
|
50
|
+
self.embed_dim = embed_dim
|
|
51
|
+
self.rope = rope
|
|
52
|
+
self.num_heads = num_heads
|
|
53
|
+
|
|
54
|
+
out_dim = embed_dim
|
|
55
|
+
num_kv = num_heads
|
|
56
|
+
kv_dim = (embed_dim // num_heads) * num_kv
|
|
57
|
+
out_dim += 2 * kv_dim
|
|
58
|
+
mult = 1
|
|
59
|
+
self.in_proj = nn.Linear(embed_dim, mult * out_dim, bias=False)
|
|
60
|
+
self.out_proj = nn.Linear(embed_dim, mult * embed_dim, bias=False)
|
|
61
|
+
|
|
62
|
+
def _get_mask(self, shape: tuple[int, int], shift: int, device: torch.device) -> torch.Tensor:
|
|
63
|
+
return _materialize_causal_mask(shape, shift=shift, device=device)
|
|
64
|
+
|
|
65
|
+
def init_state(self, batch_size: int, sequence_length: int) -> dict[str, torch.Tensor]:
|
|
66
|
+
dim_per_head = self.embed_dim // self.num_heads
|
|
67
|
+
initial_current_end = torch.zeros((0,)).to(self.in_proj.weight.device)
|
|
68
|
+
return dict(
|
|
69
|
+
current_end=initial_current_end,
|
|
70
|
+
cache=torch.full(
|
|
71
|
+
(2, batch_size, sequence_length, self.num_heads, dim_per_head),
|
|
72
|
+
float("NaN"),
|
|
73
|
+
device=self.in_proj.weight.device,
|
|
74
|
+
dtype=self.in_proj.weight.dtype,
|
|
75
|
+
),
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def increment_step(self, state: dict, increment: int = 1):
|
|
79
|
+
new_size = state["current_end"].shape[0] + increment
|
|
80
|
+
state["current_end"] = torch.zeros((new_size,)).to(state["current_end"].device)
|
|
81
|
+
|
|
82
|
+
def _complete_kv(self, k, v, state: dict | None):
|
|
83
|
+
k, v = complete_kv(state["cache"], state["current_end"], k, v)
|
|
84
|
+
return k, v
|
|
85
|
+
|
|
86
|
+
def _apply_rope(self, query: torch.Tensor, key: torch.Tensor, state: dict | None):
|
|
87
|
+
# Apply rope embeddings to query and key tensors.
|
|
88
|
+
streaming_offset = self._streaming_offset(state)
|
|
89
|
+
return self.rope(query, key, offset=streaming_offset)
|
|
90
|
+
|
|
91
|
+
def _streaming_offset(self, state: dict | None) -> torch.Tensor | int:
|
|
92
|
+
return state["current_end"].shape[0]
|
|
93
|
+
|
|
94
|
+
def check_model_state(self, model_state: dict):
|
|
95
|
+
if model_state is None:
|
|
96
|
+
raise ValueError("model_state must be provided")
|
|
97
|
+
return self.get_state(model_state)
|
|
98
|
+
|
|
99
|
+
def forward(self, query: torch.Tensor, model_state: dict | None):
|
|
100
|
+
state = self.check_model_state(model_state)
|
|
101
|
+
|
|
102
|
+
projected = self.in_proj(query)
|
|
103
|
+
# Reshape from (b, t, p*h*d) to (b, t, p, h, d) where p=3, h=num_heads
|
|
104
|
+
b, t, _ = projected.shape
|
|
105
|
+
d = self.embed_dim // self.num_heads
|
|
106
|
+
packed = projected.view(b, t, 3, self.num_heads, d)
|
|
107
|
+
q, k, v = torch.unbind(packed, dim=2)
|
|
108
|
+
q, k = self._apply_rope(q, k, state)
|
|
109
|
+
k, v = self._complete_kv(k, v, state)
|
|
110
|
+
|
|
111
|
+
mask_shape = (query.shape[1], query.shape[1] + state["current_end"].shape[0])
|
|
112
|
+
shift = state["current_end"].shape[0]
|
|
113
|
+
|
|
114
|
+
attn_mask = self._get_mask(mask_shape, shift=shift, device=q.device)
|
|
115
|
+
|
|
116
|
+
q, k, v = [x.transpose(1, 2) for x in (q, k, v)]
|
|
117
|
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask)
|
|
118
|
+
x = x.transpose(1, 2)
|
|
119
|
+
# Reshape from (b, t, h, d) to (b, t, h*d)
|
|
120
|
+
b, t, h, d = x.shape
|
|
121
|
+
x = x.reshape(b, t, h * d)
|
|
122
|
+
x = self.out_proj(x)
|
|
123
|
+
|
|
124
|
+
return x
|