x-transformers 1.26.5__py3-none-any.whl → 1.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -190,7 +190,7 @@ class AutoregressiveWrapper(Module):
190
190
  if restrict_to_max_seq_len:
191
191
  max_len_exceeded = out.shape[-1] > max_seq_len
192
192
 
193
- assert not (cache_kv and max_len_exceeded and self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embeeding. you can switch to rotary embeddings to resolve this issue'
193
+ assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embeeding. you can switch to rotary embeddings to resolve this issue'
194
194
 
195
195
  x = out[:, -max_seq_len:]
196
196
 
@@ -8,6 +8,7 @@ from x_transformers.x_transformers import (
8
8
  AttentionLayers,
9
9
  ScaledSinusoidalEmbedding,
10
10
  AbsolutePositionalEmbedding,
11
+ LayerNorm,
11
12
  always,
12
13
  pad_at_dim
13
14
  )
@@ -54,7 +55,7 @@ class ContinuousTransformerWrapper(nn.Module):
54
55
  else:
55
56
  self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
56
57
 
57
- self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
58
+ self.post_emb_norm = LayerNorm(dim) if post_emb_norm else nn.Identity()
58
59
  self.emb_dropout = nn.Dropout(emb_dropout)
59
60
 
60
61
  # memory tokens
@@ -71,8 +72,8 @@ class ContinuousTransformerWrapper(nn.Module):
71
72
 
72
73
  # project in and out
73
74
 
74
- self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
75
- self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
75
+ self.project_in = nn.Linear(dim_in, dim, bias = False) if exists(dim_in) else nn.Identity()
76
+ self.project_out = nn.Linear(dim, dim_out, bias = False) if exists(dim_out) else nn.Identity()
76
77
 
77
78
  def forward(
78
79
  self,
@@ -304,7 +304,7 @@ class DynamicPositionBias(nn.Module):
304
304
 
305
305
  self.mlp.append(Sequential(
306
306
  nn.Linear(1, dim),
307
- nn.LayerNorm(dim) if norm else None,
307
+ LayerNorm(dim) if norm else None,
308
308
  nn.SiLU()
309
309
  ))
310
310
 
@@ -498,6 +498,19 @@ class ScaleNorm(nn.Module):
498
498
  norm = torch.norm(x, dim = -1, keepdim = True)
499
499
  return x / norm.clamp(min = self.eps) * self.g
500
500
 
501
+ class LayerNorm(nn.Module):
502
+ def __init__(self, dim):
503
+ """
504
+ bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
505
+ latest pytorch actually has a way to turn this off in nn.LayerNorm
506
+ """
507
+ super().__init__()
508
+ self.gamma = nn.Parameter(torch.ones(dim))
509
+ self.register_buffer("beta", torch.zeros(dim))
510
+
511
+ def forward(self, x):
512
+ return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
513
+
501
514
  class RMSNorm(nn.Module):
502
515
  def __init__(self, dim):
503
516
  super().__init__()
@@ -634,7 +647,7 @@ class FeedForward(nn.Module):
634
647
 
635
648
  self.ff = Sequential(
636
649
  project_in,
637
- nn.LayerNorm(inner_dim) if post_act_ln else None,
650
+ LayerNorm(inner_dim) if post_act_ln else None,
638
651
  nn.Dropout(dropout),
639
652
  nn.Linear(inner_dim, dim_out, bias = not no_bias)
640
653
  )
@@ -1083,7 +1096,7 @@ class AttentionLayers(nn.Module):
1083
1096
  elif use_simple_rmsnorm:
1084
1097
  norm_class = SimpleRMSNorm
1085
1098
  else:
1086
- norm_class = nn.LayerNorm
1099
+ norm_class = LayerNorm
1087
1100
 
1088
1101
  norm_fn = partial(norm_class, dim)
1089
1102
 
@@ -1415,12 +1428,12 @@ class ViTransformerWrapper(nn.Module):
1415
1428
  self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
1416
1429
 
1417
1430
  self.patch_to_embedding = nn.Sequential(
1418
- nn.LayerNorm(patch_dim),
1431
+ LayerNorm(patch_dim),
1419
1432
  nn.Linear(patch_dim, dim),
1420
- nn.LayerNorm(dim)
1433
+ LayerNorm(dim)
1421
1434
  )
1422
1435
 
1423
- self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
1436
+ LayerNorm(dim) if post_emb_norm else nn.Identity()
1424
1437
  self.dropout = nn.Dropout(emb_dropout)
1425
1438
 
1426
1439
  self.attn_layers = attn_layers
@@ -1515,7 +1528,7 @@ class TransformerWrapper(nn.Module):
1515
1528
 
1516
1529
  self.emb_frac_gradient = emb_frac_gradient
1517
1530
 
1518
- self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
1531
+ self.post_emb_norm = LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
1519
1532
  self.emb_dropout = nn.Dropout(emb_dropout)
1520
1533
 
1521
1534
  self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
@@ -1524,7 +1537,7 @@ class TransformerWrapper(nn.Module):
1524
1537
  self.init_()
1525
1538
 
1526
1539
  logits_dim = default(logits_dim, num_tokens)
1527
- self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
1540
+ self.to_logits = nn.Linear(dim, logits_dim, bias = False) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
1528
1541
 
1529
1542
  # memory tokens (like [cls]) from Memory Transformers paper
1530
1543
 
@@ -1538,7 +1551,7 @@ class TransformerWrapper(nn.Module):
1538
1551
  # whether can do cached kv decoding
1539
1552
 
1540
1553
  self.can_cache_kv = self.num_memory_tokens == 0
1541
- self.can_cache_kv_outside_max_seq_len = not no_abs_pos_emb
1554
+ self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
1542
1555
 
1543
1556
  def init_(self):
1544
1557
  if self.l2norm_embed:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.26.5
3
+ Version: 1.27.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -0,0 +1,13 @@
1
+ x_transformers/__init__.py,sha256=pXc_U4M3ONUQcpNgZySDIlCF1rp7u4FFmcOYjc4WuXw,629
2
+ x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,10188
3
+ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
4
+ x_transformers/continuous.py,sha256=Ra5IClCl9G7SAiM6L9w6iY4cCznH0dSGljC9AC_bNyw,6066
5
+ x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
+ x_transformers/x_transformers.py,sha256=80qvAhandlAmt-mkiG7Ft6e5caCSDVRCFRfHppGvd5A,62216
7
+ x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
+ x_transformers/xval.py,sha256=lS9W_E_RskPQAqVZkPiUzbByoW1Ajsw_phsikA3JDAg,8139
9
+ x_transformers-1.27.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
10
+ x_transformers-1.27.0.dist-info/METADATA,sha256=7kqhAXMJ-POUrT-6-QGBbgP4j7DCUR04bK8-ULmOYxQ,661
11
+ x_transformers-1.27.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
12
+ x_transformers-1.27.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
13
+ x_transformers-1.27.0.dist-info/RECORD,,
@@ -1,13 +0,0 @@
1
- x_transformers/__init__.py,sha256=pXc_U4M3ONUQcpNgZySDIlCF1rp7u4FFmcOYjc4WuXw,629
2
- x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,10188
3
- x_transformers/autoregressive_wrapper.py,sha256=47sc7HAMNBJUGZRtZX-cO-yML0YFcw4PF6E-7pp1E0A,9614
4
- x_transformers/continuous.py,sha256=ixfgi2_zpGN03SX_STXFkNYEOAkgwVIxuS53QgDCx-g,6026
5
- x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=8n8R_huY0KwKDGTUlLLhleAqNR5M1YI_95KRmhrP_Eg,61740
7
- x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
- x_transformers/xval.py,sha256=lS9W_E_RskPQAqVZkPiUzbByoW1Ajsw_phsikA3JDAg,8139
9
- x_transformers-1.26.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
10
- x_transformers-1.26.5.dist-info/METADATA,sha256=GcEy7CtmuqOpAapRxh7Et5kfPOBiV2EIa6GjN2U-eFM,661
11
- x_transformers-1.26.5.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
12
- x_transformers-1.26.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
13
- x_transformers-1.26.5.dist-info/RECORD,,