x-transformers 1.42.20__py3-none-any.whl → 1.42.22__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +7 -4
- {x_transformers-1.42.20.dist-info → x_transformers-1.42.22.dist-info}/METADATA +1 -1
- {x_transformers-1.42.20.dist-info → x_transformers-1.42.22.dist-info}/RECORD +6 -6
- {x_transformers-1.42.20.dist-info → x_transformers-1.42.22.dist-info}/LICENSE +0 -0
- {x_transformers-1.42.20.dist-info → x_transformers-1.42.22.dist-info}/WHEEL +0 -0
- {x_transformers-1.42.20.dist-info → x_transformers-1.42.22.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -655,7 +655,10 @@ class RotaryEmbedding(Module):
|
|
655
655
|
def forward(self, t):
|
656
656
|
max_pos = t.max() + 1
|
657
657
|
|
658
|
-
|
658
|
+
if t.ndim == 1:
|
659
|
+
t = rearrange(t, 'n -> 1 n')
|
660
|
+
|
661
|
+
freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
|
659
662
|
freqs = torch.stack((freqs, freqs), dim = -1)
|
660
663
|
freqs = rearrange(freqs, '... d r -> ... (d r)')
|
661
664
|
|
@@ -679,8 +682,8 @@ def rotate_half(x):
|
|
679
682
|
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
680
683
|
rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
|
681
684
|
|
682
|
-
freqs = freqs[-seq_len:, :]
|
683
|
-
scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale
|
685
|
+
freqs = freqs[:, -seq_len:, :]
|
686
|
+
scale = scale[:, -seq_len:, :] if isinstance(scale, torch.Tensor) else scale
|
684
687
|
|
685
688
|
if t.ndim == 4 and freqs.ndim == 3:
|
686
689
|
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
@@ -1462,7 +1465,7 @@ class Attention(Module):
|
|
1462
1465
|
# laser
|
1463
1466
|
|
1464
1467
|
if self.laser:
|
1465
|
-
out =
|
1468
|
+
out = log(out) + values_max
|
1466
1469
|
|
1467
1470
|
# store the values for resformer or Neutreno
|
1468
1471
|
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=AtxLfcaVabAKJdJ9xOKVrATDcyjxG-tFXx6lg941WB8,96068
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.42.
|
13
|
-
x_transformers-1.42.
|
14
|
-
x_transformers-1.42.
|
15
|
-
x_transformers-1.42.
|
16
|
-
x_transformers-1.42.
|
12
|
+
x_transformers-1.42.22.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.42.22.dist-info/METADATA,sha256=M3wgytCy3B8zW_g2qUesrZAgzhZLBBy-60HjPDpHjNM,739
|
14
|
+
x_transformers-1.42.22.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
15
|
+
x_transformers-1.42.22.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.42.22.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|