x-transformers 1.42.19__py3-none-any.whl → 1.42.21__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/attend.py +1 -1
- x_transformers/x_transformers.py +6 -3
- {x_transformers-1.42.19.dist-info → x_transformers-1.42.21.dist-info}/METADATA +1 -1
- {x_transformers-1.42.19.dist-info → x_transformers-1.42.21.dist-info}/RECORD +7 -7
- {x_transformers-1.42.19.dist-info → x_transformers-1.42.21.dist-info}/LICENSE +0 -0
- {x_transformers-1.42.19.dist-info → x_transformers-1.42.21.dist-info}/WHEEL +0 -0
- {x_transformers-1.42.19.dist-info → x_transformers-1.42.21.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -370,7 +370,7 @@ class Attend(Module):
|
|
370
370
|
# convert from bool to float
|
371
371
|
|
372
372
|
if exists(attn_bias):
|
373
|
-
attn_bias =
|
373
|
+
attn_bias = attn_bias.expand(batch, heads, -1, -1)
|
374
374
|
|
375
375
|
# if mask given, the mask would already contain the causal mask from above logic
|
376
376
|
# otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
|
x_transformers/x_transformers.py
CHANGED
@@ -655,7 +655,10 @@ class RotaryEmbedding(Module):
|
|
655
655
|
def forward(self, t):
|
656
656
|
max_pos = t.max() + 1
|
657
657
|
|
658
|
-
|
658
|
+
if t.ndim == 1:
|
659
|
+
t = rearrange(t, 'n -> 1 n')
|
660
|
+
|
661
|
+
freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
|
659
662
|
freqs = torch.stack((freqs, freqs), dim = -1)
|
660
663
|
freqs = rearrange(freqs, '... d r -> ... (d r)')
|
661
664
|
|
@@ -679,8 +682,8 @@ def rotate_half(x):
|
|
679
682
|
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
680
683
|
rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
|
681
684
|
|
682
|
-
freqs = freqs[-seq_len:, :]
|
683
|
-
scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale
|
685
|
+
freqs = freqs[:, -seq_len:, :]
|
686
|
+
scale = scale[:, -seq_len:, :] if isinstance(scale, torch.Tensor) else scale
|
684
687
|
|
685
688
|
if t.ndim == 4 and freqs.ndim == 3:
|
686
689
|
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
@@ -1,16 +1,16 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=l0dom8ZYkRzFvnDdgzDboXqrI1tKav3beVE7TN2nHko,844
|
2
|
-
x_transformers/attend.py,sha256
|
2
|
+
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
4
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
5
5
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=CCXzE-lRhcKoymSHeRO_ZvL1XMOhR3YlVo3Ovyk5BZw,96069
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.42.
|
13
|
-
x_transformers-1.42.
|
14
|
-
x_transformers-1.42.
|
15
|
-
x_transformers-1.42.
|
16
|
-
x_transformers-1.42.
|
12
|
+
x_transformers-1.42.21.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.42.21.dist-info/METADATA,sha256=UEU2DHiMgksiQW5Ks9Mfvsup-d58vP_sR6DjLE8PSTQ,739
|
14
|
+
x_transformers-1.42.21.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
15
|
+
x_transformers-1.42.21.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.42.21.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|