x-transformers 1.31.4__tar.gz → 1.31.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (20) hide show
  1. {x_transformers-1.31.4/x_transformers.egg-info → x_transformers-1.31.5}/PKG-INFO +1 -1
  2. {x_transformers-1.31.4 → x_transformers-1.31.5}/setup.py +1 -1
  3. {x_transformers-1.31.4 → x_transformers-1.31.5}/tests/test_x_transformers.py +17 -1
  4. {x_transformers-1.31.4 → x_transformers-1.31.5}/x_transformers/attend.py +8 -6
  5. {x_transformers-1.31.4 → x_transformers-1.31.5}/x_transformers/x_transformers.py +3 -1
  6. {x_transformers-1.31.4 → x_transformers-1.31.5/x_transformers.egg-info}/PKG-INFO +1 -1
  7. {x_transformers-1.31.4 → x_transformers-1.31.5}/LICENSE +0 -0
  8. {x_transformers-1.31.4 → x_transformers-1.31.5}/README.md +0 -0
  9. {x_transformers-1.31.4 → x_transformers-1.31.5}/setup.cfg +0 -0
  10. {x_transformers-1.31.4 → x_transformers-1.31.5}/x_transformers/__init__.py +0 -0
  11. {x_transformers-1.31.4 → x_transformers-1.31.5}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.31.4 → x_transformers-1.31.5}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.31.4 → x_transformers-1.31.5}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.31.4 → x_transformers-1.31.5}/x_transformers/nonautoregressive_wrapper.py +0 -0
  15. {x_transformers-1.31.4 → x_transformers-1.31.5}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  16. {x_transformers-1.31.4 → x_transformers-1.31.5}/x_transformers/xval.py +0 -0
  17. {x_transformers-1.31.4 → x_transformers-1.31.5}/x_transformers.egg-info/SOURCES.txt +0 -0
  18. {x_transformers-1.31.4 → x_transformers-1.31.5}/x_transformers.egg-info/dependency_links.txt +0 -0
  19. {x_transformers-1.31.4 → x_transformers-1.31.5}/x_transformers.egg-info/requires.txt +0 -0
  20. {x_transformers-1.31.4 → x_transformers-1.31.5}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.31.4
3
+ Version: 1.31.5
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.31.4',
6
+ version = '1.31.5',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -1,5 +1,4 @@
1
1
  import torch
2
-
3
2
  from x_transformers.x_transformers import (
4
3
  XTransformer,
5
4
  TransformerWrapper,
@@ -111,3 +110,20 @@ def test_adaptive_rmsnorm():
111
110
  condition = torch.randn(2, 768)
112
111
 
113
112
  model(x, condition = condition)
113
+
114
+ def test_attn_softclamp_logits():
115
+ model = TransformerWrapper(
116
+ num_tokens = 20000,
117
+ max_seq_len = 1024,
118
+ attn_layers = Decoder(
119
+ dim = 512,
120
+ dim_condition = 768,
121
+ depth = 12,
122
+ heads = 8,
123
+ attn_softclamp_logits = True,
124
+ )
125
+ )
126
+
127
+ x = torch.randint(0, 256, (1, 1024))
128
+
129
+ model(x)
@@ -69,7 +69,7 @@ def onnx_create_causal_mask(i, j, device):
69
69
 
70
70
  # main class
71
71
 
72
- class Attend(nn.Module):
72
+ class Attend(Module):
73
73
  def __init__(
74
74
  self,
75
75
  *,
@@ -81,7 +81,8 @@ class Attend(nn.Module):
81
81
  scale = None,
82
82
  qk_norm = False,
83
83
  flash = False,
84
- logit_softclamp_value = None,
84
+ softclamp_logits = False,
85
+ logit_softclamp_value = 30.,
85
86
  add_zero_kv = False,
86
87
  cope = None,
87
88
  onnxable = False,
@@ -123,10 +124,11 @@ class Attend(nn.Module):
123
124
 
124
125
  # soft clamp attention logit value
125
126
 
126
- if exists(logit_softclamp_value):
127
+ if softclamp_logits:
127
128
  assert not flash, 'flash attention not compatible with logit softclamp value yet'
128
129
  assert logit_softclamp_value > 0.
129
130
 
131
+ self.softclamp_logits = softclamp_logits
130
132
  self.logit_softclamp_value = logit_softclamp_value
131
133
 
132
134
  # contextual positional encoding
@@ -308,6 +310,9 @@ class Attend(nn.Module):
308
310
  if exists(attn_bias):
309
311
  sim = sim + attn_bias
310
312
 
313
+ if self.softclamp_logits:
314
+ sim = softclamp(sim, self.logit_softclamp_value)
315
+
311
316
  i, j, dtype = *sim.shape[-2:], sim.dtype
312
317
 
313
318
  mask_value = -torch.finfo(sim.dtype).max
@@ -329,9 +334,6 @@ class Attend(nn.Module):
329
334
 
330
335
  pre_softmax_attn = sim.clone()
331
336
 
332
- if exists(self.logit_softclamp_value):
333
- sim = softclamp(sim, self.logit_softclamp_value)
334
-
335
337
  attn = self.attn_fn(sim, dim = -1)
336
338
  attn = attn.type(dtype)
337
339
 
@@ -884,7 +884,8 @@ class Attention(Module):
884
884
  cope_max_pos = 16,
885
885
  cope_soft_onehot_pos = False,
886
886
  cope_talking_heads = False,
887
- logit_softclamp_value = None,
887
+ softclamp_logits = False,
888
+ logit_softclamp_value = 30.,
888
889
  onnxable = False
889
890
  ):
890
891
  super().__init__()
@@ -987,6 +988,7 @@ class Attention(Module):
987
988
  scale = qk_norm_scale if qk_norm else self.scale,
988
989
  add_zero_kv = add_zero_kv,
989
990
  flash = flash,
991
+ softclamp_logits = softclamp_logits,
990
992
  logit_softclamp_value = logit_softclamp_value,
991
993
  cope = cope,
992
994
  onnxable = onnxable
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.31.4
3
+ Version: 1.31.5
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes