sdft-pytorch 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdft_pytorch-0.0.1/LICENSE +21 -0
- sdft_pytorch-0.0.1/PKG-INFO +39 -0
- sdft_pytorch-0.0.1/README.md +19 -0
- sdft_pytorch-0.0.1/pyproject.toml +36 -0
- sdft_pytorch-0.0.1/sdft_pytorch/__init__.py +3 -0
- sdft_pytorch-0.0.1/sdft_pytorch/sdft_pytorch.py +178 -0
- sdft_pytorch-0.0.1/sdft_pytorch.egg-info/PKG-INFO +39 -0
- sdft_pytorch-0.0.1/sdft_pytorch.egg-info/SOURCES.txt +11 -0
- sdft_pytorch-0.0.1/sdft_pytorch.egg-info/dependency_links.txt +1 -0
- sdft_pytorch-0.0.1/sdft_pytorch.egg-info/requires.txt +10 -0
- sdft_pytorch-0.0.1/sdft_pytorch.egg-info/top_level.txt +1 -0
- sdft_pytorch-0.0.1/setup.cfg +4 -0
- sdft_pytorch-0.0.1/tests/test_sdft.py +36 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Phil Wang
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: sdft-pytorch
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: SDFT - Pytorch
|
|
5
|
+
Author-email: Phil Wang <lucidrains@gmail.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Requires-Python: >=3.10
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Requires-Dist: discrete-continuous-embed-readout
|
|
11
|
+
Requires-Dist: einops>=0.8.2
|
|
12
|
+
Requires-Dist: ema-pytorch
|
|
13
|
+
Requires-Dist: Jinja2
|
|
14
|
+
Requires-Dist: torch>=2.5
|
|
15
|
+
Requires-Dist: torch-einops-utils>=0.0.21
|
|
16
|
+
Provides-Extra: test
|
|
17
|
+
Requires-Dist: pytest; extra == "test"
|
|
18
|
+
Requires-Dist: x-transformers; extra == "test"
|
|
19
|
+
Dynamic: license-file
|
|
20
|
+
|
|
21
|
+
<img src="./sdft.png" width="450px"></img>
|
|
22
|
+
|
|
23
|
+
## SDFT - Pytorch (wip)
|
|
24
|
+
|
|
25
|
+
Explorations into the proposed SDFT, [Self-Distillation Enables Continual Learning](https://arxiv.org/abs/2601.19897), from Shenfeld et al. of MIT
|
|
26
|
+
|
|
27
|
+
## Citations
|
|
28
|
+
|
|
29
|
+
```bibtex
|
|
30
|
+
@misc{shenfeld2026selfdistillationenablescontinuallearning,
|
|
31
|
+
title = {Self-Distillation Enables Continual Learning},
|
|
32
|
+
author = {Idan Shenfeld and Mehul Damani and Jonas Hübotter and Pulkit Agrawal},
|
|
33
|
+
year = {2026},
|
|
34
|
+
eprint = {2601.19897},
|
|
35
|
+
archivePrefix = {arXiv},
|
|
36
|
+
primaryClass = {cs.LG},
|
|
37
|
+
url = {https://arxiv.org/abs/2601.19897},
|
|
38
|
+
}
|
|
39
|
+
```
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
<img src="./sdft.png" width="450px"></img>
|
|
2
|
+
|
|
3
|
+
## SDFT - Pytorch (wip)
|
|
4
|
+
|
|
5
|
+
Explorations into the proposed SDFT, [Self-Distillation Enables Continual Learning](https://arxiv.org/abs/2601.19897), from Shenfeld et al. of MIT
|
|
6
|
+
|
|
7
|
+
## Citations
|
|
8
|
+
|
|
9
|
+
```bibtex
|
|
10
|
+
@misc{shenfeld2026selfdistillationenablescontinuallearning,
|
|
11
|
+
title = {Self-Distillation Enables Continual Learning},
|
|
12
|
+
author = {Idan Shenfeld and Mehul Damani and Jonas Hübotter and Pulkit Agrawal},
|
|
13
|
+
year = {2026},
|
|
14
|
+
eprint = {2601.19897},
|
|
15
|
+
archivePrefix = {arXiv},
|
|
16
|
+
primaryClass = {cs.LG},
|
|
17
|
+
url = {https://arxiv.org/abs/2601.19897},
|
|
18
|
+
}
|
|
19
|
+
```
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "sdft-pytorch"
|
|
7
|
+
version = "0.0.1"
|
|
8
|
+
description = "SDFT - Pytorch"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
|
+
license = { text = "MIT" }
|
|
12
|
+
authors = [
|
|
13
|
+
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
14
|
+
]
|
|
15
|
+
dependencies = [
|
|
16
|
+
"discrete-continuous-embed-readout",
|
|
17
|
+
"einops>=0.8.2",
|
|
18
|
+
"ema-pytorch",
|
|
19
|
+
'Jinja2',
|
|
20
|
+
"torch>=2.5",
|
|
21
|
+
"torch-einops-utils>=0.0.21"
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
[project.optional-dependencies]
|
|
25
|
+
test = [
|
|
26
|
+
"pytest",
|
|
27
|
+
"x-transformers"
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
[tool.pytest.ini_options]
|
|
31
|
+
pythonpath = [
|
|
32
|
+
"."
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
[tool.setuptools]
|
|
36
|
+
packages = ["sdft_pytorch"]
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
3
|
+
from collections import namedtuple
|
|
4
|
+
|
|
5
|
+
from jinja2 import Template, Environment, meta
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch.nn import Module
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from torch import nn, cat, stack, is_tensor, tensor, Tensor
|
|
11
|
+
|
|
12
|
+
from einops import rearrange
|
|
13
|
+
|
|
14
|
+
from torch_einops_utils import pad_sequence
|
|
15
|
+
|
|
16
|
+
from ema_pytorch import EMA
|
|
17
|
+
|
|
18
|
+
from x_transformers import TransformerWrapper
|
|
19
|
+
|
|
20
|
+
from discrete_continuous_embed_readout import Readout
|
|
21
|
+
|
|
22
|
+
# default query / demonstration template for in-context learned distillation targets from teacher for student
|
|
23
|
+
|
|
24
|
+
DEFAULT_STUDENT_PROMPT_TEMPLATE = """
|
|
25
|
+
[Instruction]
|
|
26
|
+
You are a helpful assistant
|
|
27
|
+
|
|
28
|
+
[Query]
|
|
29
|
+
{{ question }}
|
|
30
|
+
|
|
31
|
+
[Response]
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
DEFAULT_TEACHER_PROMPT_TEMPLATE = """
|
|
35
|
+
[Task Instructions] You are a helpful assistant. Please answer the question based on the provided logic.
|
|
36
|
+
|
|
37
|
+
[Expert Demonstration] Question: {{ question }} Expert Reasoning and Answer: {{ answer }}
|
|
38
|
+
|
|
39
|
+
[Current Task] Question: {{ question }} Answer:
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def get_variables_from_template(template):
|
|
43
|
+
|
|
44
|
+
env = Environment()
|
|
45
|
+
|
|
46
|
+
parsed_template = env.parse(template)
|
|
47
|
+
|
|
48
|
+
return set(meta.find_undeclared_variables(parsed_template))
|
|
49
|
+
|
|
50
|
+
# helpers
|
|
51
|
+
|
|
52
|
+
def exists(v):
|
|
53
|
+
return v is not None
|
|
54
|
+
|
|
55
|
+
def default(v, d):
|
|
56
|
+
return v if exists(v) else d
|
|
57
|
+
|
|
58
|
+
def maybe_cast_tensor(t):
|
|
59
|
+
return t if is_tensor(t) else tensor(t)
|
|
60
|
+
|
|
61
|
+
# classes
|
|
62
|
+
|
|
63
|
+
SDFTOutput = namedtuple('SDFTOutput', ('loss', 'response'))
|
|
64
|
+
|
|
65
|
+
class SDFT(Module):
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
model: TransformerWrapper,
|
|
69
|
+
tokenizer_encode: Callable[[list[str]], list[Tensor]],
|
|
70
|
+
student_max_response_len,
|
|
71
|
+
student_prompt_template = DEFAULT_STUDENT_PROMPT_TEMPLATE,
|
|
72
|
+
teacher_update_rate = 0.01,
|
|
73
|
+
teacher_prompt_template = DEFAULT_TEACHER_PROMPT_TEMPLATE,
|
|
74
|
+
):
|
|
75
|
+
super().__init__()
|
|
76
|
+
|
|
77
|
+
self.student = model
|
|
78
|
+
|
|
79
|
+
self.teacher = EMA(
|
|
80
|
+
model,
|
|
81
|
+
beta = 1. - teacher_update_rate,
|
|
82
|
+
include_online_model = False
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# sampling
|
|
86
|
+
|
|
87
|
+
self.student_max_response_len = student_max_response_len
|
|
88
|
+
|
|
89
|
+
self.discrete_readout = Readout(dim = 0, num_discrete = 1)
|
|
90
|
+
|
|
91
|
+
# collection of prompts to list[Int['seq']]
|
|
92
|
+
|
|
93
|
+
self.tokenizer_encode = tokenizer_encode
|
|
94
|
+
|
|
95
|
+
# store templates
|
|
96
|
+
|
|
97
|
+
assert get_variables_from_template(teacher_prompt_template) == {'question', 'answer'}, 'your template must contain only variables `question` and `answer`, embedded like so - {{ question }} ... {{ answer }}'
|
|
98
|
+
self.teacher_prompt_template = Template(teacher_prompt_template)
|
|
99
|
+
|
|
100
|
+
assert get_variables_from_template(student_prompt_template) == {'question'}
|
|
101
|
+
self.student_prompt_template = Template(student_prompt_template)
|
|
102
|
+
|
|
103
|
+
def forward(
|
|
104
|
+
self,
|
|
105
|
+
questions: list[str],
|
|
106
|
+
answers: list[str],
|
|
107
|
+
student_logit_sample_kwargs: dict = dict()
|
|
108
|
+
):
|
|
109
|
+
encode = self.tokenizer_encode
|
|
110
|
+
assert len(questions) == len(answers)
|
|
111
|
+
|
|
112
|
+
student_vars = [{'question': question} for question in questions]
|
|
113
|
+
teacher_vars = [{'question': question, 'answer': answer} for question, answer in zip(questions, answers)]
|
|
114
|
+
|
|
115
|
+
# ready the prompts for student and teacher
|
|
116
|
+
|
|
117
|
+
student_prompts_str = [self.student_prompt_template.render(questions) for questions in student_vars]
|
|
118
|
+
teacher_prompts_str = [self.teacher_prompt_template.render(question_answers) for question_answers in teacher_vars]
|
|
119
|
+
|
|
120
|
+
student_prompt_ids = [maybe_cast_tensor(encode(prompt)) for prompt in student_prompts_str]
|
|
121
|
+
teacher_prompt_ids = [maybe_cast_tensor(encode(prompt)) for prompt in teacher_prompts_str]
|
|
122
|
+
|
|
123
|
+
student_prompt_ids, student_seq_start_pos = pad_sequence(student_prompt_ids, return_lens = True, left = True, pad_lens = True)
|
|
124
|
+
teacher_prompt_ids, teacher_seq_start_pos = pad_sequence(teacher_prompt_ids, return_lens = True, left = True, pad_lens = True)
|
|
125
|
+
|
|
126
|
+
student_cache = None
|
|
127
|
+
teacher_cache = None
|
|
128
|
+
|
|
129
|
+
# accumulate
|
|
130
|
+
|
|
131
|
+
student_response = []
|
|
132
|
+
token_kl_div_losses = []
|
|
133
|
+
|
|
134
|
+
for _ in range(self.student_max_response_len):
|
|
135
|
+
|
|
136
|
+
# forward for logit of student and teacher
|
|
137
|
+
|
|
138
|
+
student_logits, student_cache = self.student(student_prompt_ids, cache = student_cache, seq_start_pos = student_seq_start_pos, return_intermediates = True)
|
|
139
|
+
|
|
140
|
+
with torch.no_grad():
|
|
141
|
+
self.teacher.eval()
|
|
142
|
+
teacher_logits, teacher_cache = self.teacher(teacher_prompt_ids, cache = teacher_cache, seq_start_pos = teacher_seq_start_pos, return_intermediates = True)
|
|
143
|
+
|
|
144
|
+
student_token_logit = student_logits[:, -1:]
|
|
145
|
+
teacher_token_logit = teacher_logits[:, -1:]
|
|
146
|
+
|
|
147
|
+
student_token_log_probs = student_token_logit.log_softmax(dim = -1)
|
|
148
|
+
teacher_token_probs = teacher_token_logit.softmax(dim = -1)
|
|
149
|
+
|
|
150
|
+
# privileged self distillation via ICL
|
|
151
|
+
|
|
152
|
+
token_kl_div = F.kl_div(
|
|
153
|
+
student_token_log_probs,
|
|
154
|
+
teacher_token_probs,
|
|
155
|
+
reduction = 'none'
|
|
156
|
+
|
|
157
|
+
).sum(dim = -1)
|
|
158
|
+
|
|
159
|
+
token_kl_div_losses.append(token_kl_div)
|
|
160
|
+
|
|
161
|
+
# sample
|
|
162
|
+
|
|
163
|
+
sampled_action = self.discrete_readout.sample(student_token_logit, **student_logit_sample_kwargs)
|
|
164
|
+
student_response.append(sampled_action)
|
|
165
|
+
|
|
166
|
+
# set student and teacher tokens to the next sampled token
|
|
167
|
+
|
|
168
|
+
student_prompt_ids = sampled_action
|
|
169
|
+
teacher_prompt_ids = sampled_action
|
|
170
|
+
|
|
171
|
+
# stack and return
|
|
172
|
+
|
|
173
|
+
student_response = cat(student_response, dim = 1)
|
|
174
|
+
token_kl_div_losses = cat(token_kl_div_losses, dim = 1)
|
|
175
|
+
|
|
176
|
+
loss = token_kl_div_losses.mean()
|
|
177
|
+
|
|
178
|
+
return SDFTOutput(loss, student_response)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: sdft-pytorch
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: SDFT - Pytorch
|
|
5
|
+
Author-email: Phil Wang <lucidrains@gmail.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Requires-Python: >=3.10
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Requires-Dist: discrete-continuous-embed-readout
|
|
11
|
+
Requires-Dist: einops>=0.8.2
|
|
12
|
+
Requires-Dist: ema-pytorch
|
|
13
|
+
Requires-Dist: Jinja2
|
|
14
|
+
Requires-Dist: torch>=2.5
|
|
15
|
+
Requires-Dist: torch-einops-utils>=0.0.21
|
|
16
|
+
Provides-Extra: test
|
|
17
|
+
Requires-Dist: pytest; extra == "test"
|
|
18
|
+
Requires-Dist: x-transformers; extra == "test"
|
|
19
|
+
Dynamic: license-file
|
|
20
|
+
|
|
21
|
+
<img src="./sdft.png" width="450px"></img>
|
|
22
|
+
|
|
23
|
+
## SDFT - Pytorch (wip)
|
|
24
|
+
|
|
25
|
+
Explorations into the proposed SDFT, [Self-Distillation Enables Continual Learning](https://arxiv.org/abs/2601.19897), from Shenfeld et al. of MIT
|
|
26
|
+
|
|
27
|
+
## Citations
|
|
28
|
+
|
|
29
|
+
```bibtex
|
|
30
|
+
@misc{shenfeld2026selfdistillationenablescontinuallearning,
|
|
31
|
+
title = {Self-Distillation Enables Continual Learning},
|
|
32
|
+
author = {Idan Shenfeld and Mehul Damani and Jonas Hübotter and Pulkit Agrawal},
|
|
33
|
+
year = {2026},
|
|
34
|
+
eprint = {2601.19897},
|
|
35
|
+
archivePrefix = {arXiv},
|
|
36
|
+
primaryClass = {cs.LG},
|
|
37
|
+
url = {https://arxiv.org/abs/2601.19897},
|
|
38
|
+
}
|
|
39
|
+
```
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
sdft_pytorch/__init__.py
|
|
5
|
+
sdft_pytorch/sdft_pytorch.py
|
|
6
|
+
sdft_pytorch.egg-info/PKG-INFO
|
|
7
|
+
sdft_pytorch.egg-info/SOURCES.txt
|
|
8
|
+
sdft_pytorch.egg-info/dependency_links.txt
|
|
9
|
+
sdft_pytorch.egg-info/requires.txt
|
|
10
|
+
sdft_pytorch.egg-info/top_level.txt
|
|
11
|
+
tests/test_sdft.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
sdft_pytorch
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from sdft_pytorch.sdft_pytorch import SDFT
|
|
3
|
+
|
|
4
|
+
def test_sdft():
|
|
5
|
+
from torch import tensor
|
|
6
|
+
from x_transformers import TransformerWrapper, Decoder
|
|
7
|
+
|
|
8
|
+
model = TransformerWrapper(
|
|
9
|
+
num_tokens = 256,
|
|
10
|
+
max_seq_len = 512,
|
|
11
|
+
attn_layers = Decoder(
|
|
12
|
+
dim = 512,
|
|
13
|
+
depth = 2
|
|
14
|
+
)
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
def tokenizer_encode(prompts: list[str]):
|
|
18
|
+
return [
|
|
19
|
+
tensor([ord(c) for c in prompt])
|
|
20
|
+
for prompt in prompts
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
sdft_wrapper = SDFT(
|
|
25
|
+
model,
|
|
26
|
+
student_max_response_len = 128,
|
|
27
|
+
tokenizer_encode = tokenizer_encode,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
loss, response = sdft_wrapper(
|
|
31
|
+
questions = ['12+48', '2*3'],
|
|
32
|
+
answers = ['60', '6']
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
loss.backward()
|
|
36
|
+
assert response.shape == (2, 128)
|