sdft-pytorch 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Phil Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,39 @@
1
+ Metadata-Version: 2.4
2
+ Name: sdft-pytorch
3
+ Version: 0.0.1
4
+ Summary: SDFT - Pytorch
5
+ Author-email: Phil Wang <lucidrains@gmail.com>
6
+ License: MIT
7
+ Requires-Python: >=3.10
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: discrete-continuous-embed-readout
11
+ Requires-Dist: einops>=0.8.2
12
+ Requires-Dist: ema-pytorch
13
+ Requires-Dist: Jinja2
14
+ Requires-Dist: torch>=2.5
15
+ Requires-Dist: torch-einops-utils>=0.0.21
16
+ Provides-Extra: test
17
+ Requires-Dist: pytest; extra == "test"
18
+ Requires-Dist: x-transformers; extra == "test"
19
+ Dynamic: license-file
20
+
21
+ <img src="./sdft.png" width="450px"></img>
22
+
23
+ ## SDFT - Pytorch (wip)
24
+
25
+ Explorations into the proposed SDFT, [Self-Distillation Enables Continual Learning](https://arxiv.org/abs/2601.19897), from Shenfeld et al. of MIT
26
+
27
+ ## Citations
28
+
29
+ ```bibtex
30
+ @misc{shenfeld2026selfdistillationenablescontinuallearning,
31
+ title = {Self-Distillation Enables Continual Learning},
32
+ author = {Idan Shenfeld and Mehul Damani and Jonas Hübotter and Pulkit Agrawal},
33
+ year = {2026},
34
+ eprint = {2601.19897},
35
+ archivePrefix = {arXiv},
36
+ primaryClass = {cs.LG},
37
+ url = {https://arxiv.org/abs/2601.19897},
38
+ }
39
+ ```
@@ -0,0 +1,19 @@
1
+ <img src="./sdft.png" width="450px"></img>
2
+
3
+ ## SDFT - Pytorch (wip)
4
+
5
+ Explorations into the proposed SDFT, [Self-Distillation Enables Continual Learning](https://arxiv.org/abs/2601.19897), from Shenfeld et al. of MIT
6
+
7
+ ## Citations
8
+
9
+ ```bibtex
10
+ @misc{shenfeld2026selfdistillationenablescontinuallearning,
11
+ title = {Self-Distillation Enables Continual Learning},
12
+ author = {Idan Shenfeld and Mehul Damani and Jonas Hübotter and Pulkit Agrawal},
13
+ year = {2026},
14
+ eprint = {2601.19897},
15
+ archivePrefix = {arXiv},
16
+ primaryClass = {cs.LG},
17
+ url = {https://arxiv.org/abs/2601.19897},
18
+ }
19
+ ```
@@ -0,0 +1,36 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "sdft-pytorch"
7
+ version = "0.0.1"
8
+ description = "SDFT - Pytorch"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ license = { text = "MIT" }
12
+ authors = [
13
+ { name = "Phil Wang", email = "lucidrains@gmail.com" }
14
+ ]
15
+ dependencies = [
16
+ "discrete-continuous-embed-readout",
17
+ "einops>=0.8.2",
18
+ "ema-pytorch",
19
+ 'Jinja2',
20
+ "torch>=2.5",
21
+ "torch-einops-utils>=0.0.21"
22
+ ]
23
+
24
+ [project.optional-dependencies]
25
+ test = [
26
+ "pytest",
27
+ "x-transformers"
28
+ ]
29
+
30
+ [tool.pytest.ini_options]
31
+ pythonpath = [
32
+ "."
33
+ ]
34
+
35
+ [tool.setuptools]
36
+ packages = ["sdft_pytorch"]
@@ -0,0 +1,3 @@
1
+ from sdft_pytorch.sdft_pytorch import (
2
+ SDFT
3
+ )
@@ -0,0 +1,178 @@
1
+ from __future__ import annotations
2
+ from typing import Callable
3
+ from collections import namedtuple
4
+
5
+ from jinja2 import Template, Environment, meta
6
+
7
+ import torch
8
+ from torch.nn import Module
9
+ import torch.nn.functional as F
10
+ from torch import nn, cat, stack, is_tensor, tensor, Tensor
11
+
12
+ from einops import rearrange
13
+
14
+ from torch_einops_utils import pad_sequence
15
+
16
+ from ema_pytorch import EMA
17
+
18
+ from x_transformers import TransformerWrapper
19
+
20
+ from discrete_continuous_embed_readout import Readout
21
+
22
+ # default query / demonstration template for in-context learned distillation targets from teacher for student
23
+
24
+ DEFAULT_STUDENT_PROMPT_TEMPLATE = """
25
+ [Instruction]
26
+ You are a helpful assistant
27
+
28
+ [Query]
29
+ {{ question }}
30
+
31
+ [Response]
32
+ """
33
+
34
+ DEFAULT_TEACHER_PROMPT_TEMPLATE = """
35
+ [Task Instructions] You are a helpful assistant. Please answer the question based on the provided logic.
36
+
37
+ [Expert Demonstration] Question: {{ question }} Expert Reasoning and Answer: {{ answer }}
38
+
39
+ [Current Task] Question: {{ question }} Answer:
40
+ """
41
+
42
+ def get_variables_from_template(template):
43
+
44
+ env = Environment()
45
+
46
+ parsed_template = env.parse(template)
47
+
48
+ return set(meta.find_undeclared_variables(parsed_template))
49
+
50
+ # helpers
51
+
52
+ def exists(v):
53
+ return v is not None
54
+
55
+ def default(v, d):
56
+ return v if exists(v) else d
57
+
58
+ def maybe_cast_tensor(t):
59
+ return t if is_tensor(t) else tensor(t)
60
+
61
+ # classes
62
+
63
+ SDFTOutput = namedtuple('SDFTOutput', ('loss', 'response'))
64
+
65
+ class SDFT(Module):
66
+ def __init__(
67
+ self,
68
+ model: TransformerWrapper,
69
+ tokenizer_encode: Callable[[list[str]], list[Tensor]],
70
+ student_max_response_len,
71
+ student_prompt_template = DEFAULT_STUDENT_PROMPT_TEMPLATE,
72
+ teacher_update_rate = 0.01,
73
+ teacher_prompt_template = DEFAULT_TEACHER_PROMPT_TEMPLATE,
74
+ ):
75
+ super().__init__()
76
+
77
+ self.student = model
78
+
79
+ self.teacher = EMA(
80
+ model,
81
+ beta = 1. - teacher_update_rate,
82
+ include_online_model = False
83
+ )
84
+
85
+ # sampling
86
+
87
+ self.student_max_response_len = student_max_response_len
88
+
89
+ self.discrete_readout = Readout(dim = 0, num_discrete = 1)
90
+
91
+ # collection of prompts to list[Int['seq']]
92
+
93
+ self.tokenizer_encode = tokenizer_encode
94
+
95
+ # store templates
96
+
97
+ assert get_variables_from_template(teacher_prompt_template) == {'question', 'answer'}, 'your template must contain only variables `question` and `answer`, embedded like so - {{ question }} ... {{ answer }}'
98
+ self.teacher_prompt_template = Template(teacher_prompt_template)
99
+
100
+ assert get_variables_from_template(student_prompt_template) == {'question'}
101
+ self.student_prompt_template = Template(student_prompt_template)
102
+
103
+ def forward(
104
+ self,
105
+ questions: list[str],
106
+ answers: list[str],
107
+ student_logit_sample_kwargs: dict = dict()
108
+ ):
109
+ encode = self.tokenizer_encode
110
+ assert len(questions) == len(answers)
111
+
112
+ student_vars = [{'question': question} for question in questions]
113
+ teacher_vars = [{'question': question, 'answer': answer} for question, answer in zip(questions, answers)]
114
+
115
+ # ready the prompts for student and teacher
116
+
117
+ student_prompts_str = [self.student_prompt_template.render(questions) for questions in student_vars]
118
+ teacher_prompts_str = [self.teacher_prompt_template.render(question_answers) for question_answers in teacher_vars]
119
+
120
+ student_prompt_ids = [maybe_cast_tensor(encode(prompt)) for prompt in student_prompts_str]
121
+ teacher_prompt_ids = [maybe_cast_tensor(encode(prompt)) for prompt in teacher_prompts_str]
122
+
123
+ student_prompt_ids, student_seq_start_pos = pad_sequence(student_prompt_ids, return_lens = True, left = True, pad_lens = True)
124
+ teacher_prompt_ids, teacher_seq_start_pos = pad_sequence(teacher_prompt_ids, return_lens = True, left = True, pad_lens = True)
125
+
126
+ student_cache = None
127
+ teacher_cache = None
128
+
129
+ # accumulate
130
+
131
+ student_response = []
132
+ token_kl_div_losses = []
133
+
134
+ for _ in range(self.student_max_response_len):
135
+
136
+ # forward for logit of student and teacher
137
+
138
+ student_logits, student_cache = self.student(student_prompt_ids, cache = student_cache, seq_start_pos = student_seq_start_pos, return_intermediates = True)
139
+
140
+ with torch.no_grad():
141
+ self.teacher.eval()
142
+ teacher_logits, teacher_cache = self.teacher(teacher_prompt_ids, cache = teacher_cache, seq_start_pos = teacher_seq_start_pos, return_intermediates = True)
143
+
144
+ student_token_logit = student_logits[:, -1:]
145
+ teacher_token_logit = teacher_logits[:, -1:]
146
+
147
+ student_token_log_probs = student_token_logit.log_softmax(dim = -1)
148
+ teacher_token_probs = teacher_token_logit.softmax(dim = -1)
149
+
150
+ # privileged self distillation via ICL
151
+
152
+ token_kl_div = F.kl_div(
153
+ student_token_log_probs,
154
+ teacher_token_probs,
155
+ reduction = 'none'
156
+
157
+ ).sum(dim = -1)
158
+
159
+ token_kl_div_losses.append(token_kl_div)
160
+
161
+ # sample
162
+
163
+ sampled_action = self.discrete_readout.sample(student_token_logit, **student_logit_sample_kwargs)
164
+ student_response.append(sampled_action)
165
+
166
+ # set student and teacher tokens to the next sampled token
167
+
168
+ student_prompt_ids = sampled_action
169
+ teacher_prompt_ids = sampled_action
170
+
171
+ # stack and return
172
+
173
+ student_response = cat(student_response, dim = 1)
174
+ token_kl_div_losses = cat(token_kl_div_losses, dim = 1)
175
+
176
+ loss = token_kl_div_losses.mean()
177
+
178
+ return SDFTOutput(loss, student_response)
@@ -0,0 +1,39 @@
1
+ Metadata-Version: 2.4
2
+ Name: sdft-pytorch
3
+ Version: 0.0.1
4
+ Summary: SDFT - Pytorch
5
+ Author-email: Phil Wang <lucidrains@gmail.com>
6
+ License: MIT
7
+ Requires-Python: >=3.10
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: discrete-continuous-embed-readout
11
+ Requires-Dist: einops>=0.8.2
12
+ Requires-Dist: ema-pytorch
13
+ Requires-Dist: Jinja2
14
+ Requires-Dist: torch>=2.5
15
+ Requires-Dist: torch-einops-utils>=0.0.21
16
+ Provides-Extra: test
17
+ Requires-Dist: pytest; extra == "test"
18
+ Requires-Dist: x-transformers; extra == "test"
19
+ Dynamic: license-file
20
+
21
+ <img src="./sdft.png" width="450px"></img>
22
+
23
+ ## SDFT - Pytorch (wip)
24
+
25
+ Explorations into the proposed SDFT, [Self-Distillation Enables Continual Learning](https://arxiv.org/abs/2601.19897), from Shenfeld et al. of MIT
26
+
27
+ ## Citations
28
+
29
+ ```bibtex
30
+ @misc{shenfeld2026selfdistillationenablescontinuallearning,
31
+ title = {Self-Distillation Enables Continual Learning},
32
+ author = {Idan Shenfeld and Mehul Damani and Jonas Hübotter and Pulkit Agrawal},
33
+ year = {2026},
34
+ eprint = {2601.19897},
35
+ archivePrefix = {arXiv},
36
+ primaryClass = {cs.LG},
37
+ url = {https://arxiv.org/abs/2601.19897},
38
+ }
39
+ ```
@@ -0,0 +1,11 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ sdft_pytorch/__init__.py
5
+ sdft_pytorch/sdft_pytorch.py
6
+ sdft_pytorch.egg-info/PKG-INFO
7
+ sdft_pytorch.egg-info/SOURCES.txt
8
+ sdft_pytorch.egg-info/dependency_links.txt
9
+ sdft_pytorch.egg-info/requires.txt
10
+ sdft_pytorch.egg-info/top_level.txt
11
+ tests/test_sdft.py
@@ -0,0 +1,10 @@
1
+ discrete-continuous-embed-readout
2
+ einops>=0.8.2
3
+ ema-pytorch
4
+ Jinja2
5
+ torch>=2.5
6
+ torch-einops-utils>=0.0.21
7
+
8
+ [test]
9
+ pytest
10
+ x-transformers
@@ -0,0 +1 @@
1
+ sdft_pytorch
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,36 @@
1
+ import torch
2
+ from sdft_pytorch.sdft_pytorch import SDFT
3
+
4
+ def test_sdft():
5
+ from torch import tensor
6
+ from x_transformers import TransformerWrapper, Decoder
7
+
8
+ model = TransformerWrapper(
9
+ num_tokens = 256,
10
+ max_seq_len = 512,
11
+ attn_layers = Decoder(
12
+ dim = 512,
13
+ depth = 2
14
+ )
15
+ )
16
+
17
+ def tokenizer_encode(prompts: list[str]):
18
+ return [
19
+ tensor([ord(c) for c in prompt])
20
+ for prompt in prompts
21
+ ]
22
+
23
+
24
+ sdft_wrapper = SDFT(
25
+ model,
26
+ student_max_response_len = 128,
27
+ tokenizer_encode = tokenizer_encode,
28
+ )
29
+
30
+ loss, response = sdft_wrapper(
31
+ questions = ['12+48', '2*3'],
32
+ answers = ['60', '6']
33
+ )
34
+
35
+ loss.backward()
36
+ assert response.shape == (2, 128)