x-transformers 1.39.0__py3-none-any.whl → 1.39.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -107,6 +107,15 @@ def qk_l2_dist_squared(q, k):
107
107
  l2_dist_squared = torch.cdist(q, k) ** 2
108
108
  return unpack_one(l2_dist_squared, packed_shape, '* i j')
109
109
 
110
+ # gumbel softmax
111
+
112
+ def gumbel_softmax(t, temperature = 1.):
113
+ one_hot_indices = t.argmax(dim = -1, keepdim = True)
114
+ one_hot = torch.zeros_like(t).scatter(-1, one_hot_indices, 1.)
115
+
116
+ t = (t / temperature).softmax(dim = -1)
117
+ return one_hot + t - t.detach()
118
+
110
119
  # functions for creating causal mask
111
120
  # need a special one for onnx cpu (no support for .triu)
112
121
 
@@ -128,8 +137,9 @@ class Attend(Module):
128
137
  dropout = 0.,
129
138
  causal = False,
130
139
  heads = None,
131
- pre_talking_heads = True,
132
- post_talking_heads = True,
140
+ pre_talking_heads = False,
141
+ post_talking_heads = False,
142
+ pre_scale_post_talking_heads = False,
133
143
  sparse_topk = None,
134
144
  scale = None,
135
145
  qk_norm = False,
@@ -141,6 +151,7 @@ class Attend(Module):
141
151
  logit_softclamp_value = 50.,
142
152
  add_zero_kv = False,
143
153
  selective = False,
154
+ hard = False,
144
155
  sigsoftmax = False,
145
156
  cope = None,
146
157
  onnxable = False,
@@ -161,17 +172,18 @@ class Attend(Module):
161
172
  # attention type
162
173
 
163
174
  assert not (flash and sigmoid), 'sigmoid attention not available for flash'
164
- assert at_most_one_of(sigmoid, l2_distance)
165
-
166
- self.sigmoid = sigmoid
175
+ assert not (flash and hard), 'hard attention not available for flash'
176
+ assert at_most_one_of(sigmoid, hard, l2_distance)
167
177
 
168
178
  if exists(custom_attn_fn):
169
179
  self.attn_fn = custom_attn_fn
170
- elif not sigmoid:
180
+ elif sigmoid:
181
+ self.attn_fn = F.sigmoid
182
+ elif hard:
183
+ self.attn_fn = gumbel_softmax
184
+ else:
171
185
  softmax_fn = partial(F.softmax, dim = -1)
172
186
  self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
173
- else:
174
- self.attn_fn = F.sigmoid
175
187
 
176
188
  # dropouts
177
189
 
@@ -180,10 +192,11 @@ class Attend(Module):
180
192
 
181
193
  # talking heads
182
194
 
183
- assert not (flash and (pre_talking_heads or post_talking_heads)), 'talking heads not compatible with flash attention'
195
+ assert not (flash and (pre_talking_heads or post_talking_heads or pre_scale_post_talking_heads)), 'talking heads not compatible with flash attention'
184
196
 
185
197
  self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None
186
198
  self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None
199
+ self.pre_scale_post_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_scale_post_talking_heads else None
187
200
 
188
201
  if exists(self.pre_softmax_talking_heads):
189
202
  nn.init.dirac_(self.pre_softmax_talking_heads.weight)
@@ -191,6 +204,10 @@ class Attend(Module):
191
204
  if exists(self.post_softmax_talking_heads):
192
205
  nn.init.dirac_(self.post_softmax_talking_heads.weight)
193
206
 
207
+ if exists(self.pre_scale_post_talking_heads):
208
+ # an improvisation where heads are combined pre-softmax attention, then used to scale post-softmax attention
209
+ nn.init.dirac_(self.pre_scale_post_talking_heads.weight)
210
+
194
211
  # selective attention
195
212
 
196
213
  assert not (flash and selective), 'selective attention cannot work on flash attention'
