sdft-pytorch 0.0.1__tar.gz → 0.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sdft_pytorch-0.0.1/sdft_pytorch.egg-info → sdft_pytorch-0.0.2}/PKG-INFO +1 -1
- {sdft_pytorch-0.0.1 → sdft_pytorch-0.0.2}/pyproject.toml +1 -1
- {sdft_pytorch-0.0.1 → sdft_pytorch-0.0.2}/sdft_pytorch/sdft_pytorch.py +32 -10
- {sdft_pytorch-0.0.1 → sdft_pytorch-0.0.2/sdft_pytorch.egg-info}/PKG-INFO +1 -1
- {sdft_pytorch-0.0.1 → sdft_pytorch-0.0.2}/tests/test_sdft.py +8 -1
- {sdft_pytorch-0.0.1 → sdft_pytorch-0.0.2}/LICENSE +0 -0
- {sdft_pytorch-0.0.1 → sdft_pytorch-0.0.2}/README.md +0 -0
- {sdft_pytorch-0.0.1 → sdft_pytorch-0.0.2}/sdft_pytorch/__init__.py +0 -0
- {sdft_pytorch-0.0.1 → sdft_pytorch-0.0.2}/sdft_pytorch.egg-info/SOURCES.txt +0 -0
- {sdft_pytorch-0.0.1 → sdft_pytorch-0.0.2}/sdft_pytorch.egg-info/dependency_links.txt +0 -0
- {sdft_pytorch-0.0.1 → sdft_pytorch-0.0.2}/sdft_pytorch.egg-info/requires.txt +0 -0
- {sdft_pytorch-0.0.1 → sdft_pytorch-0.0.2}/sdft_pytorch.egg-info/top_level.txt +0 -0
- {sdft_pytorch-0.0.1 → sdft_pytorch-0.0.2}/setup.cfg +0 -0
|
@@ -11,7 +11,11 @@ from torch import nn, cat, stack, is_tensor, tensor, Tensor
|
|
|
11
11
|
|
|
12
12
|
from einops import rearrange
|
|
13
13
|
|
|
14
|
-
from torch_einops_utils import
|
|
14
|
+
from torch_einops_utils import (
|
|
15
|
+
pad_sequence,
|
|
16
|
+
safe_cat,
|
|
17
|
+
masked_mean
|
|
18
|
+
)
|
|
15
19
|
|
|
16
20
|
from ema_pytorch import EMA
|
|
17
21
|
|
|
@@ -71,6 +75,7 @@ class SDFT(Module):
|
|
|
71
75
|
student_prompt_template = DEFAULT_STUDENT_PROMPT_TEMPLATE,
|
|
72
76
|
teacher_update_rate = 0.01,
|
|
73
77
|
teacher_prompt_template = DEFAULT_TEACHER_PROMPT_TEMPLATE,
|
|
78
|
+
eos_id = None, # if set, will mask out any losses after the first eos token id detected in a given sample
|
|
74
79
|
):
|
|
75
80
|
super().__init__()
|
|
76
81
|
|
|
@@ -100,12 +105,18 @@ class SDFT(Module):
|
|
|
100
105
|
assert get_variables_from_template(student_prompt_template) == {'question'}
|
|
101
106
|
self.student_prompt_template = Template(student_prompt_template)
|
|
102
107
|
|
|
108
|
+
# end of string
|
|
109
|
+
|
|
110
|
+
self.eos_id = eos_id
|
|
111
|
+
|
|
103
112
|
def forward(
|
|
104
113
|
self,
|
|
105
114
|
questions: list[str],
|
|
106
115
|
answers: list[str],
|
|
107
116
|
student_logit_sample_kwargs: dict = dict()
|
|
108
117
|
):
|
|
118
|
+
maybe_eos_id = self.eos_id
|
|
119
|
+
|
|
109
120
|
encode = self.tokenizer_encode
|
|
110
121
|
assert len(questions) == len(answers)
|
|
111
122
|
|
|
@@ -128,8 +139,8 @@ class SDFT(Module):
|
|
|
128
139
|
|
|
129
140
|
# accumulate
|
|
130
141
|
|
|
131
|
-
|
|
132
|
-
token_kl_div_losses =
|
|
142
|
+
student_responses = None
|
|
143
|
+
token_kl_div_losses = None
|
|
133
144
|
|
|
134
145
|
for _ in range(self.student_max_response_len):
|
|
135
146
|
|
|
@@ -156,23 +167,34 @@ class SDFT(Module):
|
|
|
156
167
|
|
|
157
168
|
).sum(dim = -1)
|
|
158
169
|
|
|
159
|
-
token_kl_div_losses
|
|
170
|
+
token_kl_div_losses = safe_cat((token_kl_div_losses, token_kl_div), dim = 1)
|
|
160
171
|
|
|
161
172
|
# sample
|
|
162
173
|
|
|
163
174
|
sampled_action = self.discrete_readout.sample(student_token_logit, **student_logit_sample_kwargs)
|
|
164
|
-
|
|
175
|
+
|
|
176
|
+
student_responses = safe_cat((student_responses, sampled_action), dim = 1)
|
|
177
|
+
|
|
178
|
+
# break if all eos
|
|
179
|
+
|
|
180
|
+
if exists(maybe_eos_id) and (student_responses == maybe_eos_id).any(dim = -1).all():
|
|
181
|
+
break
|
|
165
182
|
|
|
166
183
|
# set student and teacher tokens to the next sampled token
|
|
167
184
|
|
|
168
185
|
student_prompt_ids = sampled_action
|
|
169
186
|
teacher_prompt_ids = sampled_action
|
|
170
187
|
|
|
171
|
-
#
|
|
188
|
+
# handle eos
|
|
189
|
+
|
|
190
|
+
mask = None
|
|
191
|
+
|
|
192
|
+
if exists(maybe_eos_id):
|
|
193
|
+
mask = ((student_responses == maybe_eos_id).cumsum(dim = -1) < 0)
|
|
194
|
+
mask = F.pad(mask, (1, -1), value = True)
|
|
172
195
|
|
|
173
|
-
|
|
174
|
-
token_kl_div_losses = cat(token_kl_div_losses, dim = 1)
|
|
196
|
+
# maybe masked mean for losses
|
|
175
197
|
|
|
176
|
-
loss = token_kl_div_losses
|
|
198
|
+
loss = masked_mean(token_kl_div_losses, mask)
|
|
177
199
|
|
|
178
|
-
return SDFTOutput(loss,
|
|
200
|
+
return SDFTOutput(loss, student_responses)
|
|
@@ -1,7 +1,13 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
param = pytest.mark.parametrize
|
|
3
|
+
|
|
1
4
|
import torch
|
|
2
5
|
from sdft_pytorch.sdft_pytorch import SDFT
|
|
3
6
|
|
|
4
|
-
|
|
7
|
+
@param('eos_id', (None, 1))
|
|
8
|
+
def test_sdft(
|
|
9
|
+
eos_id
|
|
10
|
+
):
|
|
5
11
|
from torch import tensor
|
|
6
12
|
from x_transformers import TransformerWrapper, Decoder
|
|
7
13
|
|
|
@@ -24,6 +30,7 @@ def test_sdft():
|
|
|
24
30
|
sdft_wrapper = SDFT(
|
|
25
31
|
model,
|
|
26
32
|
student_max_response_len = 128,
|
|
33
|
+
eos_id = eos_id,
|
|
27
34
|
tokenizer_encode = tokenizer_encode,
|
|
28
35
|
)
|
|
29
36
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|