x-transformers 2.2.11__py3-none-any.whl → 2.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -62,6 +62,9 @@ def default(val, d):
62
62
  return val
63
63
  return d() if callable(d) else d
64
64
 
65
+ def identity(t, *args, **kwargs):
66
+ return t
67
+
65
68
  def first(it, default = None):
66
69
  return it[0] if len(it) > 0 else default
67
70
 
@@ -74,7 +77,10 @@ def cast_tuple(val, depth = 1):
74
77
  def divisible_by(num, den):
75
78
  return (num % den) == 0
76
79
 
77
- def maybe(fn):
80
+ def maybe(fn = None):
81
+ if not exists(fn):
82
+ fn = identity
83
+
78
84
  @wraps(fn)
79
85
  def inner(x, *args, **kwargs):
80
86
  if not exists(x):
@@ -1199,6 +1205,7 @@ class FeedForward(Module):
1199
1205
  custom_activation = None,
1200
1206
  post_act_ln = False,
1201
1207
  dropout = 0.,
1208
+ sublayer_dropout = 0.,
1202
1209
  no_bias = False,
1203
1210
  zero_init_output = False
1204
1211
  ):
@@ -1227,7 +1234,8 @@ class FeedForward(Module):
1227
1234
  project_in,
1228
1235
  LayerNorm(inner_dim) if post_act_ln else None,
1229
1236
  nn.Dropout(dropout),
1230
- nn.Linear(inner_dim, dim_out, bias = not no_bias)
1237
+ nn.Linear(inner_dim, dim_out, bias = not no_bias),
1238
+ nn.Dropout(sublayer_dropout) if sublayer_dropout > 0. else None
1231
1239
  )
1232
1240
 
1233
1241
  # init last linear layer to 0
@@ -1256,6 +1264,7 @@ class Attention(Module):
1256
1264
  sparse_topk_straight_through = False,
1257
1265
  num_mem_kv = 0,
1258
1266
  dropout = 0.,
1267
+ sublayer_dropout = 0.,
1259
1268
  on_attn = False,
1260
1269
  gate_value_heads = False,
1261
1270
  swiglu_values = False,
@@ -1534,6 +1543,10 @@ class Attention(Module):
1534
1543
  dim_out = default(dim_out, dim)
1535
1544
  self.to_out = nn.Sequential(LinearNoBias(out_dim, dim_out * 2), nn.GLU()) if on_attn else LinearNoBias(out_dim, dim_out)
1536
1545
 
1546
+ # sublayer dropout
1547
+
1548
+ self.sublayer_dropout = nn.Dropout(sublayer_dropout) if sublayer_dropout > 0. else None
1549
+
1537
1550
  # the number of attention heads to rotate, for decoupled rope in multi-latent attention
1538
1551
 
1539
1552
  rotate_num_heads = default(rotate_num_heads, heads)
@@ -1871,6 +1884,10 @@ class Attention(Module):
1871
1884
 
1872
1885
  out = self.to_out(out)
1873
1886
 
1887
+ # maybe sublayer dropout
1888
+
1889
+ out = maybe(self.sublayer_dropout)(out)
1890
+
1874
1891
  if exists(mask):
1875
1892
  out = einx.where('b n, b n d, -> b n d', mask, out, 0.)
1876
1893
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.2.11
3
+ Version: 2.2.12
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -8,10 +8,10 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
- x_transformers/x_transformers.py,sha256=m2xiiTafFZiII-QZLCpPerdWbY8O41I6BAYCaaPdXig,111953
11
+ x_transformers/x_transformers.py,sha256=MF91aJGr2DOjIGe57uqwgyNxCExBg_tI9z7usAJMxOM,112401
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
14
- x_transformers-2.2.11.dist-info/METADATA,sha256=riLwnpQD_lZYsGehLw8rt8Pja2qt2k8lRepqDgMKGwA,88687
15
- x_transformers-2.2.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.2.11.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.2.11.dist-info/RECORD,,
14
+ x_transformers-2.2.12.dist-info/METADATA,sha256=cWj_UYsNQNf2botGDqO7GkyiUh3msLww0EilFMMhRS0,88687
15
+ x_transformers-2.2.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.2.12.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.2.12.dist-info/RECORD,,