x-transformers 1.39.1__py3-none-any.whl → 1.39.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -107,6 +107,15 @@ def qk_l2_dist_squared(q, k):
107
107
  l2_dist_squared = torch.cdist(q, k) ** 2
108
108
  return unpack_one(l2_dist_squared, packed_shape, '* i j')
109
109
 
110
+ # gumbel softmax
111
+
112
+ def gumbel_softmax(t, temperature = 1.):
113
+ one_hot_indices = t.argmax(dim = -1, keepdim = True)
114
+ one_hot = torch.zeros_like(t).scatter(-1, one_hot_indices, 1.)
115
+
116
+ t = (t / temperature).softmax(dim = -1)
117
+ return one_hot + t - t.detach()
118
+
110
119
  # functions for creating causal mask
111
120
  # need a special one for onnx cpu (no support for .triu)
112
121
 
@@ -142,6 +151,7 @@ class Attend(Module):
142
151
  logit_softclamp_value = 50.,
143
152
  add_zero_kv = False,
144
153
  selective = False,
154
+ hard = False,
145
155
  sigsoftmax = False,
146
156
  cope = None,
147
157
  onnxable = False,
@@ -162,17 +172,18 @@ class Attend(Module):
162
172
  # attention type
163
173
 
164
174
  assert not (flash and sigmoid), 'sigmoid attention not available for flash'
165
- assert at_most_one_of(sigmoid, l2_distance)
166
-
167
- self.sigmoid = sigmoid
175
+ assert not (flash and hard), 'hard attention not available for flash'
176
+ assert at_most_one_of(sigmoid, hard, l2_distance)
168
177
 
169
178
  if exists(custom_attn_fn):
170
179
  self.attn_fn = custom_attn_fn
171
- elif not sigmoid:
180
+ elif sigmoid:
181
+ self.attn_fn = F.sigmoid
182
+ elif hard:
183
+ self.attn_fn = gumbel_softmax
184
+ else:
172
185
  softmax_fn = partial(F.softmax, dim = -1)
173
186
  self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
174
- else:
175
- self.attn_fn = F.sigmoid
176
187
 
177
188
  # dropouts
178
189
 
@@ -487,6 +498,7 @@ class Attend(Module):
487
498
  sim = sim + sim.sigmoid().log()
488
499
 
489
500
  attn = self.attn_fn(sim)
501
+
490
502
  attn = attn.type(dtype)
491
503
 
492
504
  post_softmax_attn = attn
@@ -919,6 +919,7 @@ class Attention(Module):
919
919
  swiglu_values = False,
920
920
  gate_values = False,
921
921
  zero_init_output = False,
922
+ hard = False,
922
923
  sigsoftmax = False,
923
924
  max_attend_past = None,
924
925
  qk_norm = False,
@@ -1043,6 +1044,7 @@ class Attention(Module):
1043
1044
  pre_scale_post_talking_heads = pre_scale_post_talking_heads,
1044
1045
  dropout = dropout,
1045
1046
  sparse_topk = sparse_topk,
1047
+ hard = hard,
1046
1048
  qk_norm = qk_norm,
1047
1049
  scale = qk_norm_scale if qk_norm else self.scale,
1048
1050
  l2_distance = l2_distance,
@@ -1612,6 +1614,7 @@ class AttentionLayers(Module):
1612
1614
  layer = post_branch_fn(layer)
1613
1615
 
1614
1616
  residual_fn = GRUGating if gate_residual else Residual
1617
+
1615
1618
  residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1616
1619
 
1617
1620
  # handle unet skip connection
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.39.1
3
+ Version: 1.39.3
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=JlWwEzigY5sl9ktDwDWWQD9np9uPRoj2eRo9XU6tJc0,16273
2
+ x_transformers/attend.py,sha256=K4DNcWAgOHzNtOTIJA15VA3VQ2KMyv-PX8oO1R0Z5Rw,16670
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=8ZQR6OLT4vusIjJXzrdSp12Fydmmpcc2t5cDE6SxPNc,84460
8
+ x_transformers/x_transformers.py,sha256=3JrVqwYVrd5UVf2esdunTcer7QL72H7VF4mL3UsCWOI,84508
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.39.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.39.1.dist-info/METADATA,sha256=KJfw4hIDzozyRlsnBaqbbfcFSXygbv1y2B_6cl9pu-4,661
13
- x_transformers-1.39.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.39.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.39.1.dist-info/RECORD,,
11
+ x_transformers-1.39.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.39.3.dist-info/METADATA,sha256=aqCockVfuZfLM3D9ZdlgN_HfUPdhZm4UWyKg2HLkuUo,661
13
+ x_transformers-1.39.3.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.39.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.39.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.1.0)
2
+ Generator: setuptools (75.2.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5