x-transformers 1.31.4__py3-none-any.whl → 1.31.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -69,7 +69,7 @@ def onnx_create_causal_mask(i, j, device):
69
69
 
70
70
  # main class
71
71
 
72
- class Attend(nn.Module):
72
+ class Attend(Module):
73
73
  def __init__(
74
74
  self,
75
75
  *,
@@ -81,7 +81,8 @@ class Attend(nn.Module):
81
81
  scale = None,
82
82
  qk_norm = False,
83
83
  flash = False,
84
- logit_softclamp_value = None,
84
+ softclamp_logits = False,
85
+ logit_softclamp_value = 30.,
85
86
  add_zero_kv = False,
86
87
  cope = None,
87
88
  onnxable = False,
@@ -123,10 +124,11 @@ class Attend(nn.Module):
123
124
 
124
125
  # soft clamp attention logit value
125
126
 
126
- if exists(logit_softclamp_value):
127
+ if softclamp_logits:
127
128
  assert not flash, 'flash attention not compatible with logit softclamp value yet'
128
129
  assert logit_softclamp_value > 0.
129
130
 
131
+ self.softclamp_logits = softclamp_logits
130
132
  self.logit_softclamp_value = logit_softclamp_value
131
133
 
132
134
  # contextual positional encoding
@@ -308,6 +310,9 @@ class Attend(nn.Module):
308
310
  if exists(attn_bias):
309
311
  sim = sim + attn_bias
310
312
 
313
+ if self.softclamp_logits:
314
+ sim = softclamp(sim, self.logit_softclamp_value)
315
+
311
316
  i, j, dtype = *sim.shape[-2:], sim.dtype
312
317
 
313
318
  mask_value = -torch.finfo(sim.dtype).max
@@ -329,9 +334,6 @@ class Attend(nn.Module):
329
334
 
330
335
  pre_softmax_attn = sim.clone()
331
336
 
332
- if exists(self.logit_softclamp_value):
333
- sim = softclamp(sim, self.logit_softclamp_value)
334
-
335
337
  attn = self.attn_fn(sim, dim = -1)
336
338
  attn = attn.type(dtype)
337
339
 
@@ -884,7 +884,8 @@ class Attention(Module):
884
884
  cope_max_pos = 16,
885
885
  cope_soft_onehot_pos = False,
886
886
  cope_talking_heads = False,
887
- logit_softclamp_value = None,
887
+ softclamp_logits = False,
888
+ logit_softclamp_value = 30.,
888
889
  onnxable = False
889
890
  ):
890
891
  super().__init__()
@@ -987,6 +988,7 @@ class Attention(Module):
987
988
  scale = qk_norm_scale if qk_norm else self.scale,
988
989
  add_zero_kv = add_zero_kv,
989
990
  flash = flash,
991
+ softclamp_logits = softclamp_logits,
990
992
  logit_softclamp_value = logit_softclamp_value,
991
993
  cope = cope,
992
994
  onnxable = onnxable
@@ -1265,7 +1267,8 @@ class AttentionLayers(Module):
1265
1267
  scale_residual_constant = 1.,
1266
1268
  shift_tokens = 0,
1267
1269
  sandwich_norm = False,
1268
- softclamp_output_value: float | None = None,
1270
+ softclamp_output = False,
1271
+ softclamp_output_value = 50.,
1269
1272
  resi_dual = False,
1270
1273
  resi_dual_scale = 1.,
1271
1274
  zero_init_branch_output = False,
@@ -1482,6 +1485,7 @@ class AttentionLayers(Module):
1482
1485
  # optional soft clamping just before the final norm
1483
1486
  # used in gemma 2
1484
1487
 
1488
+ self.softclamp_output = softclamp_output
1485
1489
  self.softclamp_output_value = softclamp_output_value
1486
1490
 
1487
1491
  # whether it has post norm
@@ -1715,7 +1719,7 @@ class AttentionLayers(Module):
1715
1719
  if return_hiddens:
1716
1720
  layer_hiddens.append(x)
1717
1721
 
1718
- if exists(self.softclamp_output_value):
1722
+ if self.softclamp_output:
1719
1723
  x = softclamp(x, self.softclamp_output_value)
1720
1724
 
1721
1725
  final_norm = self.final_norm
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.31.4
3
+ Version: 1.31.6
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,14 +1,14 @@
1
1
  x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
2
- x_transformers/attend.py,sha256=ap2QkD-bRadFE9ZFQP84Lo1P2DpLOXPam24Jq9ybpPY,10903
2
+ x_transformers/attend.py,sha256=UWq0bElvJf-_j1N2QbJ2yg28xkWlhnOrLjMJt3If3ao,10956
3
3
  x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=82gdvdwWWDWZVXTQRE6B4W-RQspg4SXhD8B6BV5pOo0,75789
7
+ x_transformers/x_transformers.py,sha256=xPwzR3bd8BS_ChEcz0UxsNtx99u4UbP8jg1fFIRDGUw,75925
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.31.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.31.4.dist-info/METADATA,sha256=USK7uPjCATSS4URgfekTTmPhOtvHagHC02b-R-gxwME,661
12
- x_transformers-1.31.4.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
13
- x_transformers-1.31.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.31.4.dist-info/RECORD,,
10
+ x_transformers-1.31.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.31.6.dist-info/METADATA,sha256=O2MZXNuX-jqrAdpcxkIDH4J0T63t5nt9utI8FHrRIA0,661
12
+ x_transformers-1.31.6.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
13
+ x_transformers-1.31.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.31.6.dist-info/RECORD,,