x-transformers 1.30.14__py3-none-any.whl → 1.30.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +6 -17
- {x_transformers-1.30.14.dist-info → x_transformers-1.30.16.dist-info}/METADATA +1 -1
- {x_transformers-1.30.14.dist-info → x_transformers-1.30.16.dist-info}/RECORD +6 -6
- {x_transformers-1.30.14.dist-info → x_transformers-1.30.16.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.14.dist-info → x_transformers-1.30.16.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.14.dist-info → x_transformers-1.30.16.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -80,13 +80,6 @@ class equals():
|
|
80
80
|
def __call__(self, x, *args, **kwargs):
|
81
81
|
return x == self.val
|
82
82
|
|
83
|
-
class Identity(Module):
|
84
|
-
def __init__(self, *args, **kwargs):
|
85
|
-
super().__init__()
|
86
|
-
|
87
|
-
def __call__(self, x, *args, **kwargs):
|
88
|
-
return x
|
89
|
-
|
90
83
|
def Sequential(*modules):
|
91
84
|
return nn.Sequential(*filter(exists, modules))
|
92
85
|
|
@@ -1366,18 +1359,14 @@ class AttentionLayers(Module):
|
|
1366
1359
|
residual_fn = GRUGating if gate_residual else Residual
|
1367
1360
|
residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
|
1368
1361
|
|
1369
|
-
# for the peri-layernorm config from https://arxiv.org/abs/2405.16039
|
1370
|
-
# must be paired with qk norm
|
1371
|
-
|
1372
|
-
layer_norm_fn = norm_fn
|
1373
|
-
if no_pre_or_postnorm:
|
1374
|
-
layer_norm_fn = Identity
|
1375
|
-
|
1376
1362
|
# all normalizations of the layer
|
1377
1363
|
|
1378
|
-
pre_branch_norm =
|
1379
|
-
|
1380
|
-
|
1364
|
+
pre_branch_norm = post_branch_norm = post_main_norm = None
|
1365
|
+
|
1366
|
+
if not no_pre_or_postnorm:
|
1367
|
+
pre_branch_norm = norm_fn() if pre_norm else None
|
1368
|
+
post_branch_norm = norm_fn() if sandwich_norm else None
|
1369
|
+
post_main_norm = norm_fn() if not pre_norm else None
|
1381
1370
|
|
1382
1371
|
norms = ModuleList([
|
1383
1372
|
pre_branch_norm,
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=G4aF6SHVvhcu8OHFwCj0ZlJahEeGR1lUaRglNHcK74k,69225
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.16.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.16.dist-info/METADATA,sha256=6BifB-CeW-wD7SjuXrKCC5dbJUnq_iLEMl0DXcszvt0,662
|
12
|
+
x_transformers-1.30.16.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.16.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.16.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|