pocket-tts 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,46 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ from pocket_tts.modules.conv import StreamingConv1d, StreamingConvTranspose1d
5
+
6
+
7
+ class ConvDownsample1d(nn.Module):
8
+ """
9
+ Downsampling by some integer amount `stride` using convolutions
10
+ with a kernel size of twice the stride.
11
+ """
12
+
13
+ def __init__(self, stride: int, dimension: int):
14
+ super().__init__()
15
+ self.conv = StreamingConv1d(
16
+ dimension,
17
+ dimension,
18
+ kernel_size=2 * stride,
19
+ stride=stride,
20
+ groups=1,
21
+ bias=False,
22
+ pad_mode="replicate",
23
+ )
24
+
25
+ def forward(self, x: torch.Tensor, model_state: dict | None):
26
+ return self.conv(x, model_state)
27
+
28
+
29
+ class ConvTrUpsample1d(nn.Module):
30
+ """
31
+ Upsample by some integer amount `stride` using transposed convolutions.
32
+ """
33
+
34
+ def __init__(self, stride: int, dimension: int):
35
+ super().__init__()
36
+ self.convtr = StreamingConvTranspose1d(
37
+ dimension,
38
+ dimension,
39
+ kernel_size=2 * stride,
40
+ stride=stride,
41
+ groups=dimension,
42
+ bias=False,
43
+ )
44
+
45
+ def forward(self, x: torch.Tensor, model_state: dict | None):
46
+ return self.convtr(x, model_state)
@@ -0,0 +1,74 @@
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ def apply_rope(
8
+ q: torch.Tensor,
9
+ k: torch.Tensor,
10
+ offset: int | torch.Tensor = 0,
11
+ max_period: int | float = 10_000,
12
+ ):
13
+ """
14
+ Args:
15
+ q (torch.Tensor): Queries, shape `[B, T, H, D]`.
16
+ k (torch.Tensor): Keys, shape `[B, T, H, D]`.
17
+ offset (int): Current offset, e.g. when streaming.
18
+ max_period (float): Maximum period for the cos and sin.
19
+ """
20
+
21
+ B, T, H, D = q.shape
22
+ Bk, Tk, Hk, Dk = k.shape
23
+ assert (B, T, D) == (Bk, Tk, Dk)
24
+ assert D > 0
25
+ assert D % 2 == 0
26
+ assert max_period > 0
27
+
28
+ ds = torch.arange(D // 2, device=q.device, dtype=torch.float32)
29
+ freqs = torch.exp(ds * (-math.log(max_period) * 2 / D))
30
+
31
+ # could be optimized in one call
32
+ ts = torch.arange(T, device=q.device, dtype=torch.float32)
33
+ ts += offset
34
+ ts = ts.view(-1, 1, 1)
35
+
36
+ q = q.view(B, T, H, D // 2, 2)
37
+ k = k.view(B, T, Hk, D // 2, 2)
38
+
39
+ # convention is `r` suffix is real part, `i` is imaginary.
40
+ qr = q[..., 0].float()
41
+ qi = q[..., 1].float()
42
+
43
+ kr = k[..., 0].float()
44
+ ki = k[..., 1].float()
45
+
46
+ rotr = torch.cos(freqs * ts)
47
+ roti = torch.sin(freqs * ts)
48
+ qor = qr * rotr - qi * roti
49
+ qoi = qr * roti + qi * rotr
50
+
51
+ kor = kr * rotr - ki * roti
52
+ koi = kr * roti + ki * rotr
53
+
54
+ dtype = q.dtype
55
+ qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1)
56
+ ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1)
57
+
58
+ return qo.view(B, T, H, D), ko.view(B, T, Hk, D)
59
+
60
+
61
+ class RotaryEmbedding(nn.Module):
62
+ """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
63
+
64
+ Args:
65
+ max_period (float): Maximum period of the rotation frequencies.
66
+ """
67
+
68
+ def __init__(self, max_period: float | int = 10000.0):
69
+ super().__init__()
70
+ self.max_period = max_period
71
+
72
+ def forward(self, q: torch.Tensor, k: torch.Tensor, offset: torch.Tensor | int):
73
+ """Apply rope rotation to query or key tensor."""
74
+ return apply_rope(q, k, offset, self.max_period)
@@ -0,0 +1,180 @@
1
+ import numpy as np
2
+ import torch.nn as nn
3
+
4
+ from .conv import StreamingConv1d, StreamingConvTranspose1d
5
+
6
+
7
+ class SEANetResnetBlock(nn.Module):
8
+ def __init__(
9
+ self,
10
+ dim: int,
11
+ kernel_sizes: list[int] = [3, 1],
12
+ dilations: list[int] = [1, 1],
13
+ pad_mode: str = "reflect",
14
+ compress: int = 2,
15
+ ):
16
+ super().__init__()
17
+ assert len(kernel_sizes) == len(dilations), (
18
+ "Number of kernel sizes should match number of dilations"
19
+ )
20
+ hidden = dim // compress
21
+ block = nn.ModuleList([])
22
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
23
+ in_chs = dim if i == 0 else hidden
24
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
25
+ block += [
26
+ nn.ELU(alpha=1.0),
27
+ StreamingConv1d(
28
+ in_chs, out_chs, kernel_size=kernel_size, dilation=dilation, pad_mode=pad_mode
29
+ ),
30
+ ]
31
+ self.block = block
32
+
33
+ def forward(self, x, model_state: dict | None):
34
+ v = x
35
+ for layer in self.block:
36
+ if isinstance(layer, StreamingConv1d):
37
+ v = layer(v, model_state)
38
+ else:
39
+ v = layer(v)
40
+ assert x.shape == v.shape, (x.shape, v.shape, x.shape)
41
+ return x + v
42
+
43
+
44
+ class SEANetEncoder(nn.Module):
45
+ def __init__(
46
+ self,
47
+ channels: int = 1,
48
+ dimension: int = 128,
49
+ n_filters: int = 32,
50
+ n_residual_layers: int = 3,
51
+ ratios: list[int] = [8, 5, 4, 2],
52
+ kernel_size: int = 7,
53
+ last_kernel_size: int = 7,
54
+ residual_kernel_size: int = 3,
55
+ dilation_base: int = 2,
56
+ pad_mode: str = "reflect",
57
+ compress: int = 2,
58
+ ):
59
+ super().__init__()
60
+ self.channels = channels
61
+ self.dimension = dimension
62
+ self.n_filters = n_filters
63
+ self.ratios = list(reversed(ratios))
64
+ del ratios
65
+ self.n_residual_layers = n_residual_layers
66
+ self.hop_length = int(np.prod(self.ratios))
67
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
68
+
69
+ mult = 1
70
+ model = nn.ModuleList(
71
+ [StreamingConv1d(channels, mult * n_filters, kernel_size, pad_mode=pad_mode)]
72
+ )
73
+ # Downsample to raw audio scale
74
+ for i, ratio in enumerate(self.ratios):
75
+ # Add residual layers
76
+ for j in range(n_residual_layers):
77
+ model += [
78
+ SEANetResnetBlock(
79
+ mult * n_filters,
80
+ kernel_sizes=[residual_kernel_size, 1],
81
+ dilations=[dilation_base**j, 1],
82
+ pad_mode=pad_mode,
83
+ compress=compress,
84
+ )
85
+ ]
86
+
87
+ # Add downsampling layers
88
+ model += [
89
+ nn.ELU(alpha=1.0),
90
+ StreamingConv1d(
91
+ mult * n_filters,
92
+ mult * n_filters * 2,
93
+ kernel_size=ratio * 2,
94
+ stride=ratio,
95
+ pad_mode=pad_mode,
96
+ ),
97
+ ]
98
+ mult *= 2
99
+
100
+ model += [
101
+ nn.ELU(alpha=1.0),
102
+ StreamingConv1d(mult * n_filters, dimension, last_kernel_size, pad_mode=pad_mode),
103
+ ]
104
+
105
+ self.model = model
106
+
107
+ def forward(self, x, model_state: dict | None):
108
+ for layer in self.model:
109
+ if isinstance(layer, (StreamingConv1d, SEANetResnetBlock)):
110
+ x = layer(x, model_state)
111
+ else:
112
+ x = layer(x)
113
+ return x
114
+
115
+
116
+ class SEANetDecoder(nn.Module):
117
+ def __init__(
118
+ self,
119
+ channels: int = 1,
120
+ dimension: int = 128,
121
+ n_filters: int = 32,
122
+ n_residual_layers: int = 3,
123
+ ratios: list[int] = [8, 5, 4, 2],
124
+ kernel_size: int = 7,
125
+ last_kernel_size: int = 7,
126
+ residual_kernel_size: int = 3,
127
+ dilation_base: int = 2,
128
+ pad_mode: str = "reflect",
129
+ compress: int = 2,
130
+ ):
131
+ super().__init__()
132
+ self.dimension = dimension
133
+ self.channels = channels
134
+ self.n_filters = n_filters
135
+ self.ratios = ratios
136
+ del ratios
137
+ self.n_residual_layers = n_residual_layers
138
+ self.hop_length = int(np.prod(self.ratios))
139
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
140
+ mult = int(2 ** len(self.ratios))
141
+ model = nn.ModuleList(
142
+ [StreamingConv1d(dimension, mult * n_filters, kernel_size, pad_mode=pad_mode)]
143
+ )
144
+ # Upsample to raw audio scale
145
+ for _, ratio in enumerate(self.ratios):
146
+ # Add upsampling layers
147
+ model += [
148
+ nn.ELU(alpha=1.0),
149
+ StreamingConvTranspose1d(
150
+ mult * n_filters, mult * n_filters // 2, kernel_size=ratio * 2, stride=ratio
151
+ ),
152
+ ]
153
+ # Add residual layers
154
+ for j in range(n_residual_layers):
155
+ model += [
156
+ SEANetResnetBlock(
157
+ mult * n_filters // 2,
158
+ kernel_sizes=[residual_kernel_size, 1],
159
+ dilations=[dilation_base**j, 1],
160
+ pad_mode=pad_mode,
161
+ compress=compress,
162
+ )
163
+ ]
164
+
165
+ mult //= 2
166
+
167
+ # Add final layers
168
+ model += [
169
+ nn.ELU(alpha=1.0),
170
+ StreamingConv1d(n_filters, channels, last_kernel_size, pad_mode=pad_mode),
171
+ ]
172
+ self.model = model
173
+
174
+ def forward(self, z, model_state: dict | None):
175
+ for layer in self.model:
176
+ if isinstance(layer, (StreamingConvTranspose1d, SEANetResnetBlock, StreamingConv1d)):
177
+ z = layer(z, model_state)
178
+ else:
179
+ z = layer(z)
180
+ return z
@@ -0,0 +1,45 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ def init_states(
8
+ model: nn.Module, batch_size: int, sequence_length: int
9
+ ) -> dict[str, dict[str, torch.Tensor]]:
10
+ result = {}
11
+ for module_name, module in model.named_modules():
12
+ if not isinstance(module, StatefulModule):
13
+ continue
14
+ module._module_absolute_name = module_name
15
+ module_state = module.init_state(batch_size, sequence_length=sequence_length)
16
+ result[module_name] = module_state
17
+ return result
18
+
19
+
20
+ def increment_steps(
21
+ module: nn.Module, model_state: dict[str, dict[str, torch.Tensor]], increment: int = 1
22
+ ):
23
+ # print("incrementing steps by", increment)
24
+ for module_name, module in module.named_modules():
25
+ if not isinstance(module, StatefulModule):
26
+ continue
27
+ module.increment_step(model_state[module_name], increment)
28
+
29
+
30
+ class StatefulModule(ABC, nn.Module):
31
+ def __init__(self, *args, **kwds):
32
+ self._module_absolute_name = None
33
+ return super().__init__(*args, **kwds)
34
+
35
+ @abstractmethod
36
+ def init_state(self, batch_size: int, sequence_length: int):
37
+ """Initialize the state."""
38
+ raise NotImplementedError
39
+
40
+ def increment_step(self, state: dict, increment: int = 1):
41
+ pass
42
+
43
+ def get_state(self, model_state: dict[str, dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
44
+ """Get the state for this module from the model state."""
45
+ return model_state[self._module_absolute_name]
@@ -0,0 +1,124 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from pocket_tts.modules.rope import RotaryEmbedding
6
+ from pocket_tts.modules.stateful_module import StatefulModule
7
+
8
+
9
+ def complete_kv(
10
+ cache: torch.Tensor, current_end: torch.Tensor, k: torch.Tensor, v: torch.Tensor
11
+ ) -> tuple[torch.Tensor, torch.Tensor]:
12
+ current_end = current_end.shape[0]
13
+
14
+ cache[0, :, current_end : current_end + k.shape[1]] = k
15
+ cache[1, :, current_end : current_end + v.shape[1]] = v
16
+ valid = cache[:, :, : current_end + k.shape[1]]
17
+ return valid[0], valid[1]
18
+
19
+
20
+ def _materialize_causal_mask(
21
+ shape: tuple[int, ...], shift: int, device: str | torch.device = "cpu"
22
+ ) -> torch.Tensor:
23
+ dtype = torch.float32
24
+
25
+ num_queries, num_keys = shape[-2:]
26
+ shift = num_keys - num_queries
27
+
28
+ tensor = torch.full(shape, dtype=dtype, fill_value=1, device=device)
29
+ mask = torch.tril(tensor, diagonal=shift).to(dtype)
30
+ mask = torch.log(mask)
31
+ return mask.to(dtype)
32
+
33
+
34
+ class StreamingMultiheadAttention(StatefulModule):
35
+ """Similar to `nn.MultiheadAttention` but with support for streaming.
36
+
37
+ Args:
38
+ embed_dim (int): Dimension to project to.
39
+ num_heads (int): Number of heads.
40
+ context (int, optional): Number of time steps the attention can access to.
41
+ Can access `context` time steps into the past.
42
+ rope (`RotaryEmbedding`, optional): Rope embedding to use.
43
+ device (torch.device, optional): Device on which to initialize.
44
+ dtype (torch.dtype, optional): dtype to use.
45
+ """
46
+
47
+ def __init__(self, embed_dim: int, num_heads: int, rope: RotaryEmbedding):
48
+ super().__init__()
49
+
50
+ self.embed_dim = embed_dim
51
+ self.rope = rope
52
+ self.num_heads = num_heads
53
+
54
+ out_dim = embed_dim
55
+ num_kv = num_heads
56
+ kv_dim = (embed_dim // num_heads) * num_kv
57
+ out_dim += 2 * kv_dim
58
+ mult = 1
59
+ self.in_proj = nn.Linear(embed_dim, mult * out_dim, bias=False)
60
+ self.out_proj = nn.Linear(embed_dim, mult * embed_dim, bias=False)
61
+
62
+ def _get_mask(self, shape: tuple[int, int], shift: int, device: torch.device) -> torch.Tensor:
63
+ return _materialize_causal_mask(shape, shift=shift, device=device)
64
+
65
+ def init_state(self, batch_size: int, sequence_length: int) -> dict[str, torch.Tensor]:
66
+ dim_per_head = self.embed_dim // self.num_heads
67
+ initial_current_end = torch.zeros((0,)).to(self.in_proj.weight.device)
68
+ return dict(
69
+ current_end=initial_current_end,
70
+ cache=torch.full(
71
+ (2, batch_size, sequence_length, self.num_heads, dim_per_head),
72
+ float("NaN"),
73
+ device=self.in_proj.weight.device,
74
+ dtype=self.in_proj.weight.dtype,
75
+ ),
76
+ )
77
+
78
+ def increment_step(self, state: dict, increment: int = 1):
79
+ new_size = state["current_end"].shape[0] + increment
80
+ state["current_end"] = torch.zeros((new_size,)).to(state["current_end"].device)
81
+
82
+ def _complete_kv(self, k, v, state: dict | None):
83
+ k, v = complete_kv(state["cache"], state["current_end"], k, v)
84
+ return k, v
85
+
86
+ def _apply_rope(self, query: torch.Tensor, key: torch.Tensor, state: dict | None):
87
+ # Apply rope embeddings to query and key tensors.
88
+ streaming_offset = self._streaming_offset(state)
89
+ return self.rope(query, key, offset=streaming_offset)
90
+
91
+ def _streaming_offset(self, state: dict | None) -> torch.Tensor | int:
92
+ return state["current_end"].shape[0]
93
+
94
+ def check_model_state(self, model_state: dict):
95
+ if model_state is None:
96
+ raise ValueError("model_state must be provided")
97
+ return self.get_state(model_state)
98
+
99
+ def forward(self, query: torch.Tensor, model_state: dict | None):
100
+ state = self.check_model_state(model_state)
101
+
102
+ projected = self.in_proj(query)
103
+ # Reshape from (b, t, p*h*d) to (b, t, p, h, d) where p=3, h=num_heads
104
+ b, t, _ = projected.shape
105
+ d = self.embed_dim // self.num_heads
106
+ packed = projected.view(b, t, 3, self.num_heads, d)
107
+ q, k, v = torch.unbind(packed, dim=2)
108
+ q, k = self._apply_rope(q, k, state)
109
+ k, v = self._complete_kv(k, v, state)
110
+
111
+ mask_shape = (query.shape[1], query.shape[1] + state["current_end"].shape[0])
112
+ shift = state["current_end"].shape[0]
113
+
114
+ attn_mask = self._get_mask(mask_shape, shift=shift, device=q.device)
115
+
116
+ q, k, v = [x.transpose(1, 2) for x in (q, k, v)]
117
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask)
118
+ x = x.transpose(1, 2)
119
+ # Reshape from (b, t, h, d) to (b, t, h*d)
120
+ b, t, h, d = x.shape
121
+ x = x.reshape(b, t, h * d)
122
+ x = self.out_proj(x)
123
+
124
+ return x