x-transformers 1.31.3__py3-none-any.whl → 1.31.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +8 -6
- x_transformers/x_transformers.py +14 -15
- {x_transformers-1.31.3.dist-info → x_transformers-1.31.5.dist-info}/METADATA +1 -1
- {x_transformers-1.31.3.dist-info → x_transformers-1.31.5.dist-info}/RECORD +7 -7
- {x_transformers-1.31.3.dist-info → x_transformers-1.31.5.dist-info}/LICENSE +0 -0
- {x_transformers-1.31.3.dist-info → x_transformers-1.31.5.dist-info}/WHEEL +0 -0
- {x_transformers-1.31.3.dist-info → x_transformers-1.31.5.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -69,7 +69,7 @@ def onnx_create_causal_mask(i, j, device):
|
|
69
69
|
|
70
70
|
# main class
|
71
71
|
|
72
|
-
class Attend(
|
72
|
+
class Attend(Module):
|
73
73
|
def __init__(
|
74
74
|
self,
|
75
75
|
*,
|
@@ -81,7 +81,8 @@ class Attend(nn.Module):
|
|
81
81
|
scale = None,
|
82
82
|
qk_norm = False,
|
83
83
|
flash = False,
|
84
|
-
|
84
|
+
softclamp_logits = False,
|
85
|
+
logit_softclamp_value = 30.,
|
85
86
|
add_zero_kv = False,
|
86
87
|
cope = None,
|
87
88
|
onnxable = False,
|
@@ -123,10 +124,11 @@ class Attend(nn.Module):
|
|
123
124
|
|
124
125
|
# soft clamp attention logit value
|
125
126
|
|
126
|
-
if
|
127
|
+
if softclamp_logits:
|
127
128
|
assert not flash, 'flash attention not compatible with logit softclamp value yet'
|
128
129
|
assert logit_softclamp_value > 0.
|
129
130
|
|
131
|
+
self.softclamp_logits = softclamp_logits
|
130
132
|
self.logit_softclamp_value = logit_softclamp_value
|
131
133
|
|
132
134
|
# contextual positional encoding
|
@@ -308,6 +310,9 @@ class Attend(nn.Module):
|
|
308
310
|
if exists(attn_bias):
|
309
311
|
sim = sim + attn_bias
|
310
312
|
|
313
|
+
if self.softclamp_logits:
|
314
|
+
sim = softclamp(sim, self.logit_softclamp_value)
|
315
|
+
|
311
316
|
i, j, dtype = *sim.shape[-2:], sim.dtype
|
312
317
|
|
313
318
|
mask_value = -torch.finfo(sim.dtype).max
|
@@ -329,9 +334,6 @@ class Attend(nn.Module):
|
|
329
334
|
|
330
335
|
pre_softmax_attn = sim.clone()
|
331
336
|
|
332
|
-
if exists(self.logit_softclamp_value):
|
333
|
-
sim = softclamp(sim, self.logit_softclamp_value)
|
334
|
-
|
335
337
|
attn = self.attn_fn(sim, dim = -1)
|
336
338
|
attn = attn.type(dtype)
|
337
339
|
|
x_transformers/x_transformers.py
CHANGED
@@ -566,7 +566,7 @@ class LayerNorm(Module):
|
|
566
566
|
def __init__(
|
567
567
|
self,
|
568
568
|
dim,
|
569
|
-
unit_offset =
|
569
|
+
unit_offset = False
|
570
570
|
):
|
571
571
|
"""
|
572
572
|
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
@@ -576,9 +576,7 @@ class LayerNorm(Module):
|
|
576
576
|
|
577
577
|
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
|
578
578
|
self.gamma = nn.Parameter(torch.ones(dim))
|
579
|
-
nn.init.constant_(self.gamma, 1. - unit_offset)
|
580
|
-
|
581
|
-
self.register_buffer('beta', torch.zeros(dim), persistent = False)
|
579
|
+
nn.init.constant_(self.gamma, 1. - float(unit_offset))
|
582
580
|
|
583
581
|
def forward(self, x):
|
584
582
|
normed = self.ln(x)
|
@@ -596,7 +594,6 @@ class AdaptiveLayerNorm(Module):
|
|
596
594
|
|
597
595
|
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
|
598
596
|
self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
|
599
|
-
|
600
597
|
nn.init.zeros_(self.to_gamma.weight)
|
601
598
|
|
602
599
|
def forward(self, x, *, condition):
|
@@ -608,14 +605,14 @@ class ScaleNorm(Module):
|
|
608
605
|
def __init__(
|
609
606
|
self,
|
610
607
|
dim,
|
611
|
-
unit_offset =
|
608
|
+
unit_offset = False
|
612
609
|
):
|
613
610
|
super().__init__()
|
614
611
|
self.unit_offset = unit_offset
|
615
612
|
self.scale = dim ** 0.5
|
616
613
|
|
617
614
|
self.g = nn.Parameter(torch.zeros(1))
|
618
|
-
nn.init.constant_(self.g, 1. - unit_offset)
|
615
|
+
nn.init.constant_(self.g, 1. - float(unit_offset))
|
619
616
|
|
620
617
|
def forward(self, x):
|
621
618
|
return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
|
@@ -624,14 +621,14 @@ class RMSNorm(Module):
|
|
624
621
|
def __init__(
|
625
622
|
self,
|
626
623
|
dim,
|
627
|
-
unit_offset =
|
624
|
+
unit_offset = False
|
628
625
|
):
|
629
626
|
super().__init__()
|
630
627
|
self.unit_offset = unit_offset
|
631
628
|
self.scale = dim ** 0.5
|
632
629
|
|
633
630
|
self.g = nn.Parameter(torch.zeros(dim))
|
634
|
-
nn.init.constant_(self.g, 1. - unit_offset)
|
631
|
+
nn.init.constant_(self.g, 1. - float(unit_offset))
|
635
632
|
|
636
633
|
def forward(self, x):
|
637
634
|
return F.normalize(x, dim = -1) * self.scale * (self.g + self.unit_offset)
|
@@ -738,14 +735,14 @@ class LayerScale(Module):
|
|
738
735
|
fn: Module,
|
739
736
|
dim,
|
740
737
|
init_value = 0.,
|
741
|
-
unit_offset =
|
738
|
+
unit_offset = False
|
742
739
|
):
|
743
740
|
super().__init__()
|
744
741
|
self.unit_offset = unit_offset
|
745
742
|
|
746
743
|
self.fn = fn
|
747
744
|
self.gamma = nn.Parameter(torch.zeros(dim))
|
748
|
-
nn.init.constant_(self.gamma, init_value - unit_offset)
|
745
|
+
nn.init.constant_(self.gamma, init_value - float(unit_offset))
|
749
746
|
|
750
747
|
def forward(self, x, **kwargs):
|
751
748
|
out = self.fn(x, **kwargs)
|
@@ -887,7 +884,8 @@ class Attention(Module):
|
|
887
884
|
cope_max_pos = 16,
|
888
885
|
cope_soft_onehot_pos = False,
|
889
886
|
cope_talking_heads = False,
|
890
|
-
|
887
|
+
softclamp_logits = False,
|
888
|
+
logit_softclamp_value = 30.,
|
891
889
|
onnxable = False
|
892
890
|
):
|
893
891
|
super().__init__()
|
@@ -990,6 +988,7 @@ class Attention(Module):
|
|
990
988
|
scale = qk_norm_scale if qk_norm else self.scale,
|
991
989
|
add_zero_kv = add_zero_kv,
|
992
990
|
flash = flash,
|
991
|
+
softclamp_logits = softclamp_logits,
|
993
992
|
logit_softclamp_value = logit_softclamp_value,
|
994
993
|
cope = cope,
|
995
994
|
onnxable = onnxable
|
@@ -1370,8 +1369,8 @@ class AttentionLayers(Module):
|
|
1370
1369
|
norm_fn = partial(norm_class, dim)
|
1371
1370
|
|
1372
1371
|
if not norm_need_condition and norm_add_unit_offset:
|
1373
|
-
# researcher Ohad Rubin shares in a blog post by adding an offset to gammas
|
1374
|
-
norm_fn = partial(norm_fn, unit_offset =
|
1372
|
+
# researcher Ohad Rubin shares in a blog post by adding an offset to gammas, they can be subjected to weight decay safely
|
1373
|
+
norm_fn = partial(norm_fn, unit_offset = True)
|
1375
1374
|
|
1376
1375
|
self.norm_need_condition = norm_need_condition
|
1377
1376
|
self.dim_condition = dim_condition
|
@@ -1404,7 +1403,7 @@ class AttentionLayers(Module):
|
|
1404
1403
|
self.post_branch_fn_needs_condition = post_branch_fn_needs_condition
|
1405
1404
|
|
1406
1405
|
if exists(post_branch_fn) and not post_branch_fn_needs_condition and norm_add_unit_offset:
|
1407
|
-
post_branch_fn = partial(post_branch_fn, unit_offset =
|
1406
|
+
post_branch_fn = partial(post_branch_fn, unit_offset = True)
|
1408
1407
|
|
1409
1408
|
# setup mlp for conditioning
|
1410
1409
|
|
@@ -1,14 +1,14 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=UWq0bElvJf-_j1N2QbJ2yg28xkWlhnOrLjMJt3If3ao,10956
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=VL9Dm8L5jnpgyt_V6DWGtLs9MeiTn7ZMCQdcHSFLVo8,75871
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.31.
|
11
|
-
x_transformers-1.31.
|
12
|
-
x_transformers-1.31.
|
13
|
-
x_transformers-1.31.
|
14
|
-
x_transformers-1.31.
|
10
|
+
x_transformers-1.31.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.31.5.dist-info/METADATA,sha256=o7z9SmDOf9BAU9FyMRuAbX4LIH8nIACA0acjGHPMz8s,661
|
12
|
+
x_transformers-1.31.5.dist-info/WHEEL,sha256=mguMlWGMX-VHnMpKOjjQidIo1ssRlCFu4a4mBpz1s2M,91
|
13
|
+
x_transformers-1.31.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.31.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|