x-transformers 1.42.20__py3-none-any.whl → 1.42.21__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -655,7 +655,10 @@ class RotaryEmbedding(Module):
655
655
  def forward(self, t):
656
656
  max_pos = t.max() + 1
657
657
 
658
- freqs = torch.einsum('i , j -> i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
658
+ if t.ndim == 1:
659
+ t = rearrange(t, 'n -> 1 n')
660
+
661
+ freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
659
662
  freqs = torch.stack((freqs, freqs), dim = -1)
660
663
  freqs = rearrange(freqs, '... d r -> ... (d r)')
661
664
 
@@ -679,8 +682,8 @@ def rotate_half(x):
679
682
  def apply_rotary_pos_emb(t, freqs, scale = 1):
680
683
  rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
681
684
 
682
- freqs = freqs[-seq_len:, :]
683
- scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale
685
+ freqs = freqs[:, -seq_len:, :]
686
+ scale = scale[:, -seq_len:, :] if isinstance(scale, torch.Tensor) else scale
684
687
 
685
688
  if t.ndim == 4 and freqs.ndim == 3:
686
689
  freqs = rearrange(freqs, 'b n d -> b 1 n d')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.20
3
+ Version: 1.42.21
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=pDYtIGhoo-lFn_ULJETnQz1Z0QYuDsD4ReTlPy__jwo,95993
9
+ x_transformers/x_transformers.py,sha256=CCXzE-lRhcKoymSHeRO_ZvL1XMOhR3YlVo3Ovyk5BZw,96069
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.42.20.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.42.20.dist-info/METADATA,sha256=J0yBEg7oUfbkJaC3WxfB9Oq4XbGxXA5VjUGd9AHELGk,739
14
- x_transformers-1.42.20.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
- x_transformers-1.42.20.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.42.20.dist-info/RECORD,,
12
+ x_transformers-1.42.21.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.42.21.dist-info/METADATA,sha256=UEU2DHiMgksiQW5Ks9Mfvsup-d58vP_sR6DjLE8PSTQ,739
14
+ x_transformers-1.42.21.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
+ x_transformers-1.42.21.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.42.21.dist-info/RECORD,,