lt-tensor 0.0.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +0 -0
- lt_tensor/_basics.py +244 -0
- lt_tensor/_torch_commons.py +12 -0
- lt_tensor/lr_schedulers.py +108 -0
- lt_tensor/math_ops.py +120 -0
- lt_tensor/misc_utils.py +596 -0
- lt_tensor/model_zoo/__init__.py +2 -0
- lt_tensor/model_zoo/basic.py +65 -0
- lt_tensor/model_zoo/diffusion/__init__.py +6 -0
- lt_tensor/model_zoo/diffusion/models.py +114 -0
- lt_tensor/model_zoo/residual.py +236 -0
- lt_tensor/model_zoo/transformer_models/__init__.py +6 -0
- lt_tensor/model_zoo/transformer_models/models.py +132 -0
- lt_tensor/model_zoo/transformer_models/positional_encoders.py +95 -0
- lt_tensor/monotonic_align.py +70 -0
- lt_tensor/transform.py +113 -0
- lt_tensor-0.0.1.dev0.dist-info/METADATA +33 -0
- lt_tensor-0.0.1.dev0.dist-info/RECORD +21 -0
- lt_tensor-0.0.1.dev0.dist-info/WHEEL +5 -0
- lt_tensor-0.0.1.dev0.dist-info/licenses/LICENSE +201 -0
- lt_tensor-0.0.1.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,114 @@
|
|
1
|
+
__all__ = [
|
2
|
+
"ResidualBlock1D_B",
|
3
|
+
"Downsample1D",
|
4
|
+
"Upsample1D",
|
5
|
+
"DiffusionUNet",
|
6
|
+
"DiffusionUNetT",
|
7
|
+
]
|
8
|
+
|
9
|
+
from ..._torch_commons import *
|
10
|
+
from ..._basics import Model
|
11
|
+
from ..residual import ResBlock1D, ResBlock1DT
|
12
|
+
from ...misc_utils import log_tensor
|
13
|
+
|
14
|
+
|
15
|
+
|
16
|
+
class Downsample1D(Model):
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
in_channels: int,
|
20
|
+
out_channels: int,
|
21
|
+
):
|
22
|
+
super().__init__()
|
23
|
+
self.pool = nn.Conv1d(in_channels, out_channels, 4, stride=2, padding=1)
|
24
|
+
|
25
|
+
def forward(self, x):
|
26
|
+
return self.pool(x)
|
27
|
+
|
28
|
+
|
29
|
+
class Upsample1D(Model):
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
in_channels: int,
|
33
|
+
out_channels: int,
|
34
|
+
activation=nn.ReLU(inplace=True),
|
35
|
+
):
|
36
|
+
super().__init__()
|
37
|
+
self.up = nn.Sequential(
|
38
|
+
nn.ConvTranspose1d(
|
39
|
+
in_channels, out_channels, kernel_size=4, stride=2, padding=1
|
40
|
+
),
|
41
|
+
nn.BatchNorm1d(out_channels),
|
42
|
+
activation,
|
43
|
+
)
|
44
|
+
|
45
|
+
def forward(self, x):
|
46
|
+
return self.up(x)
|
47
|
+
|
48
|
+
|
49
|
+
class DiffusionUNet(Model):
|
50
|
+
def __init__(self, in_channels=1, base_channels=64, out_channels=1, depth=4):
|
51
|
+
super().__init__()
|
52
|
+
|
53
|
+
self.depth = depth
|
54
|
+
self.encoder_blocks = nn.ModuleList()
|
55
|
+
self.downsamples = nn.ModuleList()
|
56
|
+
self.upsamples = nn.ModuleList()
|
57
|
+
self.decoder_blocks = nn.ModuleList()
|
58
|
+
# Keep track of channel sizes per layer for skip connections
|
59
|
+
self.channels = [in_channels] # starting input channel
|
60
|
+
for i in range(depth):
|
61
|
+
enc_in = self.channels[-1]
|
62
|
+
enc_out = base_channels * (2**i)
|
63
|
+
# Encoder block and downsample
|
64
|
+
self.encoder_blocks.append(ResBlock1D(enc_in, enc_out))
|
65
|
+
self.downsamples.append(
|
66
|
+
Downsample1D(enc_out, enc_out)
|
67
|
+
) # halve time, keep channels
|
68
|
+
self.channels.append(enc_out)
|
69
|
+
# Bottleneck
|
70
|
+
bottleneck_ch = self.channels[-1]
|
71
|
+
self.bottleneck = ResBlock1D(bottleneck_ch, bottleneck_ch)
|
72
|
+
# Decoder blocks (reverse channel flow)
|
73
|
+
for i in reversed(range(depth)):
|
74
|
+
skip_ch = self.channels[i + 1] # from encoder
|
75
|
+
dec_out = self.channels[i] # match earlier stage's output
|
76
|
+
self.upsamples.append(Upsample1D(skip_ch, skip_ch))
|
77
|
+
self.decoder_blocks.append(ResBlock1D(skip_ch * 2, dec_out))
|
78
|
+
# Final output projection (out_channels)
|
79
|
+
self.final = nn.Conv1d(in_channels, out_channels, kernel_size=1)
|
80
|
+
|
81
|
+
def forward(self, x: Tensor):
|
82
|
+
skips = []
|
83
|
+
|
84
|
+
# Encoder
|
85
|
+
for enc, down in zip(self.encoder_blocks, self.downsamples):
|
86
|
+
# log_tensor(x, "before enc")
|
87
|
+
x = enc(x)
|
88
|
+
skips.append(x)
|
89
|
+
x = down(x)
|
90
|
+
|
91
|
+
# Bottleneck
|
92
|
+
x = self.bottleneck(x)
|
93
|
+
|
94
|
+
# Decoder
|
95
|
+
for up, dec, skip in zip(self.upsamples, self.decoder_blocks, reversed(skips)):
|
96
|
+
x = up(x)
|
97
|
+
|
98
|
+
# Match lengths via trimming or padding
|
99
|
+
if x.shape[-1] > skip.shape[-1]:
|
100
|
+
x = x[..., : skip.shape[-1]]
|
101
|
+
elif x.shape[-1] < skip.shape[-1]:
|
102
|
+
diff = skip.shape[-1] - x.shape[-1]
|
103
|
+
x = F.pad(x, (0, diff))
|
104
|
+
|
105
|
+
x = torch.cat([x, skip], dim=1) # concat on channels
|
106
|
+
x = dec(x)
|
107
|
+
|
108
|
+
# Final 1x1 conv
|
109
|
+
return self.final(x)
|
110
|
+
|
111
|
+
|
112
|
+
|
113
|
+
|
114
|
+
|
@@ -0,0 +1,236 @@
|
|
1
|
+
from .._torch_commons import *
|
2
|
+
from .._basics import Model
|
3
|
+
import math
|
4
|
+
from ..misc_utils import log_tensor
|
5
|
+
|
6
|
+
|
7
|
+
def initialize__weights(model: nn.Module, method: str = "xavier"):
|
8
|
+
"""Initialize model weights using specified method."""
|
9
|
+
for name, param in model.named_parameters():
|
10
|
+
if "weight" in name:
|
11
|
+
if method == "xavier":
|
12
|
+
nn.init.xavier_uniform_(param)
|
13
|
+
elif method == "kaiming":
|
14
|
+
nn.init.kaiming_uniform_(param, nonlinearity="relu")
|
15
|
+
elif "bias" in name:
|
16
|
+
nn.init.constant_(param, 0)
|
17
|
+
|
18
|
+
|
19
|
+
def spectral_norm_select(module: Module, enabled: bool):
|
20
|
+
if enabled:
|
21
|
+
return spectral_norm(module)
|
22
|
+
return module
|
23
|
+
|
24
|
+
|
25
|
+
def init_weights(m, mean=0.0, std=0.01):
|
26
|
+
classname = m.__class__.__name__
|
27
|
+
if "Conv" in classname:
|
28
|
+
m.weight.data.normal_(mean, std)
|
29
|
+
|
30
|
+
|
31
|
+
class ResBlock1D(Model):
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
in_channels: int,
|
35
|
+
out_channels: int,
|
36
|
+
kernel_size: int = 3,
|
37
|
+
dilation: Union[Sequence[int], int] = (1, 3, 5),
|
38
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
39
|
+
num_groups: int = 1,
|
40
|
+
batched: bool = True,
|
41
|
+
):
|
42
|
+
super().__init__()
|
43
|
+
self.conv = nn.ModuleList()
|
44
|
+
if isinstance(dilation, int):
|
45
|
+
dilation = [dilation]
|
46
|
+
|
47
|
+
if batched:
|
48
|
+
layernorm_fn = lambda x: nn.GroupNorm(num_groups=num_groups, num_channels=x)
|
49
|
+
else:
|
50
|
+
layernorm_fn = lambda x: nn.LayerNorm(normalized_shape=x)
|
51
|
+
for i, dil in enumerate(dilation):
|
52
|
+
|
53
|
+
self.conv.append(
|
54
|
+
nn.ModuleDict(
|
55
|
+
dict(
|
56
|
+
net=nn.Sequential(
|
57
|
+
self._get_conv_layer(
|
58
|
+
in_channels, in_channels, kernel_size, dil
|
59
|
+
),
|
60
|
+
activation,
|
61
|
+
self._get_conv_layer(
|
62
|
+
in_channels, in_channels, kernel_size, 1, True
|
63
|
+
),
|
64
|
+
activation,
|
65
|
+
),
|
66
|
+
l_norm=layernorm_fn(in_channels),
|
67
|
+
)
|
68
|
+
)
|
69
|
+
)
|
70
|
+
self.final = nn.Sequential(
|
71
|
+
self._get_conv_layer(in_channels, out_channels, kernel_size, 1, True),
|
72
|
+
activation,
|
73
|
+
)
|
74
|
+
self.conv.apply(init_weights)
|
75
|
+
|
76
|
+
def _get_conv_layer(
|
77
|
+
self,
|
78
|
+
channels_in: int,
|
79
|
+
channels_out: int,
|
80
|
+
kernel_size: int,
|
81
|
+
dilation: int,
|
82
|
+
pad_gate: bool = False,
|
83
|
+
):
|
84
|
+
return weight_norm(
|
85
|
+
nn.Conv1d(
|
86
|
+
in_channels=channels_in,
|
87
|
+
out_channels=channels_out,
|
88
|
+
kernel_size=kernel_size,
|
89
|
+
stride=1,
|
90
|
+
dilation=dilation,
|
91
|
+
padding=(
|
92
|
+
int((kernel_size * dilation - dilation) / 2)
|
93
|
+
if not pad_gate
|
94
|
+
else int((kernel_size * 1 - 1) / 2)
|
95
|
+
),
|
96
|
+
)
|
97
|
+
)
|
98
|
+
|
99
|
+
def forward(self, x: Tensor):
|
100
|
+
for i, layer in enumerate(self.conv):
|
101
|
+
xt = layer["net"](x)
|
102
|
+
x = xt + x
|
103
|
+
x = layer["l_norm"](x)
|
104
|
+
return self.final(x)
|
105
|
+
|
106
|
+
def remove_weight_norm(self):
|
107
|
+
for module in self.modules():
|
108
|
+
try:
|
109
|
+
remove_weight_norm(module)
|
110
|
+
except ValueError:
|
111
|
+
pass # Not normed, skip
|
112
|
+
|
113
|
+
|
114
|
+
class ResBlock2D(Model):
|
115
|
+
def __init__(
|
116
|
+
self,
|
117
|
+
in_channels,
|
118
|
+
out_channels,
|
119
|
+
downsample=False,
|
120
|
+
spec_norm: bool = False,
|
121
|
+
):
|
122
|
+
super().__init__()
|
123
|
+
stride = 2 if downsample else 1
|
124
|
+
|
125
|
+
self.block = nn.Sequential(
|
126
|
+
spectral_norm_select(
|
127
|
+
nn.Conv2d(in_channels, out_channels, 3, stride, 1), spec_norm
|
128
|
+
),
|
129
|
+
nn.LeakyReLU(0.2),
|
130
|
+
spectral_norm_select(
|
131
|
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1), spec_norm
|
132
|
+
),
|
133
|
+
)
|
134
|
+
|
135
|
+
self.skip = nn.Identity()
|
136
|
+
if downsample or in_channels != out_channels:
|
137
|
+
self.skip = spectral_norm_select(
|
138
|
+
nn.Conv2d(in_channels, out_channels, 1, stride), spec_norm
|
139
|
+
)
|
140
|
+
# on less to be handled every cicle
|
141
|
+
self.sqrt_2 = math.sqrt(2)
|
142
|
+
|
143
|
+
def forward(self, x):
|
144
|
+
return (self.block(x) + self.skip(x)) / self.sqrt_2
|
145
|
+
|
146
|
+
|
147
|
+
class ResBlock1DT(Model):
|
148
|
+
"""For time based residual layers"""
|
149
|
+
|
150
|
+
def __init__(
|
151
|
+
self,
|
152
|
+
in_channels: int,
|
153
|
+
out_channels: int,
|
154
|
+
kernel_size: int = 3,
|
155
|
+
dilation: Union[Sequence[int], int] = (1, 3, 5),
|
156
|
+
activation: nn.Module = nn.LeakyReLU(0.1),
|
157
|
+
num_groups: int = 1,
|
158
|
+
time_emb_dim: int = 1,
|
159
|
+
batched: bool = True,
|
160
|
+
):
|
161
|
+
super().__init__()
|
162
|
+
self.conv = nn.ModuleList()
|
163
|
+
if isinstance(dilation, int):
|
164
|
+
dilation = [dilation]
|
165
|
+
self.time_proj = nn.Linear(time_emb_dim, out_channels)
|
166
|
+
if batched:
|
167
|
+
layernorm_fn = lambda x: nn.GroupNorm(num_groups=num_groups, num_channels=x)
|
168
|
+
else:
|
169
|
+
layernorm_fn = lambda x: nn.LayerNorm(normalized_shape=x)
|
170
|
+
for i, dil in enumerate(dilation):
|
171
|
+
|
172
|
+
self.conv.append(
|
173
|
+
nn.ModuleDict(
|
174
|
+
dict(
|
175
|
+
net=nn.Sequential(
|
176
|
+
self._get_conv_layer(
|
177
|
+
in_channels, in_channels, kernel_size, dil
|
178
|
+
),
|
179
|
+
activation,
|
180
|
+
self._get_conv_layer(
|
181
|
+
in_channels, in_channels, kernel_size, 1, True
|
182
|
+
),
|
183
|
+
activation,
|
184
|
+
),
|
185
|
+
l_norm=layernorm_fn(in_channels),
|
186
|
+
)
|
187
|
+
)
|
188
|
+
)
|
189
|
+
self.final = nn.Sequential(
|
190
|
+
self._get_conv_layer(in_channels, out_channels, kernel_size, 1, True),
|
191
|
+
activation,
|
192
|
+
)
|
193
|
+
self.conv.apply(init_weights)
|
194
|
+
|
195
|
+
def _get_conv_layer(
|
196
|
+
self,
|
197
|
+
channels_in: int,
|
198
|
+
channels_out: int,
|
199
|
+
kernel_size: int,
|
200
|
+
dilation: int,
|
201
|
+
pad_gate: bool = False,
|
202
|
+
):
|
203
|
+
return weight_norm(
|
204
|
+
nn.Conv1d(
|
205
|
+
in_channels=channels_in,
|
206
|
+
out_channels=channels_out,
|
207
|
+
kernel_size=kernel_size,
|
208
|
+
stride=1,
|
209
|
+
dilation=dilation,
|
210
|
+
padding=(
|
211
|
+
int((kernel_size * dilation - dilation) / 2)
|
212
|
+
if not pad_gate
|
213
|
+
else int((kernel_size * 1 - 1) / 2)
|
214
|
+
),
|
215
|
+
)
|
216
|
+
)
|
217
|
+
|
218
|
+
def forward(self, x: Tensor, t_embed: Optional[Tensor] = None):
|
219
|
+
if t_embed is not None:
|
220
|
+
t_emb = self.time_proj(t_embed).unsqueeze(-1) # [B, C, 1]
|
221
|
+
|
222
|
+
for i, layer in enumerate(self.conv):
|
223
|
+
if t_embed is not None:
|
224
|
+
xt = layer["net"](x) + t_emb
|
225
|
+
else:
|
226
|
+
xt = layer["net"](x)
|
227
|
+
x = xt + x
|
228
|
+
x = layer["l_norm"](x)
|
229
|
+
return self.final(x)
|
230
|
+
|
231
|
+
def remove_weight_norm(self):
|
232
|
+
for module in self.modules():
|
233
|
+
try:
|
234
|
+
remove_weight_norm(module)
|
235
|
+
except ValueError:
|
236
|
+
pass # Not normed, skip
|
@@ -0,0 +1,132 @@
|
|
1
|
+
import math
|
2
|
+
from ..._torch_commons import *
|
3
|
+
from ..._basics import Model
|
4
|
+
from lt_utils.misc_utils import default
|
5
|
+
|
6
|
+
from .positional_encoders import *
|
7
|
+
from ..basic import FeedForward
|
8
|
+
|
9
|
+
|
10
|
+
def init_weights(module):
|
11
|
+
if isinstance(module, nn.Linear):
|
12
|
+
nn.init.xavier_uniform_(module.weight)
|
13
|
+
if module.bias is not None:
|
14
|
+
nn.init.constant_(module.bias, 0)
|
15
|
+
elif isinstance(module, nn.Embedding):
|
16
|
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
17
|
+
elif isinstance(module, nn.LayerNorm):
|
18
|
+
nn.init.constant_(module.bias, 0)
|
19
|
+
nn.init.constant_(module.weight, 1.0)
|
20
|
+
|
21
|
+
|
22
|
+
class TransformerEncoderLayer(nn.Module):
|
23
|
+
def __init__(self, d_model: int, n_heads: int, ff_size: int, dropout: float = 0.1):
|
24
|
+
super().__init__()
|
25
|
+
self.self_attn = nn.MultiheadAttention(
|
26
|
+
d_model, n_heads, dropout=dropout, batch_first=True
|
27
|
+
)
|
28
|
+
self.norm1 = nn.LayerNorm(d_model)
|
29
|
+
self.ff = FeedForward(d_model, ff_size, dropout)
|
30
|
+
self.norm2 = nn.LayerNorm(d_model)
|
31
|
+
self.dropout = nn.Dropout(dropout)
|
32
|
+
|
33
|
+
def forward(self, x: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
|
34
|
+
attn_output, _ = self.self_attn(x, x, x, attn_mask=src_mask)
|
35
|
+
x = self.norm1(x + self.dropout(attn_output))
|
36
|
+
ff_output = self.ff(x)
|
37
|
+
x = self.norm2(x + self.dropout(ff_output))
|
38
|
+
return x
|
39
|
+
|
40
|
+
|
41
|
+
class TransformerDecoderLayer(nn.Module):
|
42
|
+
def __init__(self, d_model: int, n_heads: int, ff_size: int, dropout: float = 0.1):
|
43
|
+
super().__init__()
|
44
|
+
self.self_attn = nn.MultiheadAttention(
|
45
|
+
d_model, n_heads, dropout=dropout, batch_first=True
|
46
|
+
)
|
47
|
+
self.norm1 = nn.LayerNorm(d_model)
|
48
|
+
|
49
|
+
self.cross_attn = nn.MultiheadAttention(
|
50
|
+
d_model, n_heads, dropout=dropout, batch_first=True
|
51
|
+
)
|
52
|
+
self.norm2 = nn.LayerNorm(d_model)
|
53
|
+
|
54
|
+
self.ff = FeedForward(d_model, ff_size, dropout)
|
55
|
+
self.norm3 = nn.LayerNorm(d_model)
|
56
|
+
self.dropout = nn.Dropout(dropout)
|
57
|
+
|
58
|
+
def forward(
|
59
|
+
self,
|
60
|
+
x: Tensor, # Decoder input [B, T, d_model]
|
61
|
+
encoder_out: Tensor, # Encoder output [B, S, d_model]
|
62
|
+
tgt_mask: Optional[Tensor] = None,
|
63
|
+
memory_mask: Optional[Tensor] = None,
|
64
|
+
) -> Tensor:
|
65
|
+
# 1. Masked Self-Attention
|
66
|
+
attn_output, _ = self.self_attn(x, x, x, attn_mask=tgt_mask)
|
67
|
+
x = self.norm1(x + self.dropout(attn_output))
|
68
|
+
|
69
|
+
# 2. Cross-Attention
|
70
|
+
cross_output, _ = self.cross_attn(
|
71
|
+
x, encoder_out, encoder_out, attn_mask=memory_mask
|
72
|
+
)
|
73
|
+
x = self.norm2(x + self.dropout(cross_output))
|
74
|
+
|
75
|
+
# 3. FeedForward
|
76
|
+
ff_output = self.ff(x)
|
77
|
+
x = self.norm3(x + self.dropout(ff_output))
|
78
|
+
return x
|
79
|
+
|
80
|
+
|
81
|
+
class TransformerEncoder(Model):
|
82
|
+
def __init__(
|
83
|
+
self,
|
84
|
+
d_model: int = 64,
|
85
|
+
n_heads: int = 4,
|
86
|
+
ff_size: int = 256,
|
87
|
+
num_layers: int = 2,
|
88
|
+
dropout: float = 0.1,
|
89
|
+
):
|
90
|
+
super().__init__()
|
91
|
+
self.layers = nn.ModuleList(
|
92
|
+
[
|
93
|
+
TransformerEncoderLayer(d_model, n_heads, ff_size, dropout)
|
94
|
+
for _ in range(num_layers)
|
95
|
+
]
|
96
|
+
)
|
97
|
+
|
98
|
+
def forward(self, x: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
|
99
|
+
|
100
|
+
for layer in self.layers:
|
101
|
+
x = layer(x, src_mask)
|
102
|
+
return x
|
103
|
+
|
104
|
+
|
105
|
+
class TransformerDecoder(Model):
|
106
|
+
def __init__(
|
107
|
+
self,
|
108
|
+
d_model: int = 64,
|
109
|
+
n_heads: int = 2,
|
110
|
+
ff_size: int = 256,
|
111
|
+
num_layers: int = 2,
|
112
|
+
dropout: float = 0.1,
|
113
|
+
):
|
114
|
+
super().__init__()
|
115
|
+
|
116
|
+
self.layers = nn.ModuleList(
|
117
|
+
[
|
118
|
+
TransformerDecoderLayer(d_model, n_heads, ff_size, dropout)
|
119
|
+
for _ in range(num_layers)
|
120
|
+
]
|
121
|
+
)
|
122
|
+
|
123
|
+
def forward(
|
124
|
+
self,
|
125
|
+
x: Tensor,
|
126
|
+
encoder_out: Tensor,
|
127
|
+
tgt_mask: Optional[Tensor] = None,
|
128
|
+
memory_mask: Optional[Tensor] = None,
|
129
|
+
) -> Tensor:
|
130
|
+
for layer in self.layers:
|
131
|
+
x = layer(x, encoder_out, tgt_mask, memory_mask)
|
132
|
+
return x
|
@@ -0,0 +1,95 @@
|
|
1
|
+
__all__ = ["RotaryEmbedding", "PositionalEncoding"]
|
2
|
+
import math
|
3
|
+
from ..._torch_commons import *
|
4
|
+
from ..._basics import Model
|
5
|
+
|
6
|
+
|
7
|
+
class RotaryEmbedding(Module):
|
8
|
+
def __init__(self, dim: int, base: int = 10000):
|
9
|
+
"""
|
10
|
+
Rotary Positional Embedding Module.
|
11
|
+
Args:
|
12
|
+
dim (int): The dimension of the rotary embedding (must be even).
|
13
|
+
base (int): The base frequency scale (default: 10000).
|
14
|
+
"""
|
15
|
+
super().__init__()
|
16
|
+
assert dim % 2 == 0, "Rotary dimension must be even"
|
17
|
+
self.dim = dim
|
18
|
+
self.base = base
|
19
|
+
|
20
|
+
def forward(self, x, seq_len=None):
|
21
|
+
"""
|
22
|
+
Apply rotary embeddings to input tensor.
|
23
|
+
Args:
|
24
|
+
x (torch.Tensor): Input tensor of shape [batch, seq_len, dim].
|
25
|
+
seq_len (int, optional): Override for sequence length.
|
26
|
+
Returns:
|
27
|
+
torch.Tensor: Tensor with rotary embeddings applied.
|
28
|
+
"""
|
29
|
+
bsz, seq_len = x.shape[0], seq_len or x.shape[1]
|
30
|
+
device = x.device
|
31
|
+
|
32
|
+
pos = torch.arange(seq_len, dtype=torch.float32, device=device).unsqueeze(1)
|
33
|
+
freqs = torch.pow(
|
34
|
+
self.base, -torch.arange(0, self.dim, 2, device=device).float() / self.dim
|
35
|
+
)
|
36
|
+
angle = pos * freqs # [seq_len, dim/2]
|
37
|
+
|
38
|
+
sin = torch.sin(angle)
|
39
|
+
cos = torch.cos(angle)
|
40
|
+
|
41
|
+
# Expand and interleave to [seq_len, dim]
|
42
|
+
sin = torch.stack((sin, sin), dim=-1).reshape(seq_len, self.dim)
|
43
|
+
cos = torch.stack((cos, cos), dim=-1).reshape(seq_len, self.dim)
|
44
|
+
|
45
|
+
sin = sin.unsqueeze(0).expand(bsz, -1, -1) # [batch, seq_len, dim]
|
46
|
+
cos = cos.unsqueeze(0).expand(bsz, -1, -1)
|
47
|
+
|
48
|
+
return self.apply_rotary(x, sin, cos)
|
49
|
+
|
50
|
+
def _apply_rotary(self, x, sin, cos):
|
51
|
+
"""This version may still be useful, but for now its the problem for the text model"""
|
52
|
+
x1, x2 = x.chunk(2, dim=-1)
|
53
|
+
return torch.cat((x1 * cos - x2 * sin, x2 * cos + x1 * sin), dim=-1)
|
54
|
+
|
55
|
+
def apply_rotary(self, x, sin, cos):
|
56
|
+
"""x: [batch, seq_len, dim] → assume dim is even"""
|
57
|
+
b, s, d = x.shape
|
58
|
+
x = x.view(b, s, d // 2, 2) # [b, s, d//2, 2]
|
59
|
+
sin = sin.view(b, s, d // 2, 2)
|
60
|
+
cos = cos.view(b, s, d // 2, 2)
|
61
|
+
|
62
|
+
# Apply rotation: even, odd = x[..., 0], x[..., 1]
|
63
|
+
x_rotated = torch.stack(
|
64
|
+
[
|
65
|
+
x[..., 0] * cos[..., 0] - x[..., 1] * sin[..., 0],
|
66
|
+
x[..., 0] * sin[..., 0] + x[..., 1] * cos[..., 0],
|
67
|
+
],
|
68
|
+
dim=-1,
|
69
|
+
)
|
70
|
+
|
71
|
+
return x_rotated.view(b, s, d) # Back to [b, s, d]
|
72
|
+
|
73
|
+
|
74
|
+
class PositionalEncoding(Module):
|
75
|
+
def __init__(self, d_model: int, max_len: int = 8192):
|
76
|
+
super().__init__()
|
77
|
+
# create a matrix of [seq_len, hidden_dim] representing positional encoding for each token in sequence
|
78
|
+
pe = torch.zeros(max_len, d_model)
|
79
|
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(
|
80
|
+
1
|
81
|
+
) # (max_len, 1)
|
82
|
+
div_term = torch.exp(
|
83
|
+
torch.arange(0, d_model, 2, dtype=torch.float)
|
84
|
+
* (-math.log(10000.0) / d_model)
|
85
|
+
)
|
86
|
+
pe[:, 0::2] = torch.sin(position * div_term)
|
87
|
+
pe[:, 1::2] = torch.cos(position * div_term)
|
88
|
+
pe = pe.unsqueeze(0)
|
89
|
+
self.register_buffer("pe", pe, persistent=False) # Shape: (1, max_len, d_model)
|
90
|
+
|
91
|
+
def forward(self, x: Tensor, seq_len: Optional[Tensor] = None):
|
92
|
+
# x shape: (batch_size, seq_len, d_model)
|
93
|
+
s_sz = seq_len or x.size(1)
|
94
|
+
x = x + self.pe[:, :s_sz]
|
95
|
+
return x
|
@@ -0,0 +1,70 @@
|
|
1
|
+
from numba import njit, prange
|
2
|
+
|
3
|
+
|
4
|
+
@njit()
|
5
|
+
def maximum_path_each(path, value, t_x, t_y, max_neg_val):
|
6
|
+
index = t_x - 1
|
7
|
+
# Forward pass: Calculate max path sums
|
8
|
+
for y in range(t_y):
|
9
|
+
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
10
|
+
v_cur = max_neg_val if x == y else value[x, y - 1]
|
11
|
+
v_prev = (
|
12
|
+
0.0
|
13
|
+
if (x == 0 and y == 0)
|
14
|
+
else (max_neg_val if x == 0 else value[x - 1, y - 1])
|
15
|
+
)
|
16
|
+
value[x, y] = max(v_cur, v_prev) + value[x, y]
|
17
|
+
|
18
|
+
# Backtrack to store the path
|
19
|
+
for y in range(t_y - 1, -1, -1):
|
20
|
+
path[index, y] = 1
|
21
|
+
if index != 0 and (index == y or value[index, y - 1] < value[index - 1, y - 1]):
|
22
|
+
index -= 1
|
23
|
+
|
24
|
+
|
25
|
+
@njit() # Took almost 10x the time while testing using "parallel=True".
|
26
|
+
def maximum_path(paths, values, t_xs, t_ys, max_neg_val=-1e9):
|
27
|
+
"""
|
28
|
+
Example:
|
29
|
+
```python
|
30
|
+
paths = tc.randn((2, 3, 3)).numpy()
|
31
|
+
values = tc.randn((2, 3, 3)).numpy()
|
32
|
+
t_xs = tc.tensor([3, 3, 3]).numpy()
|
33
|
+
t_ys = tc.tensor([3, 3]).numpy()
|
34
|
+
|
35
|
+
# to display values (before) and paths:
|
36
|
+
print("=====================")
|
37
|
+
print("Paths:")
|
38
|
+
print(paths)
|
39
|
+
print("Original Values:")
|
40
|
+
print(values)
|
41
|
+
|
42
|
+
maximum_path(paths, values, t_xs, t_ys)
|
43
|
+
|
44
|
+
print("Updated Values:")
|
45
|
+
print(values)
|
46
|
+
print("=====================")
|
47
|
+
|
48
|
+
```
|
49
|
+
Outputs:
|
50
|
+
```md
|
51
|
+
=====================
|
52
|
+
Paths:
|
53
|
+
[[[ 2.310408 -1.9375949 -0.57884663]
|
54
|
+
[ 1.0308106 1.0793993 0.4461908 ]
|
55
|
+
[ 0.26789713 0.48924422 0.3409592 ]]]
|
56
|
+
Original Values:
|
57
|
+
[[[-0.48256454 0.51348686 -1.8236492 ]
|
58
|
+
[ 0.9949021 -0.6066166 0.18991096]
|
59
|
+
[ 1.2555764 -0.24222293 -0.78757876]]]
|
60
|
+
Updated Values:
|
61
|
+
[[[-0.48256454 0.51348686 -1.8236492 ]
|
62
|
+
[ 0.9949021 -1.0891812 0.18991096]
|
63
|
+
[ 1.2555764 -0.24222293 -1.87676 ]]]
|
64
|
+
=====================
|
65
|
+
```
|
66
|
+
This may not be the standard, but may work for your project.
|
67
|
+
"""
|
68
|
+
batch_size = values.shape[0]
|
69
|
+
for i in prange(batch_size):
|
70
|
+
maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
|