x-transformers 2.2.11__py3-none-any.whl → 2.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +19 -2
- {x_transformers-2.2.11.dist-info → x_transformers-2.2.12.dist-info}/METADATA +1 -1
- {x_transformers-2.2.11.dist-info → x_transformers-2.2.12.dist-info}/RECORD +5 -5
- {x_transformers-2.2.11.dist-info → x_transformers-2.2.12.dist-info}/WHEEL +0 -0
- {x_transformers-2.2.11.dist-info → x_transformers-2.2.12.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -62,6 +62,9 @@ def default(val, d):
|
|
62
62
|
return val
|
63
63
|
return d() if callable(d) else d
|
64
64
|
|
65
|
+
def identity(t, *args, **kwargs):
|
66
|
+
return t
|
67
|
+
|
65
68
|
def first(it, default = None):
|
66
69
|
return it[0] if len(it) > 0 else default
|
67
70
|
|
@@ -74,7 +77,10 @@ def cast_tuple(val, depth = 1):
|
|
74
77
|
def divisible_by(num, den):
|
75
78
|
return (num % den) == 0
|
76
79
|
|
77
|
-
def maybe(fn):
|
80
|
+
def maybe(fn = None):
|
81
|
+
if not exists(fn):
|
82
|
+
fn = identity
|
83
|
+
|
78
84
|
@wraps(fn)
|
79
85
|
def inner(x, *args, **kwargs):
|
80
86
|
if not exists(x):
|
@@ -1199,6 +1205,7 @@ class FeedForward(Module):
|
|
1199
1205
|
custom_activation = None,
|
1200
1206
|
post_act_ln = False,
|
1201
1207
|
dropout = 0.,
|
1208
|
+
sublayer_dropout = 0.,
|
1202
1209
|
no_bias = False,
|
1203
1210
|
zero_init_output = False
|
1204
1211
|
):
|
@@ -1227,7 +1234,8 @@ class FeedForward(Module):
|
|
1227
1234
|
project_in,
|
1228
1235
|
LayerNorm(inner_dim) if post_act_ln else None,
|
1229
1236
|
nn.Dropout(dropout),
|
1230
|
-
nn.Linear(inner_dim, dim_out, bias = not no_bias)
|
1237
|
+
nn.Linear(inner_dim, dim_out, bias = not no_bias),
|
1238
|
+
nn.Dropout(sublayer_dropout) if sublayer_dropout > 0. else None
|
1231
1239
|
)
|
1232
1240
|
|
1233
1241
|
# init last linear layer to 0
|
@@ -1256,6 +1264,7 @@ class Attention(Module):
|
|
1256
1264
|
sparse_topk_straight_through = False,
|
1257
1265
|
num_mem_kv = 0,
|
1258
1266
|
dropout = 0.,
|
1267
|
+
sublayer_dropout = 0.,
|
1259
1268
|
on_attn = False,
|
1260
1269
|
gate_value_heads = False,
|
1261
1270
|
swiglu_values = False,
|
@@ -1534,6 +1543,10 @@ class Attention(Module):
|
|
1534
1543
|
dim_out = default(dim_out, dim)
|
1535
1544
|
self.to_out = nn.Sequential(LinearNoBias(out_dim, dim_out * 2), nn.GLU()) if on_attn else LinearNoBias(out_dim, dim_out)
|
1536
1545
|
|
1546
|
+
# sublayer dropout
|
1547
|
+
|
1548
|
+
self.sublayer_dropout = nn.Dropout(sublayer_dropout) if sublayer_dropout > 0. else None
|
1549
|
+
|
1537
1550
|
# the number of attention heads to rotate, for decoupled rope in multi-latent attention
|
1538
1551
|
|
1539
1552
|
rotate_num_heads = default(rotate_num_heads, heads)
|
@@ -1871,6 +1884,10 @@ class Attention(Module):
|
|
1871
1884
|
|
1872
1885
|
out = self.to_out(out)
|
1873
1886
|
|
1887
|
+
# maybe sublayer dropout
|
1888
|
+
|
1889
|
+
out = maybe(self.sublayer_dropout)(out)
|
1890
|
+
|
1874
1891
|
if exists(mask):
|
1875
1892
|
out = einx.where('b n, b n d, -> b n d', mask, out, 0.)
|
1876
1893
|
|
@@ -8,10 +8,10 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
|
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
|
-
x_transformers/x_transformers.py,sha256=
|
11
|
+
x_transformers/x_transformers.py,sha256=MF91aJGr2DOjIGe57uqwgyNxCExBg_tI9z7usAJMxOM,112401
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
14
|
-
x_transformers-2.2.
|
15
|
-
x_transformers-2.2.
|
16
|
-
x_transformers-2.2.
|
17
|
-
x_transformers-2.2.
|
14
|
+
x_transformers-2.2.12.dist-info/METADATA,sha256=cWj_UYsNQNf2botGDqO7GkyiUh3msLww0EilFMMhRS0,88687
|
15
|
+
x_transformers-2.2.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.2.12.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.2.12.dist-info/RECORD,,
|
File without changes
|
File without changes
|