x-transformers 1.42.19__py3-none-any.whl → 1.42.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +1 -1
- x_transformers/x_transformers.py +6 -3
- {x_transformers-1.42.19.dist-info → x_transformers-1.42.21.dist-info}/METADATA +1 -1
- {x_transformers-1.42.19.dist-info → x_transformers-1.42.21.dist-info}/RECORD +7 -7
- {x_transformers-1.42.19.dist-info → x_transformers-1.42.21.dist-info}/LICENSE +0 -0
- {x_transformers-1.42.19.dist-info → x_transformers-1.42.21.dist-info}/WHEEL +0 -0
- {x_transformers-1.42.19.dist-info → x_transformers-1.42.21.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -370,7 +370,7 @@ class Attend(Module):
|
|
370
370
|
# convert from bool to float
|
371
371
|
|
372
372
|
if exists(attn_bias):
|
373
|
-
attn_bias =
|
373
|
+
attn_bias = attn_bias.expand(batch, heads, -1, -1)
|
374
374
|
|
375
375
|
# if mask given, the mask would already contain the causal mask from above logic
|
376
376
|
# otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
|
x_transformers/x_transformers.py
CHANGED
@@ -655,7 +655,10 @@ class RotaryEmbedding(Module):
|
|
655
655
|
def forward(self, t):
|
656
656
|
max_pos = t.max() + 1
|
657
657
|
|
658
|
-
|
658
|
+
if t.ndim == 1:
|
659
|
+
t = rearrange(t, 'n -> 1 n')
|
660
|
+
|
661
|
+
freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
|
659
662
|
freqs = torch.stack((freqs, freqs), dim = -1)
|
660
663
|
freqs = rearrange(freqs, '... d r -> ... (d r)')
|
661
664
|
|
@@ -679,8 +682,8 @@ def rotate_half(x):
|
|
679
682
|
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
680
683
|
rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
|
681
684
|
|
682
|
-
freqs = freqs[-seq_len:, :]
|
683
|
-
scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale
|
685
|
+
freqs = freqs[:, -seq_len:, :]
|
686
|
+
scale = scale[:, -seq_len:, :] if isinstance(scale, torch.Tensor) else scale
|
684
687
|
|
685
688
|
if t.ndim == 4 and freqs.ndim == 3:
|
686
689
|
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
@@ -1,16 +1,16 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=l0dom8ZYkRzFvnDdgzDboXqrI1tKav3beVE7TN2nHko,844
|
2
|
-
x_transformers/attend.py,sha256
|
2
|
+
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
4
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
5
5
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=CCXzE-lRhcKoymSHeRO_ZvL1XMOhR3YlVo3Ovyk5BZw,96069
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.42.
|
13
|
-
x_transformers-1.42.
|
14
|
-
x_transformers-1.42.
|
15
|
-
x_transformers-1.42.
|
16
|
-
x_transformers-1.42.
|
12
|
+
x_transformers-1.42.21.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.42.21.dist-info/METADATA,sha256=UEU2DHiMgksiQW5Ks9Mfvsup-d58vP_sR6DjLE8PSTQ,739
|
14
|
+
x_transformers-1.42.21.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
15
|
+
x_transformers-1.42.21.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.42.21.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|