lt-tensor 0.0.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +1 -0
- lt_tensor/_basics.py +270 -0
- lt_tensor/_torch_commons.py +12 -0
- lt_tensor/lr_schedulers.py +114 -0
- lt_tensor/math_ops.py +71 -0
- lt_tensor/misc_utils.py +628 -0
- lt_tensor/model_zoo/__init__.py +9 -0
- lt_tensor/model_zoo/bsc.py +210 -0
- lt_tensor/model_zoo/dfs.py +181 -0
- lt_tensor/model_zoo/fsn.py +67 -0
- lt_tensor/model_zoo/pos.py +121 -0
- lt_tensor/model_zoo/rsd.py +158 -0
- lt_tensor/model_zoo/tfr.py +140 -0
- lt_tensor/monotonic_align.py +70 -0
- lt_tensor/transform.py +349 -0
- lt_tensor-0.0.1a0.dist-info/METADATA +31 -0
- lt_tensor-0.0.1a0.dist-info/RECORD +20 -0
- lt_tensor-0.0.1a0.dist-info/WHEEL +5 -0
- lt_tensor-0.0.1a0.dist-info/licenses/LICENSE +201 -0
- lt_tensor-0.0.1a0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,210 @@
|
|
1
|
+
__all__ = [
|
2
|
+
"FeedForward",
|
3
|
+
"MLP",
|
4
|
+
"TimestepEmbedder",
|
5
|
+
"GRUEncoder",
|
6
|
+
"ConvBlock1D",
|
7
|
+
"TemporalPredictor",
|
8
|
+
"StyleEncoder",
|
9
|
+
"PatchEmbed1D",
|
10
|
+
"MultiScaleEncoder1D",
|
11
|
+
]
|
12
|
+
|
13
|
+
from .._torch_commons import *
|
14
|
+
from .._basics import Model
|
15
|
+
from ..transform import get_sinusoidal_embedding
|
16
|
+
|
17
|
+
|
18
|
+
class FeedForward(Model):
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
d_model: int,
|
22
|
+
ff_dim: int,
|
23
|
+
dropout: float = 0.01,
|
24
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
25
|
+
normalizer: nn.Module = nn.Identity(),
|
26
|
+
):
|
27
|
+
"""Creates a Feed-Forward Layer, with the chosen activation function and the normalizer."""
|
28
|
+
super().__init__()
|
29
|
+
self.net = nn.Sequential(
|
30
|
+
nn.Linear(d_model, ff_dim),
|
31
|
+
activation,
|
32
|
+
nn.Dropout(dropout),
|
33
|
+
nn.Linear(ff_dim, d_model),
|
34
|
+
normalizer,
|
35
|
+
)
|
36
|
+
|
37
|
+
def forward(self, x: Tensor):
|
38
|
+
return self.net(x)
|
39
|
+
|
40
|
+
|
41
|
+
class MLP(Model):
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
d_model: int,
|
45
|
+
ff_dim: int,
|
46
|
+
n_classes: int,
|
47
|
+
dropout: float = 0.01,
|
48
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
49
|
+
):
|
50
|
+
"""Creates a MLP block, with the chosen activation function and the normalizer."""
|
51
|
+
super().__init__()
|
52
|
+
self.net = nn.Sequential(
|
53
|
+
nn.Linear(d_model, ff_dim),
|
54
|
+
activation,
|
55
|
+
nn.Dropout(dropout),
|
56
|
+
nn.Linear(ff_dim, n_classes),
|
57
|
+
)
|
58
|
+
|
59
|
+
def forward(self, x: Tensor):
|
60
|
+
return self.net(x)
|
61
|
+
|
62
|
+
|
63
|
+
class TimestepEmbedder(Model):
|
64
|
+
def __init__(self, dim_emb: int, proj_dim: int):
|
65
|
+
super().__init__()
|
66
|
+
self.net = nn.Sequential(
|
67
|
+
nn.Linear(dim_emb, proj_dim),
|
68
|
+
nn.SiLU(),
|
69
|
+
nn.Linear(proj_dim, proj_dim),
|
70
|
+
)
|
71
|
+
|
72
|
+
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
73
|
+
# t: [B] (long)
|
74
|
+
emb = get_sinusoidal_embedding(t, self.net[0].in_features) # [B, dim_emb]
|
75
|
+
return self.net(emb) # [B, proj_dim]
|
76
|
+
|
77
|
+
|
78
|
+
class GRUEncoder(Model):
|
79
|
+
def __init__(
|
80
|
+
self,
|
81
|
+
input_dim: int,
|
82
|
+
hidden_dim: int,
|
83
|
+
num_layers: int = 1,
|
84
|
+
bidirectional: bool = False,
|
85
|
+
):
|
86
|
+
super().__init__()
|
87
|
+
self.gru = nn.GRU(
|
88
|
+
input_dim,
|
89
|
+
hidden_dim,
|
90
|
+
num_layers=num_layers,
|
91
|
+
batch_first=True,
|
92
|
+
bidirectional=bidirectional,
|
93
|
+
)
|
94
|
+
self.output_dim = hidden_dim * (2 if bidirectional else 1)
|
95
|
+
|
96
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
97
|
+
# x: [B, T, input_dim]
|
98
|
+
output, _ = self.gru(x) # output: [B, T, hidden_dim*D]
|
99
|
+
return output
|
100
|
+
|
101
|
+
|
102
|
+
class ConvBlock1D(Model):
|
103
|
+
def __init__(
|
104
|
+
self,
|
105
|
+
in_channels: int,
|
106
|
+
out_channels: int,
|
107
|
+
kernel_size: int = 3,
|
108
|
+
stride: int = 1,
|
109
|
+
norm: bool = True,
|
110
|
+
residual: bool = False,
|
111
|
+
):
|
112
|
+
super().__init__()
|
113
|
+
padding = (kernel_size - 1) // 2
|
114
|
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
|
115
|
+
self.norm = nn.BatchNorm1d(out_channels) if norm else nn.Identity()
|
116
|
+
self.act = nn.LeakyReLU(0.1)
|
117
|
+
self.residual = residual and in_channels == out_channels
|
118
|
+
|
119
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
120
|
+
y = self.act(self.norm(self.conv(x)))
|
121
|
+
return x + y if self.residual else y
|
122
|
+
|
123
|
+
|
124
|
+
class TemporalPredictor(Model):
|
125
|
+
def __init__(
|
126
|
+
self,
|
127
|
+
d_model: int,
|
128
|
+
hidden_dim: int = 128,
|
129
|
+
n_layers: int = 2,
|
130
|
+
dropout: float = 0.1,
|
131
|
+
):
|
132
|
+
super().__init__()
|
133
|
+
layers = []
|
134
|
+
for _ in range(n_layers):
|
135
|
+
layers.append(nn.Conv1d(d_model, hidden_dim, kernel_size=3, padding=1))
|
136
|
+
layers.append(nn.ReLU())
|
137
|
+
layers.append(nn.LayerNorm(hidden_dim))
|
138
|
+
layers.append(nn.Dropout(dropout))
|
139
|
+
d_model = hidden_dim
|
140
|
+
self.network = nn.Sequential(*layers)
|
141
|
+
self.proj = nn.Linear(hidden_dim, 1)
|
142
|
+
|
143
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
144
|
+
# x: [B, T, D]
|
145
|
+
x = x.transpose(1, 2) # [B, D, T]
|
146
|
+
x = self.network(x) # [B, H, T]
|
147
|
+
x = x.transpose(1, 2) # [B, T, H]
|
148
|
+
return self.proj(x).squeeze(-1) # [B, T]
|
149
|
+
|
150
|
+
|
151
|
+
class StyleEncoder(Model):
|
152
|
+
def __init__(self, in_channels: int = 80, hidden: int = 128, out_dim: int = 256):
|
153
|
+
super().__init__()
|
154
|
+
self.net = nn.Sequential(
|
155
|
+
nn.Conv1d(in_channels, hidden, kernel_size=3, stride=2, padding=1),
|
156
|
+
nn.ReLU(),
|
157
|
+
nn.Conv1d(hidden, hidden, kernel_size=3, stride=2, padding=1),
|
158
|
+
nn.ReLU(),
|
159
|
+
nn.AdaptiveAvgPool1d(1),
|
160
|
+
)
|
161
|
+
self.linear = nn.Linear(hidden, out_dim)
|
162
|
+
|
163
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
164
|
+
# x: [B, Mels, T]
|
165
|
+
x = self.net(x).squeeze(-1) # [B, hidden]
|
166
|
+
return self.linear(x) # [B, out_dim]
|
167
|
+
|
168
|
+
|
169
|
+
class PatchEmbed1D(Model):
|
170
|
+
def __init__(self, in_channels: int, patch_size: int, embed_dim: int):
|
171
|
+
"""
|
172
|
+
Args:
|
173
|
+
in_channels: number of input channels (e.g., mel bins)
|
174
|
+
patch_size: number of time-steps per patch
|
175
|
+
embed_dim: dimension of the patch embedding
|
176
|
+
"""
|
177
|
+
super().__init__()
|
178
|
+
self.patch_size = patch_size
|
179
|
+
self.proj = nn.Conv1d(
|
180
|
+
in_channels, embed_dim, kernel_size=patch_size, stride=patch_size
|
181
|
+
)
|
182
|
+
|
183
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
184
|
+
# x: [B, C, T]
|
185
|
+
x = self.proj(x) # [B, embed_dim, T//patch_size]
|
186
|
+
return x.transpose(1, 2) # [B, T_patches, embed_dim]
|
187
|
+
|
188
|
+
|
189
|
+
class MultiScaleEncoder1D(Model):
|
190
|
+
def __init__(
|
191
|
+
self, in_channels: int, hidden: int, num_layers: int = 4, kernel_size: int = 3
|
192
|
+
):
|
193
|
+
super().__init__()
|
194
|
+
layers = []
|
195
|
+
for i in range(num_layers):
|
196
|
+
layers.append(
|
197
|
+
nn.Conv1d(
|
198
|
+
in_channels if i == 0 else hidden,
|
199
|
+
hidden,
|
200
|
+
kernel_size=kernel_size,
|
201
|
+
dilation=2**i,
|
202
|
+
padding=(kernel_size - 1) * (2**i) // 2,
|
203
|
+
)
|
204
|
+
)
|
205
|
+
layers.append(nn.GELU())
|
206
|
+
self.net = nn.Sequential(*layers)
|
207
|
+
|
208
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
209
|
+
# x: [B, C, T]
|
210
|
+
return self.net(x) # [B, hidden, T]
|
@@ -0,0 +1,181 @@
|
|
1
|
+
__all__ = [
|
2
|
+
"Downsample1D",
|
3
|
+
"Upsample1D",
|
4
|
+
"DiffusionUNet",
|
5
|
+
"UNetConvBlock1D",
|
6
|
+
"UNetUpBlock1D",
|
7
|
+
"NoisePredictor1D",
|
8
|
+
]
|
9
|
+
|
10
|
+
from .._torch_commons import *
|
11
|
+
from .._basics import Model
|
12
|
+
from .rsd import ResBlock1D
|
13
|
+
from ..misc_utils import log_tensor
|
14
|
+
|
15
|
+
|
16
|
+
class Downsample1D(Model):
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
in_channels: int,
|
20
|
+
out_channels: int,
|
21
|
+
):
|
22
|
+
super().__init__()
|
23
|
+
self.pool = nn.Conv1d(in_channels, out_channels, 4, stride=2, padding=1)
|
24
|
+
|
25
|
+
def forward(self, x):
|
26
|
+
return self.pool(x)
|
27
|
+
|
28
|
+
|
29
|
+
class Upsample1D(Model):
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
in_channels: int,
|
33
|
+
out_channels: int,
|
34
|
+
activation=nn.ReLU(inplace=True),
|
35
|
+
):
|
36
|
+
super().__init__()
|
37
|
+
self.up = nn.Sequential(
|
38
|
+
nn.ConvTranspose1d(
|
39
|
+
in_channels, out_channels, kernel_size=4, stride=2, padding=1
|
40
|
+
),
|
41
|
+
nn.BatchNorm1d(out_channels),
|
42
|
+
activation,
|
43
|
+
)
|
44
|
+
|
45
|
+
def forward(self, x):
|
46
|
+
return self.up(x)
|
47
|
+
|
48
|
+
|
49
|
+
class DiffusionUNet(Model):
|
50
|
+
def __init__(self, in_channels=1, base_channels=64, out_channels=1, depth=4):
|
51
|
+
super().__init__()
|
52
|
+
|
53
|
+
self.depth = depth
|
54
|
+
self.encoder_blocks = nn.ModuleList()
|
55
|
+
self.downsamples = nn.ModuleList()
|
56
|
+
self.upsamples = nn.ModuleList()
|
57
|
+
self.decoder_blocks = nn.ModuleList()
|
58
|
+
# Keep track of channel sizes per layer for skip connections
|
59
|
+
self.channels = [in_channels] # starting input channel
|
60
|
+
for i in range(depth):
|
61
|
+
enc_in = self.channels[-1]
|
62
|
+
enc_out = base_channels * (2**i)
|
63
|
+
# Encoder block and downsample
|
64
|
+
self.encoder_blocks.append(ResBlock1D(enc_in, enc_out))
|
65
|
+
self.downsamples.append(
|
66
|
+
Downsample1D(enc_out, enc_out)
|
67
|
+
) # halve time, keep channels
|
68
|
+
self.channels.append(enc_out)
|
69
|
+
# Bottleneck
|
70
|
+
bottleneck_ch = self.channels[-1]
|
71
|
+
self.bottleneck = ResBlock1D(bottleneck_ch, bottleneck_ch)
|
72
|
+
# Decoder blocks (reverse channel flow)
|
73
|
+
for i in reversed(range(depth)):
|
74
|
+
skip_ch = self.channels[i + 1] # from encoder
|
75
|
+
dec_out = self.channels[i] # match earlier stage's output
|
76
|
+
self.upsamples.append(Upsample1D(skip_ch, skip_ch))
|
77
|
+
self.decoder_blocks.append(ResBlock1D(skip_ch * 2, dec_out))
|
78
|
+
# Final output projection (out_channels)
|
79
|
+
self.final = nn.Conv1d(in_channels, out_channels, kernel_size=1)
|
80
|
+
|
81
|
+
def forward(self, x: Tensor):
|
82
|
+
skips = []
|
83
|
+
|
84
|
+
# Encoder
|
85
|
+
for enc, down in zip(self.encoder_blocks, self.downsamples):
|
86
|
+
# log_tensor(x, "before enc")
|
87
|
+
x = enc(x)
|
88
|
+
skips.append(x)
|
89
|
+
x = down(x)
|
90
|
+
|
91
|
+
# Bottleneck
|
92
|
+
x = self.bottleneck(x)
|
93
|
+
|
94
|
+
# Decoder
|
95
|
+
for up, dec, skip in zip(self.upsamples, self.decoder_blocks, reversed(skips)):
|
96
|
+
x = up(x)
|
97
|
+
|
98
|
+
# Match lengths via trimming or padding
|
99
|
+
if x.shape[-1] > skip.shape[-1]:
|
100
|
+
x = x[..., : skip.shape[-1]]
|
101
|
+
elif x.shape[-1] < skip.shape[-1]:
|
102
|
+
diff = skip.shape[-1] - x.shape[-1]
|
103
|
+
x = F.pad(x, (0, diff))
|
104
|
+
|
105
|
+
x = torch.cat([x, skip], dim=1) # concat on channels
|
106
|
+
x = dec(x)
|
107
|
+
|
108
|
+
# Final 1x1 conv
|
109
|
+
return self.final(x)
|
110
|
+
|
111
|
+
|
112
|
+
class UNetConvBlock1D(Model):
|
113
|
+
def __init__(self, in_channels: int, out_channels: int, down: bool = True):
|
114
|
+
super().__init__()
|
115
|
+
self.down = down
|
116
|
+
self.conv = nn.Sequential(
|
117
|
+
nn.Conv1d(
|
118
|
+
in_channels,
|
119
|
+
out_channels,
|
120
|
+
kernel_size=3,
|
121
|
+
stride=2 if down else 1,
|
122
|
+
padding=1,
|
123
|
+
),
|
124
|
+
nn.BatchNorm1d(out_channels),
|
125
|
+
nn.LeakyReLU(0.2),
|
126
|
+
nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
|
127
|
+
nn.BatchNorm1d(out_channels),
|
128
|
+
nn.LeakyReLU(0.2),
|
129
|
+
)
|
130
|
+
self.downsample = (
|
131
|
+
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=2 if down else 1)
|
132
|
+
if in_channels != out_channels
|
133
|
+
else nn.Identity()
|
134
|
+
)
|
135
|
+
|
136
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
137
|
+
# x: [B, C, T]
|
138
|
+
residual = self.downsample(x)
|
139
|
+
return self.conv(x) + residual
|
140
|
+
|
141
|
+
|
142
|
+
class UNetUpBlock1D(Model):
|
143
|
+
def __init__(self, in_channels: int, out_channels: int):
|
144
|
+
super().__init__()
|
145
|
+
self.conv = nn.Sequential(
|
146
|
+
nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
|
147
|
+
nn.BatchNorm1d(out_channels),
|
148
|
+
nn.LeakyReLU(0.2),
|
149
|
+
nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
|
150
|
+
nn.BatchNorm1d(out_channels),
|
151
|
+
nn.LeakyReLU(0.2),
|
152
|
+
)
|
153
|
+
self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
|
154
|
+
|
155
|
+
def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
|
156
|
+
x = self.upsample(x)
|
157
|
+
x = torch.cat([x, skip], dim=1) # skip connection
|
158
|
+
return self.conv(x)
|
159
|
+
|
160
|
+
|
161
|
+
class NoisePredictor1D(Model):
|
162
|
+
def __init__(self, in_channels: int, cond_dim: int = 0, hidden: int = 128):
|
163
|
+
"""
|
164
|
+
Args:
|
165
|
+
in_channels: channels of the noisy input [B, C, T]
|
166
|
+
cond_dim: optional condition vector [B, cond_dim]
|
167
|
+
"""
|
168
|
+
super().__init__()
|
169
|
+
self.proj = nn.Linear(cond_dim, hidden) if cond_dim > 0 else None
|
170
|
+
self.net = nn.Sequential(
|
171
|
+
nn.Conv1d(in_channels, hidden, kernel_size=3, padding=1),
|
172
|
+
nn.SiLU(),
|
173
|
+
nn.Conv1d(hidden, in_channels, kernel_size=3, padding=1),
|
174
|
+
)
|
175
|
+
|
176
|
+
def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):
|
177
|
+
# x: [B, C, T], cond: [B, cond_dim]
|
178
|
+
if cond is not None:
|
179
|
+
cond_proj = self.proj(cond).unsqueeze(-1) # [B, hidden, 1]
|
180
|
+
x = x + cond_proj # simple conditioning
|
181
|
+
return self.net(x) # [B, C, T]
|
@@ -0,0 +1,67 @@
|
|
1
|
+
__all__ = [
|
2
|
+
"ConcatFusion",
|
3
|
+
"FiLMFusion",
|
4
|
+
"BilinearFusion",
|
5
|
+
"CrossAttentionFusion",
|
6
|
+
"GatedFusion",
|
7
|
+
]
|
8
|
+
|
9
|
+
from .._torch_commons import *
|
10
|
+
from .._basics import Model
|
11
|
+
|
12
|
+
|
13
|
+
class ConcatFusion(Model):
|
14
|
+
def __init__(self, in_dim_a: int, in_dim_b: int, out_dim: int):
|
15
|
+
super().__init__()
|
16
|
+
self.proj = nn.Linear(in_dim_a + in_dim_b, out_dim)
|
17
|
+
|
18
|
+
def forward(self, a: Tensor, b: Tensor) -> Tensor:
|
19
|
+
x = torch.cat([a, b], dim=-1)
|
20
|
+
return self.proj(x)
|
21
|
+
|
22
|
+
|
23
|
+
class FiLMFusion(Model):
|
24
|
+
def __init__(self, cond_dim: int, feature_dim: int):
|
25
|
+
super().__init__()
|
26
|
+
self.modulator = nn.Linear(cond_dim, 2 * feature_dim)
|
27
|
+
|
28
|
+
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
|
29
|
+
scale, shift = self.modulator(cond).chunk(2, dim=-1)
|
30
|
+
return x * scale + shift
|
31
|
+
|
32
|
+
|
33
|
+
class BilinearFusion(Model):
|
34
|
+
def __init__(self, in_dim_a: int, in_dim_b: int, out_dim: int):
|
35
|
+
super().__init__()
|
36
|
+
self.bilinear = nn.Bilinear(in_dim_a, in_dim_b, out_dim)
|
37
|
+
|
38
|
+
def forward(self, a: Tensor, b: Tensor) -> Tensor:
|
39
|
+
return self.bilinear(a, b)
|
40
|
+
|
41
|
+
|
42
|
+
class CrossAttentionFusion(nn.Module):
|
43
|
+
def __init__(self, q_dim: int, kv_dim: int, n_heads: int = 4, d_model: int = 256):
|
44
|
+
super().__init__()
|
45
|
+
self.q_proj = nn.Linear(q_dim, d_model)
|
46
|
+
self.k_proj = nn.Linear(kv_dim, d_model)
|
47
|
+
self.v_proj = nn.Linear(kv_dim, d_model)
|
48
|
+
self.attn = nn.MultiheadAttention(
|
49
|
+
embed_dim=d_model, num_heads=n_heads, batch_first=True
|
50
|
+
)
|
51
|
+
|
52
|
+
def forward(self, query: Tensor, context: Tensor, mask: Tensor = None) -> Tensor:
|
53
|
+
Q = self.q_proj(query)
|
54
|
+
K = self.k_proj(context)
|
55
|
+
V = self.v_proj(context)
|
56
|
+
output, _ = self.attn(Q, K, V, key_padding_mask=mask)
|
57
|
+
return output
|
58
|
+
|
59
|
+
|
60
|
+
class GatedFusion(nn.Module):
|
61
|
+
def __init__(self, in_dim: int):
|
62
|
+
super().__init__()
|
63
|
+
self.gate = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.Sigmoid())
|
64
|
+
|
65
|
+
def forward(self, a: Tensor, b: Tensor) -> Tensor:
|
66
|
+
gate = self.gate(torch.cat([a, b], dim=-1))
|
67
|
+
return gate * a + (1 - gate) * b
|
@@ -0,0 +1,121 @@
|
|
1
|
+
__all__ = [
|
2
|
+
"RotaryEmbedding",
|
3
|
+
"PositionalEncoding",
|
4
|
+
"LearnedPositionalEncoding",
|
5
|
+
]
|
6
|
+
|
7
|
+
import math
|
8
|
+
from .._torch_commons import *
|
9
|
+
from .._basics import Model
|
10
|
+
|
11
|
+
|
12
|
+
class RotaryEmbedding(Module):
|
13
|
+
def __init__(self, dim: int, base: int = 10000):
|
14
|
+
"""
|
15
|
+
Rotary Positional Embedding Module.
|
16
|
+
Args:
|
17
|
+
dim (int): The dimension of the rotary embedding (must be even).
|
18
|
+
base (int): The base frequency scale (default: 10000).
|
19
|
+
"""
|
20
|
+
super().__init__()
|
21
|
+
assert dim % 2 == 0, "Rotary dimension must be even"
|
22
|
+
self.dim = dim
|
23
|
+
self.base = base
|
24
|
+
|
25
|
+
def forward(self, x, seq_len=None):
|
26
|
+
"""
|
27
|
+
Apply rotary embeddings to input tensor.
|
28
|
+
Args:
|
29
|
+
x (torch.Tensor): Input tensor of shape [batch, seq_len, dim].
|
30
|
+
seq_len (int, optional): Override for sequence length.
|
31
|
+
Returns:
|
32
|
+
torch.Tensor: Tensor with rotary embeddings applied.
|
33
|
+
"""
|
34
|
+
bsz, seq_len = x.shape[0], seq_len or x.shape[1]
|
35
|
+
device = x.device
|
36
|
+
|
37
|
+
pos = torch.arange(seq_len, dtype=torch.float32, device=device).unsqueeze(1)
|
38
|
+
freqs = torch.pow(
|
39
|
+
self.base, -torch.arange(0, self.dim, 2, device=device).float() / self.dim
|
40
|
+
)
|
41
|
+
angle = pos * freqs # [seq_len, dim/2]
|
42
|
+
|
43
|
+
sin = torch.sin(angle)
|
44
|
+
cos = torch.cos(angle)
|
45
|
+
|
46
|
+
# Expand and interleave to [seq_len, dim]
|
47
|
+
sin = torch.stack((sin, sin), dim=-1).reshape(seq_len, self.dim)
|
48
|
+
cos = torch.stack((cos, cos), dim=-1).reshape(seq_len, self.dim)
|
49
|
+
|
50
|
+
sin = sin.unsqueeze(0).expand(bsz, -1, -1) # [batch, seq_len, dim]
|
51
|
+
cos = cos.unsqueeze(0).expand(bsz, -1, -1)
|
52
|
+
|
53
|
+
return self.apply_rotary(x, sin, cos)
|
54
|
+
|
55
|
+
def _apply_rotary(self, x, sin, cos):
|
56
|
+
"""This version may still be useful, but for now its the problem for the text model"""
|
57
|
+
x1, x2 = x.chunk(2, dim=-1)
|
58
|
+
return torch.cat((x1 * cos - x2 * sin, x2 * cos + x1 * sin), dim=-1)
|
59
|
+
|
60
|
+
def apply_rotary(self, x, sin, cos):
|
61
|
+
"""x: [batch, seq_len, dim] → assume dim is even"""
|
62
|
+
b, s, d = x.shape
|
63
|
+
x = x.view(b, s, d // 2, 2) # [b, s, d//2, 2]
|
64
|
+
sin = sin.view(b, s, d // 2, 2)
|
65
|
+
cos = cos.view(b, s, d // 2, 2)
|
66
|
+
|
67
|
+
# Apply rotation: even, odd = x[..., 0], x[..., 1]
|
68
|
+
x_rotated = torch.stack(
|
69
|
+
[
|
70
|
+
x[..., 0] * cos[..., 0] - x[..., 1] * sin[..., 0],
|
71
|
+
x[..., 0] * sin[..., 0] + x[..., 1] * cos[..., 0],
|
72
|
+
],
|
73
|
+
dim=-1,
|
74
|
+
)
|
75
|
+
|
76
|
+
return x_rotated.view(b, s, d) # Back to [b, s, d]
|
77
|
+
|
78
|
+
|
79
|
+
class PositionalEncoding(Module):
|
80
|
+
def __init__(self, d_model: int, max_len: int = 8192):
|
81
|
+
super().__init__()
|
82
|
+
# create a matrix of [seq_len, hidden_dim] representing positional encoding for each token in sequence
|
83
|
+
pe = torch.zeros(max_len, d_model)
|
84
|
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(
|
85
|
+
1
|
86
|
+
) # (max_len, 1)
|
87
|
+
div_term = torch.exp(
|
88
|
+
torch.arange(0, d_model, 2, dtype=torch.float)
|
89
|
+
* (-math.log(10000.0) / d_model)
|
90
|
+
)
|
91
|
+
pe[:, 0::2] = torch.sin(position * div_term)
|
92
|
+
pe[:, 1::2] = torch.cos(position * div_term)
|
93
|
+
pe = pe.unsqueeze(0)
|
94
|
+
self.register_buffer("pe", pe, persistent=False) # Shape: (1, max_len, d_model)
|
95
|
+
|
96
|
+
def forward(self, x: Tensor, seq_len: Optional[Tensor] = None):
|
97
|
+
# x shape: (batch_size, seq_len, d_model)
|
98
|
+
s_sz = seq_len or x.size(1)
|
99
|
+
x = x + self.pe[:, :s_sz]
|
100
|
+
return x
|
101
|
+
|
102
|
+
|
103
|
+
class LearnedPositionalEncoding(Module):
|
104
|
+
def __init__(self, max_len: int, dim_model: int, dropout: float = 0.1):
|
105
|
+
super().__init__()
|
106
|
+
self.embedding = nn.Embedding(max_len, dim_model)
|
107
|
+
self.dropout = nn.Dropout(dropout)
|
108
|
+
self.max_len = max_len
|
109
|
+
|
110
|
+
def forward(self, x: torch.Tensor, offset: int = 0) -> torch.Tensor:
|
111
|
+
# x: [B, T, D] or [T, D]
|
112
|
+
seq_len = x.size(1 if x.dim() == 3 else 0)
|
113
|
+
if seq_len + offset > self.max_len:
|
114
|
+
raise ValueError(
|
115
|
+
f"Sequence length {seq_len + offset} exceeds max length {self.max_len}"
|
116
|
+
)
|
117
|
+
positions = torch.arange(offset, offset + seq_len, device=x.device)
|
118
|
+
pos_embed = self.embedding(positions)
|
119
|
+
if x.dim() == 3:
|
120
|
+
pos_embed = pos_embed.unsqueeze(0).expand(x.size(0), -1, -1) # [B, T, D]
|
121
|
+
return self.dropout(x + pos_embed)
|