x-transformers 1.26.6__py3-none-any.whl → 1.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/continuous.py +4 -3
- x_transformers/x_transformers.py +21 -8
- {x_transformers-1.26.6.dist-info → x_transformers-1.27.0.dist-info}/METADATA +1 -1
- {x_transformers-1.26.6.dist-info → x_transformers-1.27.0.dist-info}/RECORD +7 -7
- {x_transformers-1.26.6.dist-info → x_transformers-1.27.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.26.6.dist-info → x_transformers-1.27.0.dist-info}/WHEEL +0 -0
- {x_transformers-1.26.6.dist-info → x_transformers-1.27.0.dist-info}/top_level.txt +0 -0
x_transformers/continuous.py
CHANGED
@@ -8,6 +8,7 @@ from x_transformers.x_transformers import (
|
|
8
8
|
AttentionLayers,
|
9
9
|
ScaledSinusoidalEmbedding,
|
10
10
|
AbsolutePositionalEmbedding,
|
11
|
+
LayerNorm,
|
11
12
|
always,
|
12
13
|
pad_at_dim
|
13
14
|
)
|
@@ -54,7 +55,7 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
54
55
|
else:
|
55
56
|
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
|
56
57
|
|
57
|
-
self.post_emb_norm =
|
58
|
+
self.post_emb_norm = LayerNorm(dim) if post_emb_norm else nn.Identity()
|
58
59
|
self.emb_dropout = nn.Dropout(emb_dropout)
|
59
60
|
|
60
61
|
# memory tokens
|
@@ -71,8 +72,8 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
71
72
|
|
72
73
|
# project in and out
|
73
74
|
|
74
|
-
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
|
75
|
-
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
|
75
|
+
self.project_in = nn.Linear(dim_in, dim, bias = False) if exists(dim_in) else nn.Identity()
|
76
|
+
self.project_out = nn.Linear(dim, dim_out, bias = False) if exists(dim_out) else nn.Identity()
|
76
77
|
|
77
78
|
def forward(
|
78
79
|
self,
|
x_transformers/x_transformers.py
CHANGED
@@ -304,7 +304,7 @@ class DynamicPositionBias(nn.Module):
|
|
304
304
|
|
305
305
|
self.mlp.append(Sequential(
|
306
306
|
nn.Linear(1, dim),
|
307
|
-
|
307
|
+
LayerNorm(dim) if norm else None,
|
308
308
|
nn.SiLU()
|
309
309
|
))
|
310
310
|
|
@@ -498,6 +498,19 @@ class ScaleNorm(nn.Module):
|
|
498
498
|
norm = torch.norm(x, dim = -1, keepdim = True)
|
499
499
|
return x / norm.clamp(min = self.eps) * self.g
|
500
500
|
|
501
|
+
class LayerNorm(nn.Module):
|
502
|
+
def __init__(self, dim):
|
503
|
+
"""
|
504
|
+
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
505
|
+
latest pytorch actually has a way to turn this off in nn.LayerNorm
|
506
|
+
"""
|
507
|
+
super().__init__()
|
508
|
+
self.gamma = nn.Parameter(torch.ones(dim))
|
509
|
+
self.register_buffer("beta", torch.zeros(dim))
|
510
|
+
|
511
|
+
def forward(self, x):
|
512
|
+
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
513
|
+
|
501
514
|
class RMSNorm(nn.Module):
|
502
515
|
def __init__(self, dim):
|
503
516
|
super().__init__()
|
@@ -634,7 +647,7 @@ class FeedForward(nn.Module):
|
|
634
647
|
|
635
648
|
self.ff = Sequential(
|
636
649
|
project_in,
|
637
|
-
|
650
|
+
LayerNorm(inner_dim) if post_act_ln else None,
|
638
651
|
nn.Dropout(dropout),
|
639
652
|
nn.Linear(inner_dim, dim_out, bias = not no_bias)
|
640
653
|
)
|
@@ -1083,7 +1096,7 @@ class AttentionLayers(nn.Module):
|
|
1083
1096
|
elif use_simple_rmsnorm:
|
1084
1097
|
norm_class = SimpleRMSNorm
|
1085
1098
|
else:
|
1086
|
-
norm_class =
|
1099
|
+
norm_class = LayerNorm
|
1087
1100
|
|
1088
1101
|
norm_fn = partial(norm_class, dim)
|
1089
1102
|
|
@@ -1415,12 +1428,12 @@ class ViTransformerWrapper(nn.Module):
|
|
1415
1428
|
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
|
1416
1429
|
|
1417
1430
|
self.patch_to_embedding = nn.Sequential(
|
1418
|
-
|
1431
|
+
LayerNorm(patch_dim),
|
1419
1432
|
nn.Linear(patch_dim, dim),
|
1420
|
-
|
1433
|
+
LayerNorm(dim)
|
1421
1434
|
)
|
1422
1435
|
|
1423
|
-
|
1436
|
+
LayerNorm(dim) if post_emb_norm else nn.Identity()
|
1424
1437
|
self.dropout = nn.Dropout(emb_dropout)
|
1425
1438
|
|
1426
1439
|
self.attn_layers = attn_layers
|
@@ -1515,7 +1528,7 @@ class TransformerWrapper(nn.Module):
|
|
1515
1528
|
|
1516
1529
|
self.emb_frac_gradient = emb_frac_gradient
|
1517
1530
|
|
1518
|
-
self.post_emb_norm =
|
1531
|
+
self.post_emb_norm = LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
|
1519
1532
|
self.emb_dropout = nn.Dropout(emb_dropout)
|
1520
1533
|
|
1521
1534
|
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
@@ -1524,7 +1537,7 @@ class TransformerWrapper(nn.Module):
|
|
1524
1537
|
self.init_()
|
1525
1538
|
|
1526
1539
|
logits_dim = default(logits_dim, num_tokens)
|
1527
|
-
self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
|
1540
|
+
self.to_logits = nn.Linear(dim, logits_dim, bias = False) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
|
1528
1541
|
|
1529
1542
|
# memory tokens (like [cls]) from Memory Transformers paper
|
1530
1543
|
|
@@ -1,13 +1,13 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=pXc_U4M3ONUQcpNgZySDIlCF1rp7u4FFmcOYjc4WuXw,629
|
2
2
|
x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,10188
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
|
4
|
-
x_transformers/continuous.py,sha256=
|
4
|
+
x_transformers/continuous.py,sha256=Ra5IClCl9G7SAiM6L9w6iY4cCznH0dSGljC9AC_bNyw,6066
|
5
5
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=
|
6
|
+
x_transformers/x_transformers.py,sha256=80qvAhandlAmt-mkiG7Ft6e5caCSDVRCFRfHppGvd5A,62216
|
7
7
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
8
|
x_transformers/xval.py,sha256=lS9W_E_RskPQAqVZkPiUzbByoW1Ajsw_phsikA3JDAg,8139
|
9
|
-
x_transformers-1.
|
10
|
-
x_transformers-1.
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
9
|
+
x_transformers-1.27.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
10
|
+
x_transformers-1.27.0.dist-info/METADATA,sha256=7kqhAXMJ-POUrT-6-QGBbgP4j7DCUR04bK8-ULmOYxQ,661
|
11
|
+
x_transformers-1.27.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
12
|
+
x_transformers-1.27.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
13
|
+
x_transformers-1.27.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|