x-transformers 2.6.5__py3-none-any.whl → 2.6.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +8 -9
- x_transformers/x_transformers.py +2 -2
- {x_transformers-2.6.5.dist-info → x_transformers-2.6.6.dist-info}/METADATA +1 -1
- {x_transformers-2.6.5.dist-info → x_transformers-2.6.6.dist-info}/RECORD +6 -6
- {x_transformers-2.6.5.dist-info → x_transformers-2.6.6.dist-info}/WHEEL +0 -0
- {x_transformers-2.6.5.dist-info → x_transformers-2.6.6.dist-info}/licenses/LICENSE +0 -0
x_transformers/attend.py
CHANGED
@@ -176,7 +176,7 @@ class Attend(Module):
|
|
176
176
|
softclamp_logits = False,
|
177
177
|
logit_softclamp_value = 50.,
|
178
178
|
add_zero_kv = False,
|
179
|
-
|
179
|
+
head_learned_sink = False,
|
180
180
|
selective = False,
|
181
181
|
hard = False,
|
182
182
|
cope = None,
|
@@ -257,10 +257,10 @@ class Attend(Module):
|
|
257
257
|
|
258
258
|
# learned sink concatted pre-softmax, working solution from gpt-oss
|
259
259
|
|
260
|
-
|
261
|
-
assert not (self.has_head_learned_sinks and flash), f'not supported for flash attention yet'
|
260
|
+
assert not (head_learned_sink and flash), f'not supported for flash attention yet'
|
262
261
|
|
263
|
-
self.
|
262
|
+
self.head_learned_sink = head_learned_sink
|
263
|
+
self.head_attn_sink = Parameter(torch.zeros(heads)) if head_learned_sink else None
|
264
264
|
|
265
265
|
# soft clamp attention logit value
|
266
266
|
|
@@ -517,10 +517,9 @@ class Attend(Module):
|
|
517
517
|
if self.selective:
|
518
518
|
sim = selective_attn(sim)
|
519
519
|
|
520
|
-
if self.
|
520
|
+
if self.head_learned_sink:
|
521
521
|
# add learned attention sink
|
522
|
-
|
523
|
-
attn_sink = repeat(self.head_attn_sinks, 'h sinks -> b h i sinks', b = sim.shape[0], i = sim.shape[2])
|
522
|
+
attn_sink = repeat(self.head_attn_sink, 'h -> b h i 1', b = sim.shape[0], i = sim.shape[2])
|
524
523
|
sim = cat((attn_sink, sim), dim = -1)
|
525
524
|
|
526
525
|
pre_softmax_attn = sim
|
@@ -531,9 +530,9 @@ class Attend(Module):
|
|
531
530
|
|
532
531
|
post_softmax_attn = attn
|
533
532
|
|
534
|
-
if self.
|
533
|
+
if self.head_learned_sink:
|
535
534
|
# remove attention sink
|
536
|
-
attn = attn[...,
|
535
|
+
attn = attn[..., 1:]
|
537
536
|
|
538
537
|
attn = self.attn_dropout(attn)
|
539
538
|
|
x_transformers/x_transformers.py
CHANGED
@@ -1319,7 +1319,7 @@ class Attention(Module):
|
|
1319
1319
|
value_dim_head = None,
|
1320
1320
|
dim_out = None,
|
1321
1321
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
1322
|
-
|
1322
|
+
head_learned_sink = False,
|
1323
1323
|
rotate_num_heads = None,
|
1324
1324
|
data_dependent_alibi = False,
|
1325
1325
|
data_dependent_alibi_per_row = False,
|
@@ -1516,7 +1516,7 @@ class Attention(Module):
|
|
1516
1516
|
selective = selective,
|
1517
1517
|
custom_attn_fn = custom_attn_fn,
|
1518
1518
|
add_zero_kv = add_zero_kv,
|
1519
|
-
|
1519
|
+
head_learned_sink = head_learned_sink,
|
1520
1520
|
flash = flash,
|
1521
1521
|
softclamp_logits = softclamp_logits,
|
1522
1522
|
logit_softclamp_value = logit_softclamp_value,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=aVuhUU0572TJHW88BVc4yA2tla0Zb8l3NH7W4RZ1AEs,1005
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=JJv6ypJbZIFmH1LQ49hFg6hD0Wf9Z7Im1AP2ekm9hVI,18091
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=BsGO9xfVYkvynqbU1__tu_S_cxl7gss0YwnkhIa2baY,18401
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
5
|
x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
|
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
|
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
11
|
x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
|
12
|
-
x_transformers/x_transformers.py,sha256=
|
12
|
+
x_transformers/x_transformers.py,sha256=vjRMEMA12Js94YwLVeZksYMEoRgK6CSKT6TJViMPp7U,122186
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
15
|
-
x_transformers-2.6.
|
16
|
-
x_transformers-2.6.
|
17
|
-
x_transformers-2.6.
|
18
|
-
x_transformers-2.6.
|
15
|
+
x_transformers-2.6.6.dist-info/METADATA,sha256=95CKrJ98X7R0hpb5D8GHSfi372UtxXDSeDaO2qB0Lrs,90445
|
16
|
+
x_transformers-2.6.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
x_transformers-2.6.6.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
18
|
+
x_transformers-2.6.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|