x-transformers 1.32.14__py3-none-any.whl → 1.32.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,6 +19,9 @@ def default(val, d):
19
19
  def identity(t, *args, **kwargs):
20
20
  return t
21
21
 
22
+ def join(arr, delimiter = ', '):
23
+ return delimiter.join(arr)
24
+
22
25
  def cast_tuple(t, length = 1):
23
26
  return t if isinstance(t, tuple) else (t,) * length
24
27
 
@@ -79,11 +82,29 @@ def top_k(logits, frac_num_tokens = 0.1, k = None):
79
82
  # top_a
80
83
 
81
84
  def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02):
82
- probs = F.softmax(logits, dim = -1)
83
- max_probs = torch.amax(probs, dim = -1, keepdim = True)
85
+ probs = logits.softmax(dim = -1)
86
+ max_probs = probs.amax(dim = -1, keepdim = True)
84
87
  limit = torch.pow(max_probs, min_p_pow) * min_p_ratio
85
88
  return torch.where(probs < limit, float('-inf'), logits)
86
89
 
90
+ # min_p
91
+ # https://arxiv.org/abs/2407.01082
92
+
93
+ def min_p(logits, min_p = 0.1):
94
+ probs = logits.softmax(dim = -1)
95
+ max_probs = probs.amax(dim = -1, keepdim = True)
96
+ limit = min_p * max_probs
97
+ return torch.where(probs < limit, float('-inf'), logits)
98
+
99
+ # filter logits functions dict[str -> Callable]
100
+
101
+ FILTER_LOGITS_FN = dict(
102
+ top_p = top_p,
103
+ top_k = top_k,
104
+ top_a = top_a,
105
+ min_p = min_p
106
+ )
107
+
87
108
  # contrastive decoding function
88
109
 
89
110
  def contrastive_decode_fn(
@@ -136,7 +157,7 @@ class AutoregressiveWrapper(Module):
136
157
  eos_token = None,
137
158
  temperature = 1.,
138
159
  prompt_lens: Tensor | None = None,
139
- filter_logits_fn: Callable = top_k,
160
+ filter_logits_fn: str | Callable = top_k,
140
161
  restrict_to_max_seq_len = True,
141
162
  amateur_model: Module | Tuple[Module] | None = None,
142
163
  filter_kwargs: dict = dict(),
@@ -153,6 +174,13 @@ class AutoregressiveWrapper(Module):
153
174
 
154
175
  b, t = prompts.shape
155
176
 
177
+ # handle filter logits fn given as string
178
+
179
+ if isinstance(filter_logits_fn, str):
180
+ assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
181
+
182
+ filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
183
+
156
184
  # handle variable lengthed prompts (prefixes)
157
185
 
158
186
  seq_start_pos = None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.32.14
3
+ Version: 1.32.15
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,6 +1,6 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
2
  x_transformers/attend.py,sha256=MI-m91wumBFqFqr_KK9MLgsLk_vPeaVbFMyDr_mWdmY,11349
3
- x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
3
+ x_transformers/autoregressive_wrapper.py,sha256=ka_iiej5lEBOcbutWQgGrFVMDilz2PFWzLhBh5_tmmg,10366
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -8,8 +8,8 @@ x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T
8
8
  x_transformers/x_transformers.py,sha256=pyRQ6lb1Sx1CbjOH882tAv9UhAzsLwIeXDPBOsiRipg,78669
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.32.14.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.32.14.dist-info/METADATA,sha256=ChbSwpAxqxzvSjyOlJb9yO1GnoNZk2c0ioq5F-NuHI0,662
13
- x_transformers-1.32.14.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
- x_transformers-1.32.14.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.32.14.dist-info/RECORD,,
11
+ x_transformers-1.32.15.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.32.15.dist-info/METADATA,sha256=cXOxhpryz1YtL2rhp8siArEI-K3t9-NOyh-4z1GfjDM,662
13
+ x_transformers-1.32.15.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
14
+ x_transformers-1.32.15.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.32.15.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (72.1.0)
2
+ Generator: setuptools (73.0.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5