x-transformers 2.1.10__py3-none-any.whl → 2.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,6 +12,11 @@ from torch.nn import Module, ModuleList
12
12
  from torch import nn, cat, stack, tensor, arange, cartesian_prod
13
13
  import torch.nn.functional as F
14
14
 
15
+ from x_transformers.autoregressive_wrapper import (
16
+ eval_decorator,
17
+ min_p,
18
+ )
19
+
15
20
  from x_transformers.x_transformers import (
16
21
  Decoder,
17
22
  TransformerWrapper
@@ -28,6 +33,15 @@ def exists(v):
28
33
  def default(v, d):
29
34
  return v if exists(v) else d
30
35
 
36
+ def eval_decorator(fn):
37
+ def inner(self, *args, **kwargs):
38
+ was_training = self.training
39
+ self.eval()
40
+ out = fn(self, *args, **kwargs)
41
+ self.train(was_training)
42
+ return out
43
+ return inner
44
+
31
45
  # wrappers
32
46
 
33
47
  class BeliefStateWrapper(Module):
@@ -87,6 +101,85 @@ class BeliefStateWrapper(Module):
87
101
 
88
102
  self.register_buffer('loss_weights', tensor([1., self.backward_ar_loss_weight]))
89
103
 
104
+ # sampling
105
+
106
+ self.max_seq_len = self.forward_decoder.max_seq_len
107
+
108
+ @torch.no_grad()
109
+ @eval_decorator
110
+ def generate_with_suffix_token_only(
111
+ self,
112
+ prompts,
113
+ seq_len,
114
+ temperature = 1.25,
115
+ cache_kv = True,
116
+ suffix: Tensor | None = None, # the goal conditioning
117
+ filter_logits_fn = min_p,
118
+ filter_kwargs = dict(
119
+ min_p = 0.1
120
+ ),
121
+ **kwargs
122
+ ):
123
+ max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
124
+
125
+ batch, orig_seq_len = prompts.shape
126
+
127
+ out = prompts
128
+
129
+ # kv caches
130
+
131
+ cache = None
132
+
133
+ # get the encoded suffix token once
134
+
135
+ if exists(suffix) and suffix.ndim == 1:
136
+ suffix = repeat(suffix, 'n -> b n', b = batch)
137
+
138
+ suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
139
+
140
+ suffix_sos_tokens = repeat(suffix_sos_tokens, '1 1 d -> b 1 d', b = batch)
141
+
142
+ suffix_embed = self.backward_decoder(
143
+ suffix,
144
+ prepend_embeds = suffix_sos_tokens,
145
+ return_embeddings = True
146
+ )
147
+
148
+ # sampling up to seq_len
149
+
150
+ for _ in range(seq_len):
151
+
152
+ embeds, new_cache = self.forward_decoder(
153
+ out,
154
+ return_intermediates = True,
155
+ return_embeddings = True,
156
+ cache = cache,
157
+ **kwargs
158
+ )
159
+
160
+ last_embeds = embeds[:, -1:]
161
+ embeds = cat((last_embeds, suffix_embed), dim = -1)
162
+
163
+ if cache_kv and self.forward_decoder.can_cache_kv:
164
+ cache = new_cache
165
+
166
+ logits, _ = self.text_head(embeds).chunk(2, dim = -1)
167
+
168
+ logits = logits[:, -1]
169
+
170
+ if greedy:
171
+ sample = logits.argmax(dim = -1, keepdim = True)
172
+ else:
173
+ filtered_logits = filter_logits_fn(logits, **filter_kwargs)
174
+ probs = F.softmax(filtered_logits / temperature, dim = -1)
175
+ sample = torch.multinomial(probs, 1)
176
+
177
+ # concat sample
178
+
179
+ out = torch.cat((out, sample), dim=-1)
180
+
181
+ return out[:, orig_seq_len:]
182
+
90
183
  def forward(
91
184
  self,
92
185
  seq,
@@ -2898,6 +2898,15 @@ class TransformerWrapper(Module):
2898
2898
  to_logits_kwargs = dict(),
2899
2899
  **kwargs,
2900
2900
  ):
2901
+
2902
+ # if sequence is None, auto create an empty one if `prepend_embeds` was supplied
2903
+
2904
+ if not exists(x):
2905
+ assert exists(prepend_embeds)
2906
+ x = prepend_embeds.new_empty((prepend_embeds.shape[0], 0), dtype = torch.long)
2907
+
2908
+ # shapes and variables
2909
+
2901
2910
  b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
2902
2911
 
2903
2912
  return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.10
3
+ Version: 2.1.12
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,16 +1,16 @@
1
1
  x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state_wrapper.py,sha256=_jGwSjaNOP2ztFU6OR1B623PGyN-BYopYKQbV8b05do,6190
4
+ x_transformers/belief_state_wrapper.py,sha256=aMmekjRNHem-4MKXTK8z_u0497EThUhvKLISwaKbqQw,8665
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
8
8
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
9
9
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
10
- x_transformers/x_transformers.py,sha256=-80N4sqUr3sR51Ms4wCfc4jhxnPwf0ApNR4xfIsasfQ,110142
10
+ x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.10.dist-info/METADATA,sha256=g71r5_pP-i2t9v5h5lkDpKVRwVE2SDTWskZAopLw0X8,87571
14
- x_transformers-2.1.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.10.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.10.dist-info/RECORD,,
13
+ x_transformers-2.1.12.dist-info/METADATA,sha256=JXQYWgfNcv43jVmFY_FAiIhD5EYvX88BA4zJXeRMxa0,87571
14
+ x_transformers-2.1.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.12.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.12.dist-info/RECORD,,