x-transformers 1.32.6__py3-none-any.whl → 1.32.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/multi_input.py +3 -20
- {x_transformers-1.32.6.dist-info → x_transformers-1.32.7.dist-info}/METADATA +1 -1
- {x_transformers-1.32.6.dist-info → x_transformers-1.32.7.dist-info}/RECORD +6 -6
- {x_transformers-1.32.6.dist-info → x_transformers-1.32.7.dist-info}/LICENSE +0 -0
- {x_transformers-1.32.6.dist-info → x_transformers-1.32.7.dist-info}/WHEEL +0 -0
- {x_transformers-1.32.6.dist-info → x_transformers-1.32.7.dist-info}/top_level.txt +0 -0
x_transformers/multi_input.py
CHANGED
@@ -17,7 +17,7 @@ from x_transformers.x_transformers import (
|
|
17
17
|
LayerNorm,
|
18
18
|
always,
|
19
19
|
pad_at_dim,
|
20
|
-
is_empty
|
20
|
+
is_empty,
|
21
21
|
)
|
22
22
|
|
23
23
|
# helper functions
|
@@ -45,10 +45,7 @@ class MultiInputTransformerWrapper(Module):
|
|
45
45
|
post_emb_norm = False,
|
46
46
|
num_memory_tokens = None,
|
47
47
|
memory_tokens_interspersed_every = None,
|
48
|
-
tie_embedding = False,
|
49
|
-
logits_dim = None,
|
50
48
|
return_only_embed = False,
|
51
|
-
num_output_heads = 1,
|
52
49
|
use_abs_pos_emb = True,
|
53
50
|
scaled_sinu_pos_emb = False,
|
54
51
|
emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
|
@@ -87,23 +84,12 @@ class MultiInputTransformerWrapper(Module):
|
|
87
84
|
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
88
85
|
self.attn_layers = attn_layers
|
89
86
|
|
90
|
-
assert num_output_heads > 0
|
91
|
-
|
92
87
|
# output head, usually to logits of num_tokens
|
93
88
|
|
94
|
-
logits_dim = default(logits_dim, num_tokens)
|
95
|
-
|
96
|
-
self.has_multiple_heads = False
|
97
|
-
|
98
89
|
if return_only_embed:
|
99
90
|
self.to_logits = None
|
100
|
-
elif tie_embedding:
|
101
|
-
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
|
102
|
-
elif num_output_heads > 1:
|
103
|
-
self.has_multiple_heads = True
|
104
|
-
self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
|
105
91
|
else:
|
106
|
-
self.to_logits = nn.Linear(dim, logits_dim, bias = False)
|
92
|
+
self.to_logits = ModuleDict({name: nn.Linear(dim, logits_dim, bias = False) for name, logits_dim in num_tokens.items()})
|
107
93
|
|
108
94
|
# memory tokens (like [cls]) from Memory Transformers paper
|
109
95
|
|
@@ -253,10 +239,7 @@ class MultiInputTransformerWrapper(Module):
|
|
253
239
|
# projecting to logits
|
254
240
|
|
255
241
|
if not return_embeddings:
|
256
|
-
|
257
|
-
logits = tuple(fn(x) for fn in self.to_logits)
|
258
|
-
else:
|
259
|
-
logits = self.to_logits(x)
|
242
|
+
logits = {name: fn(x) for name, fn in self.to_logits.items()}
|
260
243
|
|
261
244
|
# different returns
|
262
245
|
|
@@ -3,13 +3,13 @@ x_transformers/attend.py,sha256=MI-m91wumBFqFqr_KK9MLgsLk_vPeaVbFMyDr_mWdmY,1134
|
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
|
-
x_transformers/multi_input.py,sha256=
|
6
|
+
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
8
8
|
x_transformers/x_transformers.py,sha256=5DHbYgx0RPg9QHvfBs2qHWrtn4Jji-q0d1MRBbcRPR8,76696
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.32.
|
12
|
-
x_transformers-1.32.
|
13
|
-
x_transformers-1.32.
|
14
|
-
x_transformers-1.32.
|
15
|
-
x_transformers-1.32.
|
11
|
+
x_transformers-1.32.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.32.7.dist-info/METADATA,sha256=25J9CJ3OxsR_SZkvubPhyjSN-NmvU_yVVQHNMFzoKVg,661
|
13
|
+
x_transformers-1.32.7.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
14
|
+
x_transformers-1.32.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.32.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|