x-transformers 1.32.6__py3-none-any.whl → 1.32.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,7 +17,7 @@ from x_transformers.x_transformers import (
17
17
  LayerNorm,
18
18
  always,
19
19
  pad_at_dim,
20
- is_empty
20
+ is_empty,
21
21
  )
22
22
 
23
23
  # helper functions
@@ -45,10 +45,7 @@ class MultiInputTransformerWrapper(Module):
45
45
  post_emb_norm = False,
46
46
  num_memory_tokens = None,
47
47
  memory_tokens_interspersed_every = None,
48
- tie_embedding = False,
49
- logits_dim = None,
50
48
  return_only_embed = False,
51
- num_output_heads = 1,
52
49
  use_abs_pos_emb = True,
53
50
  scaled_sinu_pos_emb = False,
54
51
  emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
@@ -87,23 +84,12 @@ class MultiInputTransformerWrapper(Module):
87
84
  self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
88
85
  self.attn_layers = attn_layers
89
86
 
90
- assert num_output_heads > 0
91
-
92
87
  # output head, usually to logits of num_tokens
93
88
 
94
- logits_dim = default(logits_dim, num_tokens)
95
-
96
- self.has_multiple_heads = False
97
-
98
89
  if return_only_embed:
99
90
  self.to_logits = None
100
- elif tie_embedding:
101
- self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
102
- elif num_output_heads > 1:
103
- self.has_multiple_heads = True
104
- self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
105
91
  else:
106
- self.to_logits = nn.Linear(dim, logits_dim, bias = False)
92
+ self.to_logits = ModuleDict({name: nn.Linear(dim, logits_dim, bias = False) for name, logits_dim in num_tokens.items()})
107
93
 
108
94
  # memory tokens (like [cls]) from Memory Transformers paper
109
95
 
@@ -253,10 +239,7 @@ class MultiInputTransformerWrapper(Module):
253
239
  # projecting to logits
254
240
 
255
241
  if not return_embeddings:
256
- if self.has_multiple_heads:
257
- logits = tuple(fn(x) for fn in self.to_logits)
258
- else:
259
- logits = self.to_logits(x)
242
+ logits = {name: fn(x) for name, fn in self.to_logits.items()}
260
243
 
261
244
  # different returns
262
245
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.32.6
3
+ Version: 1.32.7
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,13 +3,13 @@ x_transformers/attend.py,sha256=MI-m91wumBFqFqr_KK9MLgsLk_vPeaVbFMyDr_mWdmY,1134
3
3
  x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
- x_transformers/multi_input.py,sha256=QvYrueLPcfcm0gvoSZYCd7zVgUTi2i0fZkvXowCwx_s,9794
6
+ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
8
8
  x_transformers/x_transformers.py,sha256=5DHbYgx0RPg9QHvfBs2qHWrtn4Jji-q0d1MRBbcRPR8,76696
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.32.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.32.6.dist-info/METADATA,sha256=_Uhhkxnq0aIykqJxbQdQOpxcnYJcciV5Z9SwghDiTpQ,661
13
- x_transformers-1.32.6.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
- x_transformers-1.32.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.32.6.dist-info/RECORD,,
11
+ x_transformers-1.32.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.32.7.dist-info/METADATA,sha256=25J9CJ3OxsR_SZkvubPhyjSN-NmvU_yVVQHNMFzoKVg,661
13
+ x_transformers-1.32.7.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
+ x_transformers-1.32.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.32.7.dist-info/RECORD,,