x-transformers 2.11.10__py3-none-any.whl → 2.11.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

@@ -740,11 +740,14 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
740
740
  rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
741
741
 
742
742
  freqs = freqs[:, -seq_len:, :]
743
- scale = scale[:, -seq_len:, :] if isinstance(scale, torch.Tensor) else scale
743
+ scale = scale[:, -seq_len:, :] if is_tensor(scale) else scale
744
744
 
745
745
  if t.ndim == 4 and freqs.ndim == 3:
746
746
  freqs = rearrange(freqs, 'b n d -> b 1 n d')
747
747
 
748
+ if is_tensor(scale):
749
+ scale = rearrange(scale, 'b n d -> b 1 n d')
750
+
748
751
  # partial rotary embeddings, Wang et al. GPT-J
749
752
  t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
750
753
  t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
@@ -3438,6 +3441,7 @@ class TransformerWrapper(Module):
3438
3441
 
3439
3442
  kwargs = dict(
3440
3443
  **kwargs,
3444
+ pos = pos,
3441
3445
  seq_pos_offset = seq_pos_offset,
3442
3446
  seq_start_pos = seq_start_pos,
3443
3447
  input_not_include_cache = input_not_include_cache
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.10
3
+ Version: 2.11.12
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -11,10 +11,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
11
11
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
12
12
  x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
13
13
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
14
- x_transformers/x_transformers.py,sha256=bYnVtkcfr082ALprIGgYIUx53lLADGYpi9t6QEJp1Kc,126907
14
+ x_transformers/x_transformers.py,sha256=5ctPu8tvlbUMrtW360e_LPnoGv6xcgQFsyWdbvLo6Tk,127002
15
15
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
16
16
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
17
- x_transformers-2.11.10.dist-info/METADATA,sha256=xcHidmoWV-DKOo65NAd84GKA4kRQwSGMvWh69Rwh_w8,96012
18
- x_transformers-2.11.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
- x_transformers-2.11.10.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
20
- x_transformers-2.11.10.dist-info/RECORD,,
17
+ x_transformers-2.11.12.dist-info/METADATA,sha256=t1DVN4ub0rZHHhj9IMTqmnVTvYHWqnVW0_fNv77OGnU,96012
18
+ x_transformers-2.11.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
+ x_transformers-2.11.12.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
20
+ x_transformers-2.11.12.dist-info/RECORD,,