x-transformers 1.32.12__py3-none-any.whl → 1.32.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/autoregressive_wrapper.py +31 -3
- x_transformers/x_transformers.py +7 -2
- {x_transformers-1.32.12.dist-info → x_transformers-1.32.15.dist-info}/METADATA +1 -1
- {x_transformers-1.32.12.dist-info → x_transformers-1.32.15.dist-info}/RECORD +7 -7
- {x_transformers-1.32.12.dist-info → x_transformers-1.32.15.dist-info}/WHEEL +1 -1
- {x_transformers-1.32.12.dist-info → x_transformers-1.32.15.dist-info}/LICENSE +0 -0
- {x_transformers-1.32.12.dist-info → x_transformers-1.32.15.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,9 @@ def default(val, d):
|
|
19
19
|
def identity(t, *args, **kwargs):
|
20
20
|
return t
|
21
21
|
|
22
|
+
def join(arr, delimiter = ', '):
|
23
|
+
return delimiter.join(arr)
|
24
|
+
|
22
25
|
def cast_tuple(t, length = 1):
|
23
26
|
return t if isinstance(t, tuple) else (t,) * length
|
24
27
|
|
@@ -79,11 +82,29 @@ def top_k(logits, frac_num_tokens = 0.1, k = None):
|
|
79
82
|
# top_a
|
80
83
|
|
81
84
|
def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02):
|
82
|
-
probs =
|
83
|
-
max_probs =
|
85
|
+
probs = logits.softmax(dim = -1)
|
86
|
+
max_probs = probs.amax(dim = -1, keepdim = True)
|
84
87
|
limit = torch.pow(max_probs, min_p_pow) * min_p_ratio
|
85
88
|
return torch.where(probs < limit, float('-inf'), logits)
|
86
89
|
|
90
|
+
# min_p
|
91
|
+
# https://arxiv.org/abs/2407.01082
|
92
|
+
|
93
|
+
def min_p(logits, min_p = 0.1):
|
94
|
+
probs = logits.softmax(dim = -1)
|
95
|
+
max_probs = probs.amax(dim = -1, keepdim = True)
|
96
|
+
limit = min_p * max_probs
|
97
|
+
return torch.where(probs < limit, float('-inf'), logits)
|
98
|
+
|
99
|
+
# filter logits functions dict[str -> Callable]
|
100
|
+
|
101
|
+
FILTER_LOGITS_FN = dict(
|
102
|
+
top_p = top_p,
|
103
|
+
top_k = top_k,
|
104
|
+
top_a = top_a,
|
105
|
+
min_p = min_p
|
106
|
+
)
|
107
|
+
|
87
108
|
# contrastive decoding function
|
88
109
|
|
89
110
|
def contrastive_decode_fn(
|
@@ -136,7 +157,7 @@ class AutoregressiveWrapper(Module):
|
|
136
157
|
eos_token = None,
|
137
158
|
temperature = 1.,
|
138
159
|
prompt_lens: Tensor | None = None,
|
139
|
-
filter_logits_fn: Callable = top_k,
|
160
|
+
filter_logits_fn: str | Callable = top_k,
|
140
161
|
restrict_to_max_seq_len = True,
|
141
162
|
amateur_model: Module | Tuple[Module] | None = None,
|
142
163
|
filter_kwargs: dict = dict(),
|
@@ -153,6 +174,13 @@ class AutoregressiveWrapper(Module):
|
|
153
174
|
|
154
175
|
b, t = prompts.shape
|
155
176
|
|
177
|
+
# handle filter logits fn given as string
|
178
|
+
|
179
|
+
if isinstance(filter_logits_fn, str):
|
180
|
+
assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
|
181
|
+
|
182
|
+
filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
|
183
|
+
|
156
184
|
# handle variable lengthed prompts (prefixes)
|
157
185
|
|
158
186
|
seq_start_pos = None
|
x_transformers/x_transformers.py
CHANGED
@@ -1925,7 +1925,8 @@ class TransformerWrapper(Module):
|
|
1925
1925
|
attn_z_loss_weight = 1e-4,
|
1926
1926
|
average_pool_embed = False,
|
1927
1927
|
use_cls_token = False,
|
1928
|
-
squeeze_out_last_dim = False
|
1928
|
+
squeeze_out_last_dim = False,
|
1929
|
+
token_emb: TokenEmbedding | None = None,
|
1929
1930
|
):
|
1930
1931
|
super().__init__()
|
1931
1932
|
|
@@ -1939,7 +1940,11 @@ class TransformerWrapper(Module):
|
|
1939
1940
|
self.shift_mem_down = shift_mem_down
|
1940
1941
|
|
1941
1942
|
self.l2norm_embed = l2norm_embed
|
1942
|
-
|
1943
|
+
|
1944
|
+
if not exists(token_emb):
|
1945
|
+
token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
|
1946
|
+
|
1947
|
+
self.token_emb = token_emb
|
1943
1948
|
|
1944
1949
|
no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb)
|
1945
1950
|
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
2
|
x_transformers/attend.py,sha256=MI-m91wumBFqFqr_KK9MLgsLk_vPeaVbFMyDr_mWdmY,11349
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=ka_iiej5lEBOcbutWQgGrFVMDilz2PFWzLhBh5_tmmg,10366
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=pyRQ6lb1Sx1CbjOH882tAv9UhAzsLwIeXDPBOsiRipg,78669
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.32.
|
12
|
-
x_transformers-1.32.
|
13
|
-
x_transformers-1.32.
|
14
|
-
x_transformers-1.32.
|
15
|
-
x_transformers-1.32.
|
11
|
+
x_transformers-1.32.15.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.32.15.dist-info/METADATA,sha256=cXOxhpryz1YtL2rhp8siArEI-K3t9-NOyh-4z1GfjDM,662
|
13
|
+
x_transformers-1.32.15.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
|
14
|
+
x_transformers-1.32.15.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.32.15.dist-info/RECORD,,
|
File without changes
|
File without changes
|