x-transformers 2.1.9__py3-none-any.whl → 2.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,6 +12,11 @@ from torch.nn import Module, ModuleList
12
12
  from torch import nn, cat, stack, tensor, arange, cartesian_prod
13
13
  import torch.nn.functional as F
14
14
 
15
+ from x_transformers.autoregressive_wrapper import (
16
+ eval_decorator,
17
+ min_p,
18
+ )
19
+
15
20
  from x_transformers.x_transformers import (
16
21
  Decoder,
17
22
  TransformerWrapper
@@ -28,6 +33,15 @@ def exists(v):
28
33
  def default(v, d):
29
34
  return v if exists(v) else d
30
35
 
36
+ def eval_decorator(fn):
37
+ def inner(self, *args, **kwargs):
38
+ was_training = self.training
39
+ self.eval()
40
+ out = fn(self, *args, **kwargs)
41
+ self.train(was_training)
42
+ return out
43
+ return inner
44
+
31
45
  # wrappers
32
46
 
33
47
  class BeliefStateWrapper(Module):
@@ -38,12 +52,14 @@ class BeliefStateWrapper(Module):
38
52
  def __init__(
39
53
  self,
40
54
  forward_decoder: TransformerWrapper,
41
- backward_decoder: TransformerWrapper,
55
+ backward_decoder: TransformerWrapper | None = None,
42
56
  train_frac_forward_backward_pairs: float = 1.,
43
57
  text_head: Module | None = None,
44
58
  backward_ar_loss_weight: float = 1. # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
45
59
  ):
46
60
  super().__init__()
61
+ backward_decoder = default(backward_decoder, forward_decoder) # if backward decoder not set, use the same transformer, assume it knows how to switch gears based on suffix token
62
+
47
63
  assert forward_decoder.emb_dim == backward_decoder.emb_dim, 'forward and backwards model must have the same embedding dimension'
48
64
  assert forward_decoder.num_tokens == backward_decoder.num_tokens, 'forward and backwards model must have the same number of tokens'
49
65
 
@@ -85,6 +101,88 @@ class BeliefStateWrapper(Module):
85
101
 
86
102
  self.register_buffer('loss_weights', tensor([1., self.backward_ar_loss_weight]))
87
103
 
104
+ # sampling
105
+
106
+ self.max_seq_len = self.forward_decoder.max_seq_len
107
+
108
+ @torch.no_grad()
109
+ @eval_decorator
110
+ def generate_with_suffix_token_only(
111
+ self,
112
+ prompts,
113
+ seq_len,
114
+ temperature = 1.25,
115
+ cache_kv = True,
116
+ suffix: Tensor | None = None, # the goal conditioning
117
+ filter_logits_fn = min_p,
118
+ filter_kwargs = dict(
119
+ min_p = 0.1
120
+ ),
121
+ **kwargs
122
+ ):
123
+ max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
124
+
125
+ batch, orig_seq_len = prompts.shape
126
+
127
+ out = prompts
128
+
129
+ # kv caches
130
+
131
+ cache = None
132
+
133
+ # get the encoded suffix token once
134
+
135
+ if not exists(suffix):
136
+ suffix = out[:, 0:0]
137
+
138
+ if suffix.ndim == 1:
139
+ suffix = repeat(suffix, 'n -> b n', b = batch)
140
+
141
+ suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
142
+
143
+ suffix_sos_tokens = repeat(suffix_sos_tokens, '1 1 d -> b 1 d', b = batch)
144
+
145
+ suffix_embed = self.backward_decoder(
146
+ suffix,
147
+ prepend_embeds = suffix_sos_tokens,
148
+ return_embeddings = True
149
+ )
150
+
151
+ # sampling up to seq_len
152
+
153
+ for _ in range(seq_len):
154
+
155
+ embeds, new_cache = self.forward_decoder(
156
+ out,
157
+ return_intermediates = True,
158
+ return_embeddings = True,
159
+ cache = cache,
160
+ **kwargs
161
+ )
162
+
163
+ last_embeds = embeds[:, -1:]
164
+ embeds = cat((last_embeds, suffix_embed), dim = -1)
165
+
166
+ if cache_kv and self.forward_decoder.can_cache_kv:
167
+ cache = new_cache
168
+
169
+ logits, _ = self.text_head(embeds).chunk(2, dim = -1)
170
+
171
+ logits = logits[:, -1]
172
+
173
+ if greedy:
174
+ sample = logits.argmax(dim = -1, keepdim = True)
175
+ else:
176
+ filtered_logits = filter_logits_fn(logits, **filter_kwargs)
177
+ probs = F.softmax(filtered_logits / temperature, dim = -1)
178
+ sample = torch.multinomial(probs, 1)
179
+
180
+ # concat sample
181
+
182
+ out = torch.cat((out, sample), dim=-1)
183
+
184
+ return out[:, orig_seq_len:]
185
+
88
186
  def forward(
89
187
  self,
90
188
  seq,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.9
3
+ Version: 2.1.11
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,7 +1,7 @@
1
1
  x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state_wrapper.py,sha256=bA-H9BqfyVRI5Q7GIcGbzdLjmon0CKKGHb08-BnpJOs,5990
4
+ x_transformers/belief_state_wrapper.py,sha256=mpfTNZb8gadbtlpG2TxyfIMWkMVM4vigFDqCJ_mjxSU,8711
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
10
10
  x_transformers/x_transformers.py,sha256=-80N4sqUr3sR51Ms4wCfc4jhxnPwf0ApNR4xfIsasfQ,110142
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.9.dist-info/METADATA,sha256=YXRYQPw90873on2JyVr56N8Tz4ua2ibPawTUDvVE35g,87570
14
- x_transformers-2.1.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.9.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.9.dist-info/RECORD,,
13
+ x_transformers-2.1.11.dist-info/METADATA,sha256=miwxNJVS0ZNJlw3qeJaGaxunTRxLgmz9WITWi_jnXcc,87571
14
+ x_transformers-2.1.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.11.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.11.dist-info/RECORD,,