@@ -436,6 +453,9 @@ class Attend(Module):
436
453
 
437
454
  qk_similarities = sim.clone()
438
455
 
456
+ if exists(self.pre_scale_post_talking_heads):
457
+ pre_to_post_scale = self.pre_scale_post_talking_heads(sim)
458
+
439
459
  if exists(self.pre_softmax_talking_heads):
440
460
  sim = sim + self.pre_softmax_talking_heads(sim)
441
461
 
@@ -478,6 +498,7 @@ class Attend(Module):
478
498
  sim = sim + sim.sigmoid().log()
479
499
 
480
500
  attn = self.attn_fn(sim)
501
+
481
502
  attn = attn.type(dtype)
482
503
 
483
504
  post_softmax_attn = attn
@@ -487,6 +508,9 @@ class Attend(Module):
487
508
  if exists(self.post_softmax_talking_heads):
488
509
  attn = self.post_softmax_talking_heads(attn)
489
510
 
511
+ if exists(self.pre_scale_post_talking_heads):
512
+ attn = attn * pre_to_post_scale
513
+
490
514
  out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
491
515
 
492
516
  intermediates = Intermediates(
@@ -909,6 +909,7 @@ class Attention(Module):
909
909
  flash = False,
910
910
  pre_talking_heads = False,
911
911
  post_talking_heads = False,
912
+ pre_scale_post_talking_heads = False,
912
913
  head_scale = False,
913
914
  sparse_topk = None,
914
915
  num_mem_kv = 0,
@@ -918,6 +919,7 @@ class Attention(Module):
918
919
  swiglu_values = False,
919
920
  gate_values = False,
920
921
  zero_init_output = False,
922
+ hard = False,
921
923
  sigsoftmax = False,
922
924
  max_attend_past = None,
923
925
  qk_norm = False,
@@ -1039,8 +1041,10 @@ class Attention(Module):
1039
1041
  causal = causal,
1040
1042
  pre_talking_heads = pre_talking_heads,
1041
1043
  post_talking_heads = post_talking_heads,
1044
+ pre_scale_post_talking_heads = pre_scale_post_talking_heads,
1042
1045
  dropout = dropout,
1043
1046
  sparse_topk = sparse_topk,
1047
+ hard = hard,
1044
1048
  qk_norm = qk_norm,
1045
1049
  scale = qk_norm_scale if qk_norm else self.scale,
1046
1050
  l2_distance = l2_distance,
@@ -1610,6 +1614,7 @@ class AttentionLayers(Module):
1610
1614
  layer = post_branch_fn(layer)
1611
1615
 
1612
1616
  residual_fn = GRUGating if gate_residual else Residual
1617
+
1613
1618
  residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1614
1619
 
1615
1620
  # handle unet skip connection
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.39.0
3
+ Version: 1.39.3
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=sTeX7DmUt6I5FhHtgcTDOIvmD1CvJ1PmVjZ_-lYO-QA,15596
2
+ x_transformers/attend.py,sha256=K4DNcWAgOHzNtOTIJA15VA3VQ2KMyv-PX8oO1R0Z5Rw,16670
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=nfq-EOLx5HWf8tXlmVuDbkhqNfFnfqRMCEkALK3SFkA,84341
8
+ x_transformers/x_transformers.py,sha256=3JrVqwYVrd5UVf2esdunTcer7QL72H7VF4mL3UsCWOI,84508
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.39.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.39.0.dist-info/METADATA,sha256=-ppCMMH6ZTsmwaMJB9q4b4Yvd-nU8v95l5SdzTY17OU,661
13
- x_transformers-1.39.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.39.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.39.0.dist-info/RECORD,,
11
+ x_transformers-1.39.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.39.3.dist-info/METADATA,sha256=aqCockVfuZfLM3D9ZdlgN_HfUPdhZm4UWyKg2HLkuUo,661
13
+ x_transformers-1.39.3.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.39.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.39.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.1.0)
2
+ Generator: setuptools (75.2.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5