x-transformers 1.32.5__py3-none-any.whl → 1.32.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/multi_input.py +5 -20
- {x_transformers-1.32.5.dist-info → x_transformers-1.32.7.dist-info}/METADATA +1 -1
- {x_transformers-1.32.5.dist-info → x_transformers-1.32.7.dist-info}/RECORD +6 -6
- {x_transformers-1.32.5.dist-info → x_transformers-1.32.7.dist-info}/LICENSE +0 -0
- {x_transformers-1.32.5.dist-info → x_transformers-1.32.7.dist-info}/WHEEL +0 -0
- {x_transformers-1.32.5.dist-info → x_transformers-1.32.7.dist-info}/top_level.txt +0 -0
x_transformers/multi_input.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import torch
|
2
4
|
from torch import nn, Tensor
|
3
5
|
from torch.nn import Module, ModuleDict
|
@@ -15,7 +17,7 @@ from x_transformers.x_transformers import (
|
|
15
17
|
LayerNorm,
|
16
18
|
always,
|
17
19
|
pad_at_dim,
|
18
|
-
is_empty
|
20
|
+
is_empty,
|
19
21
|
)
|
20
22
|
|
21
23
|
# helper functions
|
@@ -43,10 +45,7 @@ class MultiInputTransformerWrapper(Module):
|
|
43
45
|
post_emb_norm = False,
|
44
46
|
num_memory_tokens = None,
|
45
47
|
memory_tokens_interspersed_every = None,
|
46
|
-
tie_embedding = False,
|
47
|
-
logits_dim = None,
|
48
48
|
return_only_embed = False,
|
49
|
-
num_output_heads = 1,
|
50
49
|
use_abs_pos_emb = True,
|
51
50
|
scaled_sinu_pos_emb = False,
|
52
51
|
emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
|
@@ -85,23 +84,12 @@ class MultiInputTransformerWrapper(Module):
|
|
85
84
|
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
86
85
|
self.attn_layers = attn_layers
|
87
86
|
|
88
|
-
assert num_output_heads > 0
|
89
|
-
|
90
87
|
# output head, usually to logits of num_tokens
|
91
88
|
|
92
|
-
logits_dim = default(logits_dim, num_tokens)
|
93
|
-
|
94
|
-
self.has_multiple_heads = False
|
95
|
-
|
96
89
|
if return_only_embed:
|
97
90
|
self.to_logits = None
|
98
|
-
elif tie_embedding:
|
99
|
-
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
|
100
|
-
elif num_output_heads > 1:
|
101
|
-
self.has_multiple_heads = True
|
102
|
-
self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
|
103
91
|
else:
|
104
|
-
self.to_logits = nn.Linear(dim, logits_dim, bias = False)
|
92
|
+
self.to_logits = ModuleDict({name: nn.Linear(dim, logits_dim, bias = False) for name, logits_dim in num_tokens.items()})
|
105
93
|
|
106
94
|
# memory tokens (like [cls]) from Memory Transformers paper
|
107
95
|
|
@@ -251,10 +239,7 @@ class MultiInputTransformerWrapper(Module):
|
|
251
239
|
# projecting to logits
|
252
240
|
|
253
241
|
if not return_embeddings:
|
254
|
-
|
255
|
-
logits = tuple(fn(x) for fn in self.to_logits)
|
256
|
-
else:
|
257
|
-
logits = self.to_logits(x)
|
242
|
+
logits = {name: fn(x) for name, fn in self.to_logits.items()}
|
258
243
|
|
259
244
|
# different returns
|
260
245
|
|
@@ -3,13 +3,13 @@ x_transformers/attend.py,sha256=MI-m91wumBFqFqr_KK9MLgsLk_vPeaVbFMyDr_mWdmY,1134
|
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
|
-
x_transformers/multi_input.py,sha256=
|
6
|
+
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
8
8
|
x_transformers/x_transformers.py,sha256=5DHbYgx0RPg9QHvfBs2qHWrtn4Jji-q0d1MRBbcRPR8,76696
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.32.
|
12
|
-
x_transformers-1.32.
|
13
|
-
x_transformers-1.32.
|
14
|
-
x_transformers-1.32.
|
15
|
-
x_transformers-1.32.
|
11
|
+
x_transformers-1.32.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.32.7.dist-info/METADATA,sha256=25J9CJ3OxsR_SZkvubPhyjSN-NmvU_yVVQHNMFzoKVg,661
|
13
|
+
x_transformers-1.32.7.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
14
|
+
x_transformers-1.32.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.32.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|