sdft-pytorch 0.0.1__tar.gz → 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sdft-pytorch
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: SDFT - Pytorch
5
5
  Author-email: Phil Wang <lucidrains@gmail.com>
6
6
  License: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "sdft-pytorch"
7
- version = "0.0.1"
7
+ version = "0.0.2"
8
8
  description = "SDFT - Pytorch"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -11,7 +11,11 @@ from torch import nn, cat, stack, is_tensor, tensor, Tensor
11
11
 
12
12
  from einops import rearrange
13
13
 
14
- from torch_einops_utils import pad_sequence
14
+ from torch_einops_utils import (
15
+ pad_sequence,
16
+ safe_cat,
17
+ masked_mean
18
+ )
15
19
 
16
20
  from ema_pytorch import EMA
17
21
 
@@ -71,6 +75,7 @@ class SDFT(Module):
71
75
  student_prompt_template = DEFAULT_STUDENT_PROMPT_TEMPLATE,
72
76
  teacher_update_rate = 0.01,
73
77
  teacher_prompt_template = DEFAULT_TEACHER_PROMPT_TEMPLATE,
78
+ eos_id = None, # if set, will mask out any losses after the first eos token id detected in a given sample
74
79
  ):
75
80
  super().__init__()
76
81
 
@@ -100,12 +105,18 @@ class SDFT(Module):
100
105
  assert get_variables_from_template(student_prompt_template) == {'question'}
101
106
  self.student_prompt_template = Template(student_prompt_template)
102
107
 
108
+ # end of string
109
+
110
+ self.eos_id = eos_id
111
+
103
112
  def forward(
104
113
  self,
105
114
  questions: list[str],
106
115
  answers: list[str],
107
116
  student_logit_sample_kwargs: dict = dict()
108
117
  ):
118
+ maybe_eos_id = self.eos_id
119
+
109
120
  encode = self.tokenizer_encode
110
121
  assert len(questions) == len(answers)
111
122
 
@@ -128,8 +139,8 @@ class SDFT(Module):
128
139
 
129
140
  # accumulate
130
141
 
131
- student_response = []
132
- token_kl_div_losses = []
142
+ student_responses = None
143
+ token_kl_div_losses = None
133
144
 
134
145
  for _ in range(self.student_max_response_len):
135
146
 
@@ -156,23 +167,34 @@ class SDFT(Module):
156
167
 
157
168
  ).sum(dim = -1)
158
169
 
159
- token_kl_div_losses.append(token_kl_div)
170
+ token_kl_div_losses = safe_cat((token_kl_div_losses, token_kl_div), dim = 1)
160
171
 
161
172
  # sample
162
173
 
163
174
  sampled_action = self.discrete_readout.sample(student_token_logit, **student_logit_sample_kwargs)
164
- student_response.append(sampled_action)
175
+
176
+ student_responses = safe_cat((student_responses, sampled_action), dim = 1)
177
+
178
+ # break if all eos
179
+
180
+ if exists(maybe_eos_id) and (student_responses == maybe_eos_id).any(dim = -1).all():
181
+ break
165
182
 
166
183
  # set student and teacher tokens to the next sampled token
167
184
 
168
185
  student_prompt_ids = sampled_action
169
186
  teacher_prompt_ids = sampled_action
170
187
 
171
- # stack and return
188
+ # handle eos
189
+
190
+ mask = None
191
+
192
+ if exists(maybe_eos_id):
193
+ mask = ((student_responses == maybe_eos_id).cumsum(dim = -1) < 0)
194
+ mask = F.pad(mask, (1, -1), value = True)
172
195
 
173
- student_response = cat(student_response, dim = 1)
174
- token_kl_div_losses = cat(token_kl_div_losses, dim = 1)
196
+ # maybe masked mean for losses
175
197
 
176
- loss = token_kl_div_losses.mean()
198
+ loss = masked_mean(token_kl_div_losses, mask)
177
199
 
178
- return SDFTOutput(loss, student_response)
200
+ return SDFTOutput(loss, student_responses)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sdft-pytorch
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: SDFT - Pytorch
5
5
  Author-email: Phil Wang <lucidrains@gmail.com>
6
6
  License: MIT
@@ -1,7 +1,13 @@
1
+ import pytest
2
+ param = pytest.mark.parametrize
3
+
1
4
  import torch
2
5
  from sdft_pytorch.sdft_pytorch import SDFT
3
6
 
4
- def test_sdft():
7
+ @param('eos_id', (None, 1))
8
+ def test_sdft(
9
+ eos_id
10
+ ):
5
11
  from torch import tensor
6
12
  from x_transformers import TransformerWrapper, Decoder
7
13
 
@@ -24,6 +30,7 @@ def test_sdft():
24
30
  sdft_wrapper = SDFT(
25
31
  model,
26
32
  student_max_response_len = 128,
33
+ eos_id = eos_id,
27
34
  tokenizer_encode = tokenizer_encode,
28
35
  )
29
36
 
File without changes
File without changes
File without changes