x-transformers 1.26.5__py3-none-any.whl → 1.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/autoregressive_wrapper.py +1 -1
- x_transformers/continuous.py +4 -3
- x_transformers/x_transformers.py +22 -9
- {x_transformers-1.26.5.dist-info → x_transformers-1.27.0.dist-info}/METADATA +1 -1
- x_transformers-1.27.0.dist-info/RECORD +13 -0
- x_transformers-1.26.5.dist-info/RECORD +0 -13
- {x_transformers-1.26.5.dist-info → x_transformers-1.27.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.26.5.dist-info → x_transformers-1.27.0.dist-info}/WHEEL +0 -0
- {x_transformers-1.26.5.dist-info → x_transformers-1.27.0.dist-info}/top_level.txt +0 -0
@@ -190,7 +190,7 @@ class AutoregressiveWrapper(Module):
|
|
190
190
|
if restrict_to_max_seq_len:
|
191
191
|
max_len_exceeded = out.shape[-1] > max_seq_len
|
192
192
|
|
193
|
-
assert not (cache_kv and max_len_exceeded and self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embeeding. you can switch to rotary embeddings to resolve this issue'
|
193
|
+
assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embeeding. you can switch to rotary embeddings to resolve this issue'
|
194
194
|
|
195
195
|
x = out[:, -max_seq_len:]
|
196
196
|
|
x_transformers/continuous.py
CHANGED
@@ -8,6 +8,7 @@ from x_transformers.x_transformers import (
|
|
8
8
|
AttentionLayers,
|
9
9
|
ScaledSinusoidalEmbedding,
|
10
10
|
AbsolutePositionalEmbedding,
|
11
|
+
LayerNorm,
|
11
12
|
always,
|
12
13
|
pad_at_dim
|
13
14
|
)
|
@@ -54,7 +55,7 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
54
55
|
else:
|
55
56
|
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
|
56
57
|
|
57
|
-
self.post_emb_norm =
|
58
|
+
self.post_emb_norm = LayerNorm(dim) if post_emb_norm else nn.Identity()
|
58
59
|
self.emb_dropout = nn.Dropout(emb_dropout)
|
59
60
|
|
60
61
|
# memory tokens
|
@@ -71,8 +72,8 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
71
72
|
|
72
73
|
# project in and out
|
73
74
|
|
74
|
-
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
|
75
|
-
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
|
75
|
+
self.project_in = nn.Linear(dim_in, dim, bias = False) if exists(dim_in) else nn.Identity()
|
76
|
+
self.project_out = nn.Linear(dim, dim_out, bias = False) if exists(dim_out) else nn.Identity()
|
76
77
|
|
77
78
|
def forward(
|
78
79
|
self,
|
x_transformers/x_transformers.py
CHANGED
@@ -304,7 +304,7 @@ class DynamicPositionBias(nn.Module):
|
|
304
304
|
|
305
305
|
self.mlp.append(Sequential(
|
306
306
|
nn.Linear(1, dim),
|
307
|
-
|
307
|
+
LayerNorm(dim) if norm else None,
|
308
308
|
nn.SiLU()
|
309
309
|
))
|
310
310
|
|
@@ -498,6 +498,19 @@ class ScaleNorm(nn.Module):
|
|
498
498
|
norm = torch.norm(x, dim = -1, keepdim = True)
|
499
499
|
return x / norm.clamp(min = self.eps) * self.g
|
500
500
|
|
501
|
+
class LayerNorm(nn.Module):
|
502
|
+
def __init__(self, dim):
|
503
|
+
"""
|
504
|
+
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
505
|
+
latest pytorch actually has a way to turn this off in nn.LayerNorm
|
506
|
+
"""
|
507
|
+
super().__init__()
|
508
|
+
self.gamma = nn.Parameter(torch.ones(dim))
|
509
|
+
self.register_buffer("beta", torch.zeros(dim))
|
510
|
+
|
511
|
+
def forward(self, x):
|
512
|
+
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
513
|
+
|
501
514
|
class RMSNorm(nn.Module):
|
502
515
|
def __init__(self, dim):
|
503
516
|
super().__init__()
|
@@ -634,7 +647,7 @@ class FeedForward(nn.Module):
|
|
634
647
|
|
635
648
|
self.ff = Sequential(
|
636
649
|
project_in,
|
637
|
-
|
650
|
+
LayerNorm(inner_dim) if post_act_ln else None,
|
638
651
|
nn.Dropout(dropout),
|
639
652
|
nn.Linear(inner_dim, dim_out, bias = not no_bias)
|
640
653
|
)
|
@@ -1083,7 +1096,7 @@ class AttentionLayers(nn.Module):
|
|
1083
1096
|
elif use_simple_rmsnorm:
|
1084
1097
|
norm_class = SimpleRMSNorm
|
1085
1098
|
else:
|
1086
|
-
norm_class =
|
1099
|
+
norm_class = LayerNorm
|
1087
1100
|
|
1088
1101
|
norm_fn = partial(norm_class, dim)
|
1089
1102
|
|
@@ -1415,12 +1428,12 @@ class ViTransformerWrapper(nn.Module):
|
|
1415
1428
|
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
|
1416
1429
|
|
1417
1430
|
self.patch_to_embedding = nn.Sequential(
|
1418
|
-
|
1431
|
+
LayerNorm(patch_dim),
|
1419
1432
|
nn.Linear(patch_dim, dim),
|
1420
|
-
|
1433
|
+
LayerNorm(dim)
|
1421
1434
|
)
|
1422
1435
|
|
1423
|
-
|
1436
|
+
LayerNorm(dim) if post_emb_norm else nn.Identity()
|
1424
1437
|
self.dropout = nn.Dropout(emb_dropout)
|
1425
1438
|
|
1426
1439
|
self.attn_layers = attn_layers
|
@@ -1515,7 +1528,7 @@ class TransformerWrapper(nn.Module):
|
|
1515
1528
|
|
1516
1529
|
self.emb_frac_gradient = emb_frac_gradient
|
1517
1530
|
|
1518
|
-
self.post_emb_norm =
|
1531
|
+
self.post_emb_norm = LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
|
1519
1532
|
self.emb_dropout = nn.Dropout(emb_dropout)
|
1520
1533
|
|
1521
1534
|
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
@@ -1524,7 +1537,7 @@ class TransformerWrapper(nn.Module):
|
|
1524
1537
|
self.init_()
|
1525
1538
|
|
1526
1539
|
logits_dim = default(logits_dim, num_tokens)
|
1527
|
-
self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
|
1540
|
+
self.to_logits = nn.Linear(dim, logits_dim, bias = False) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
|
1528
1541
|
|
1529
1542
|
# memory tokens (like [cls]) from Memory Transformers paper
|
1530
1543
|
|
@@ -1538,7 +1551,7 @@ class TransformerWrapper(nn.Module):
|
|
1538
1551
|
# whether can do cached kv decoding
|
1539
1552
|
|
1540
1553
|
self.can_cache_kv = self.num_memory_tokens == 0
|
1541
|
-
self.can_cache_kv_outside_max_seq_len =
|
1554
|
+
self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
|
1542
1555
|
|
1543
1556
|
def init_(self):
|
1544
1557
|
if self.l2norm_embed:
|
@@ -0,0 +1,13 @@
|
|
1
|
+
x_transformers/__init__.py,sha256=pXc_U4M3ONUQcpNgZySDIlCF1rp7u4FFmcOYjc4WuXw,629
|
2
|
+
x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,10188
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
|
4
|
+
x_transformers/continuous.py,sha256=Ra5IClCl9G7SAiM6L9w6iY4cCznH0dSGljC9AC_bNyw,6066
|
5
|
+
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
+
x_transformers/x_transformers.py,sha256=80qvAhandlAmt-mkiG7Ft6e5caCSDVRCFRfHppGvd5A,62216
|
7
|
+
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
|
+
x_transformers/xval.py,sha256=lS9W_E_RskPQAqVZkPiUzbByoW1Ajsw_phsikA3JDAg,8139
|
9
|
+
x_transformers-1.27.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
10
|
+
x_transformers-1.27.0.dist-info/METADATA,sha256=7kqhAXMJ-POUrT-6-QGBbgP4j7DCUR04bK8-ULmOYxQ,661
|
11
|
+
x_transformers-1.27.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
12
|
+
x_transformers-1.27.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
13
|
+
x_transformers-1.27.0.dist-info/RECORD,,
|
@@ -1,13 +0,0 @@
|
|
1
|
-
x_transformers/__init__.py,sha256=pXc_U4M3ONUQcpNgZySDIlCF1rp7u4FFmcOYjc4WuXw,629
|
2
|
-
x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,10188
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=47sc7HAMNBJUGZRtZX-cO-yML0YFcw4PF6E-7pp1E0A,9614
|
4
|
-
x_transformers/continuous.py,sha256=ixfgi2_zpGN03SX_STXFkNYEOAkgwVIxuS53QgDCx-g,6026
|
5
|
-
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=8n8R_huY0KwKDGTUlLLhleAqNR5M1YI_95KRmhrP_Eg,61740
|
7
|
-
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
|
-
x_transformers/xval.py,sha256=lS9W_E_RskPQAqVZkPiUzbByoW1Ajsw_phsikA3JDAg,8139
|
9
|
-
x_transformers-1.26.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
10
|
-
x_transformers-1.26.5.dist-info/METADATA,sha256=GcEy7CtmuqOpAapRxh7Et5kfPOBiV2EIa6GjN2U-eFM,661
|
11
|
-
x_transformers-1.26.5.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
12
|
-
x_transformers-1.26.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
13
|
-
x_transformers-1.26.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|