x-transformers 2.1.9__py3-none-any.whl → 2.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/belief_state_wrapper.py +99 -1
- {x_transformers-2.1.9.dist-info → x_transformers-2.1.11.dist-info}/METADATA +1 -1
- {x_transformers-2.1.9.dist-info → x_transformers-2.1.11.dist-info}/RECORD +5 -5
- {x_transformers-2.1.9.dist-info → x_transformers-2.1.11.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.9.dist-info → x_transformers-2.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -12,6 +12,11 @@ from torch.nn import Module, ModuleList
|
|
12
12
|
from torch import nn, cat, stack, tensor, arange, cartesian_prod
|
13
13
|
import torch.nn.functional as F
|
14
14
|
|
15
|
+
from x_transformers.autoregressive_wrapper import (
|
16
|
+
eval_decorator,
|
17
|
+
min_p,
|
18
|
+
)
|
19
|
+
|
15
20
|
from x_transformers.x_transformers import (
|
16
21
|
Decoder,
|
17
22
|
TransformerWrapper
|
@@ -28,6 +33,15 @@ def exists(v):
|
|
28
33
|
def default(v, d):
|
29
34
|
return v if exists(v) else d
|
30
35
|
|
36
|
+
def eval_decorator(fn):
|
37
|
+
def inner(self, *args, **kwargs):
|
38
|
+
was_training = self.training
|
39
|
+
self.eval()
|
40
|
+
out = fn(self, *args, **kwargs)
|
41
|
+
self.train(was_training)
|
42
|
+
return out
|
43
|
+
return inner
|
44
|
+
|
31
45
|
# wrappers
|
32
46
|
|
33
47
|
class BeliefStateWrapper(Module):
|
@@ -38,12 +52,14 @@ class BeliefStateWrapper(Module):
|
|
38
52
|
def __init__(
|
39
53
|
self,
|
40
54
|
forward_decoder: TransformerWrapper,
|
41
|
-
backward_decoder: TransformerWrapper,
|
55
|
+
backward_decoder: TransformerWrapper | None = None,
|
42
56
|
train_frac_forward_backward_pairs: float = 1.,
|
43
57
|
text_head: Module | None = None,
|
44
58
|
backward_ar_loss_weight: float = 1. # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
|
45
59
|
):
|
46
60
|
super().__init__()
|
61
|
+
backward_decoder = default(backward_decoder, forward_decoder) # if backward decoder not set, use the same transformer, assume it knows how to switch gears based on suffix token
|
62
|
+
|
47
63
|
assert forward_decoder.emb_dim == backward_decoder.emb_dim, 'forward and backwards model must have the same embedding dimension'
|
48
64
|
assert forward_decoder.num_tokens == backward_decoder.num_tokens, 'forward and backwards model must have the same number of tokens'
|
49
65
|
|
@@ -85,6 +101,88 @@ class BeliefStateWrapper(Module):
|
|
85
101
|
|
86
102
|
self.register_buffer('loss_weights', tensor([1., self.backward_ar_loss_weight]))
|
87
103
|
|
104
|
+
# sampling
|
105
|
+
|
106
|
+
self.max_seq_len = self.forward_decoder.max_seq_len
|
107
|
+
|
108
|
+
@torch.no_grad()
|
109
|
+
@eval_decorator
|
110
|
+
def generate_with_suffix_token_only(
|
111
|
+
self,
|
112
|
+
prompts,
|
113
|
+
seq_len,
|
114
|
+
temperature = 1.25,
|
115
|
+
cache_kv = True,
|
116
|
+
suffix: Tensor | None = None, # the goal conditioning
|
117
|
+
filter_logits_fn = min_p,
|
118
|
+
filter_kwargs = dict(
|
119
|
+
min_p = 0.1
|
120
|
+
),
|
121
|
+
**kwargs
|
122
|
+
):
|
123
|
+
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
|
124
|
+
|
125
|
+
batch, orig_seq_len = prompts.shape
|
126
|
+
|
127
|
+
out = prompts
|
128
|
+
|
129
|
+
# kv caches
|
130
|
+
|
131
|
+
cache = None
|
132
|
+
|
133
|
+
# get the encoded suffix token once
|
134
|
+
|
135
|
+
if not exists(suffix):
|
136
|
+
suffix = out[:, 0:0]
|
137
|
+
|
138
|
+
if suffix.ndim == 1:
|
139
|
+
suffix = repeat(suffix, 'n -> b n', b = batch)
|
140
|
+
|
141
|
+
suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
|
142
|
+
|
143
|
+
suffix_sos_tokens = repeat(suffix_sos_tokens, '1 1 d -> b 1 d', b = batch)
|
144
|
+
|
145
|
+
suffix_embed = self.backward_decoder(
|
146
|
+
suffix,
|
147
|
+
prepend_embeds = suffix_sos_tokens,
|
148
|
+
return_embeddings = True
|
149
|
+
)
|
150
|
+
|
151
|
+
# sampling up to seq_len
|
152
|
+
|
153
|
+
for _ in range(seq_len):
|
154
|
+
|
155
|
+
embeds, new_cache = self.forward_decoder(
|
156
|
+
out,
|
157
|
+
return_intermediates = True,
|
158
|
+
return_embeddings = True,
|
159
|
+
cache = cache,
|
160
|
+
**kwargs
|
161
|
+
)
|
162
|
+
|
163
|
+
last_embeds = embeds[:, -1:]
|
164
|
+
embeds = cat((last_embeds, suffix_embed), dim = -1)
|
165
|
+
|
166
|
+
if cache_kv and self.forward_decoder.can_cache_kv:
|
167
|
+
cache = new_cache
|
168
|
+
|
169
|
+
logits, _ = self.text_head(embeds).chunk(2, dim = -1)
|
170
|
+
|
171
|
+
logits = logits[:, -1]
|
172
|
+
|
173
|
+
if greedy:
|
174
|
+
sample = logits.argmax(dim = -1, keepdim = True)
|
175
|
+
else:
|
176
|
+
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
|
177
|
+
probs = F.softmax(filtered_logits / temperature, dim = -1)
|
178
|
+
sample = torch.multinomial(probs, 1)
|
179
|
+
|
180
|
+
# concat sample
|
181
|
+
|
182
|
+
out = torch.cat((out, sample), dim=-1)
|
183
|
+
|
184
|
+
return out[:, orig_seq_len:]
|
185
|
+
|
88
186
|
def forward(
|
89
187
|
self,
|
90
188
|
seq,
|
@@ -1,7 +1,7 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
|
-
x_transformers/belief_state_wrapper.py,sha256=
|
4
|
+
x_transformers/belief_state_wrapper.py,sha256=mpfTNZb8gadbtlpG2TxyfIMWkMVM4vigFDqCJ_mjxSU,8711
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
10
10
|
x_transformers/x_transformers.py,sha256=-80N4sqUr3sR51Ms4wCfc4jhxnPwf0ApNR4xfIsasfQ,110142
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.11.dist-info/METADATA,sha256=miwxNJVS0ZNJlw3qeJaGaxunTRxLgmz9WITWi_jnXcc,87571
|
14
|
+
x_transformers-2.1.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.11.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.11.dist-info/RECORD,,
|
File without changes
|
File without changes
|