x-transformers 1.42.7__py3-none-any.whl → 1.42.8__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -238,6 +238,13 @@ class TokenEmbedding(Module):
238
238
  token_emb = self.emb(x.long())
239
239
  return l2norm(token_emb) if self.l2norm_embed else token_emb
240
240
 
241
+ def init_(self):
242
+ if self.l2norm_embed:
243
+ nn.init.normal_(self.emb.weight, std=1e-5)
244
+ return
245
+ nn.init.kaiming_normal_(self.emb.weight)
246
+
247
+
241
248
  # positional embeddings
242
249
 
243
250
  class AbsolutePositionalEmbedding(Module):
@@ -2261,7 +2268,8 @@ class TransformerWrapper(Module):
2261
2268
  token_emb: TokenEmbedding | None = None,
2262
2269
  mixture_of_softmax = False,
2263
2270
  mixture_of_softmax_k = 4,
2264
- sigsoftmax_logits = False
2271
+ sigsoftmax_logits = False,
2272
+ to_logits: Module | None = None,
2265
2273
  ):
2266
2274
  super().__init__()
2267
2275
 
@@ -2363,11 +2371,12 @@ class TransformerWrapper(Module):
2363
2371
  if return_only_embed:
2364
2372
  self.to_logits = None
2365
2373
  elif tie_embedding:
2374
+ assert isinstance(token_emb, TokenEmbedding), 'can only tie embedding if using `TokenEmbedding`'
2366
2375
  self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
2367
2376
  elif num_output_heads > 1:
2368
2377
  self.to_logits = ModuleList([LinearNoBias(dim, logits_dim) for _ in range(num_output_heads)])
2369
2378
  else:
2370
- self.to_logits = LinearNoBias(dim, logits_dim)
2379
+ self.to_logits = LinearNoBias(dim, logits_dim) if not exists(to_logits) else to_logits
2371
2380
 
2372
2381
  # memory tokens (like [cls]) from Memory Transformers paper
2373
2382
 
@@ -2388,13 +2397,12 @@ class TransformerWrapper(Module):
2388
2397
  self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
2389
2398
 
2390
2399
  def init_(self):
2400
+ if hasattr(self.token_emb, 'init_'):
2401
+ self.token_emb.init_()
2402
+
2391
2403
  if self.l2norm_embed:
2392
- nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
2393
2404
  if not isinstance(self.pos_emb, always):
2394
2405
  nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
2395
- return
2396
-
2397
- nn.init.kaiming_normal_(self.token_emb.emb.weight)
2398
2406
 
2399
2407
  def forward(
2400
2408
  self,
@@ -2417,7 +2425,9 @@ class TransformerWrapper(Module):
2417
2425
  attn_z_loss_weight = 1e-4,
2418
2426
  seq_start_pos = None,
2419
2427
  cache: LayerIntermediates | None = None,
2420
- **kwargs
2428
+ token_emb_kwargs = dict(),
2429
+ to_logits_kwargs = dict(),
2430
+ **kwargs,
2421
2431
  ):
2422
2432
  b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
2423
2433
 
@@ -2428,7 +2438,7 @@ class TransformerWrapper(Module):
2428
2438
 
2429
2439
  external_pos_emb = exists(pos) and pos.dtype != torch.long
2430
2440
  pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
2431
- x = self.token_emb(x) + pos_emb
2441
+ x = self.token_emb(x, **token_emb_kwargs) + pos_emb
2432
2442
 
2433
2443
  # add additional embeddings
2434
2444
 
@@ -2583,9 +2593,9 @@ class TransformerWrapper(Module):
2583
2593
 
2584
2594
  if not return_embeddings:
2585
2595
  if self.has_multiple_heads:
2586
- logits = tuple(fn(x) for fn in self.to_logits)
2596
+ logits = tuple(fn(x, **to_logits_kwargs) for fn in self.to_logits)
2587
2597
  else:
2588
- logits = self.to_logits(x)
2598
+ logits = self.to_logits(x, **to_logits_kwargs)
2589
2599
 
2590
2600
  # maybe sig softmax
2591
2601
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.7
3
+ Version: 1.42.8
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=6jXSMHViCU64gLMbxRJ6C8bgcLrPFbT-m-fhtusqq3g,93117
9
+ x_transformers/x_transformers.py,sha256=275B_yDHePxUvlLcMNgnCUmZ1qZEkwBrpk6IA8n-pnY,93550
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.42.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.42.7.dist-info/METADATA,sha256=tM7s2gIMFH8hy_YZY84BhZ-yUoH6PTyjusK0dMOpTN8,689
14
- x_transformers-1.42.7.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
- x_transformers-1.42.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.42.7.dist-info/RECORD,,
12
+ x_transformers-1.42.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.42.8.dist-info/METADATA,sha256=1d2BVA6iHKpT4UzbYxw16ijAFGJT-u29zTnYtV6Lp3w,689
14
+ x_transformers-1.42.8.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
+ x_transformers-1.42.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.42.8.dist-info/RECORD,,