x-transformers 1.43.1__tar.gz → 1.43.4__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.43.1/x_transformers.egg-info → x_transformers-1.43.4}/PKG-INFO +1 -1
  2. {x_transformers-1.43.1 → x_transformers-1.43.4}/setup.py +1 -1
  3. {x_transformers-1.43.1 → x_transformers-1.43.4}/tests/test_x_transformers.py +4 -2
  4. {x_transformers-1.43.1 → x_transformers-1.43.4}/x_transformers/x_transformers.py +6 -6
  5. {x_transformers-1.43.1 → x_transformers-1.43.4/x_transformers.egg-info}/PKG-INFO +1 -1
  6. {x_transformers-1.43.1 → x_transformers-1.43.4}/LICENSE +0 -0
  7. {x_transformers-1.43.1 → x_transformers-1.43.4}/README.md +0 -0
  8. {x_transformers-1.43.1 → x_transformers-1.43.4}/setup.cfg +0 -0
  9. {x_transformers-1.43.1 → x_transformers-1.43.4}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.43.1 → x_transformers-1.43.4}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.43.1 → x_transformers-1.43.4}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.43.1 → x_transformers-1.43.4}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.43.1 → x_transformers-1.43.4}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.43.1 → x_transformers-1.43.4}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.43.1 → x_transformers-1.43.4}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.43.1 → x_transformers-1.43.4}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.43.1 → x_transformers-1.43.4}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.43.1 → x_transformers-1.43.4}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.43.1 → x_transformers-1.43.4}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.43.1 → x_transformers-1.43.4}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.43.1 → x_transformers-1.43.4}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.43.1 → x_transformers-1.43.4}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.43.1
3
+ Version: 1.43.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.43.1',
6
+ version = '1.43.4',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -409,7 +409,8 @@ def test_custom_alibi(flash: bool):
409
409
 
410
410
  logits = model(x, pos = pos)
411
411
 
412
- def test_custom_rotary_pos_emb():
412
+ @pytest.mark.parametrize('rotary_xpos', (True, False))
413
+ def test_custom_rotary_pos_emb(rotary_xpos):
413
414
  from einops import repeat
414
415
 
415
416
  model = TransformerWrapper(
@@ -419,7 +420,8 @@ def test_custom_rotary_pos_emb():
419
420
  dim = 512,
420
421
  depth = 2,
421
422
  heads = 8,
422
- rotary_pos_emb = True
423
+ rotary_pos_emb = True,
424
+ rotary_xpos = rotary_xpos
423
425
  )
424
426
  )
425
427
 
@@ -666,7 +666,7 @@ class RotaryEmbedding(Module):
666
666
  return freqs, 1.
667
667
 
668
668
  power = (t - (max_pos // 2)) / self.scale_base
669
- scale = self.scale ** rearrange(power, 'n -> n 1')
669
+ scale = self.scale ** rearrange(power, '... n -> ... n 1')
670
670
  scale = torch.stack((scale, scale), dim = -1)
671
671
  scale = rearrange(scale, '... d r -> ... (d r)')
672
672
 
@@ -2270,16 +2270,16 @@ class AttentionLayers(Module):
2270
2270
  if self.need_condition:
2271
2271
  final_norm = maybe(partial)(final_norm, **norm_kwargs)
2272
2272
 
2273
- if self.resi_dual:
2274
- x = x + final_norm(outer_residual)
2275
- else:
2276
- x = final_norm(x)
2277
-
2278
2273
  # take care of multistreams if needed, use sum for now
2279
2274
 
2280
2275
  if is_multistream:
2281
2276
  x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)
2282
2277
 
2278
+ if self.resi_dual:
2279
+ x = x + final_norm(outer_residual)
2280
+ else:
2281
+ x = final_norm(x)
2282
+
2283
2283
  if not return_hiddens:
2284
2284
  return x
2285
2285
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.43.1
3
+ Version: 1.43.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes