x-transformers 1.42.19__py3-none-any.whl → 1.42.21__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
x_transformers/attend.py CHANGED
@@ -370,7 +370,7 @@ class Attend(Module):
370
370
  # convert from bool to float
371
371
 
372
372
  if exists(attn_bias):
373
- attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)
373
+ attn_bias = attn_bias.expand(batch, heads, -1, -1)
374
374
 
375
375
  # if mask given, the mask would already contain the causal mask from above logic
376
376
  # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
@@ -655,7 +655,10 @@ class RotaryEmbedding(Module):
655
655
  def forward(self, t):
656
656
  max_pos = t.max() + 1
657
657
 
658
- freqs = torch.einsum('i , j -> i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
658
+ if t.ndim == 1:
659
+ t = rearrange(t, 'n -> 1 n')
660
+
661
+ freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
659
662
  freqs = torch.stack((freqs, freqs), dim = -1)
660
663
  freqs = rearrange(freqs, '... d r -> ... (d r)')
661
664
 
@@ -679,8 +682,8 @@ def rotate_half(x):
679
682
  def apply_rotary_pos_emb(t, freqs, scale = 1):
680
683
  rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
681
684
 
682
- freqs = freqs[-seq_len:, :]
683
- scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale
685
+ freqs = freqs[:, -seq_len:, :]
686
+ scale = scale[:, -seq_len:, :] if isinstance(scale, torch.Tensor) else scale
684
687
 
685
688
  if t.ndim == 4 and freqs.ndim == 3:
686
689
  freqs = rearrange(freqs, 'b n d -> b 1 n d')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.19
3
+ Version: 1.42.21
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,16 +1,16 @@
1
1
  x_transformers/__init__.py,sha256=l0dom8ZYkRzFvnDdgzDboXqrI1tKav3beVE7TN2nHko,844
2
- x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
2
+ x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
4
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
5
5
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=pDYtIGhoo-lFn_ULJETnQz1Z0QYuDsD4ReTlPy__jwo,95993
9
+ x_transformers/x_transformers.py,sha256=CCXzE-lRhcKoymSHeRO_ZvL1XMOhR3YlVo3Ovyk5BZw,96069
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.42.19.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.42.19.dist-info/METADATA,sha256=pJgi1Jp7FvM1o_x3a7uOaSJ8x0pNgIQnAp4lSI3K__o,739
14
- x_transformers-1.42.19.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
- x_transformers-1.42.19.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.42.19.dist-info/RECORD,,
12
+ x_transformers-1.42.21.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.42.21.dist-info/METADATA,sha256=UEU2DHiMgksiQW5Ks9Mfvsup-d58vP_sR6DjLE8PSTQ,739
14
+ x_transformers-1.42.21.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
+ x_transformers-1.42.21.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.42.21.dist-info/RECORD,,