tiny-recursive-model 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tiny_recursive_model/__init__.py +7 -0
- tiny_recursive_model/trainer.py +157 -0
- tiny_recursive_model/trm.py +74 -85
- {tiny_recursive_model-0.0.1.dist-info → tiny_recursive_model-0.0.3.dist-info}/METADATA +48 -1
- tiny_recursive_model-0.0.3.dist-info/RECORD +8 -0
- tiny_recursive_model-0.0.1.dist-info/RECORD +0 -7
- {tiny_recursive_model-0.0.1.dist-info → tiny_recursive_model-0.0.3.dist-info}/WHEEL +0 -0
- {tiny_recursive_model-0.0.1.dist-info → tiny_recursive_model-0.0.3.dist-info}/licenses/LICENSE +0 -0
tiny_recursive_model/__init__.py
CHANGED
@@ -0,0 +1,157 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch.nn import Module
|
5
|
+
from torch.optim import AdamW
|
6
|
+
from torch.utils.data import Dataset, DataLoader
|
7
|
+
|
8
|
+
from einops import pack, unpack
|
9
|
+
|
10
|
+
from accelerate import Accelerator
|
11
|
+
|
12
|
+
# ema - apparently greatly helped with results
|
13
|
+
|
14
|
+
from ema_pytorch import EMA
|
15
|
+
|
16
|
+
from tiny_recursive_model.trm import TinyRecursiveModel
|
17
|
+
|
18
|
+
# helpers
|
19
|
+
|
20
|
+
def range_from_one(n):
|
21
|
+
return range(1, n + 1)
|
22
|
+
|
23
|
+
def is_empty(t):
|
24
|
+
return t.numel() == 0
|
25
|
+
|
26
|
+
# trainer
|
27
|
+
|
28
|
+
def newtonschulz5(
|
29
|
+
t,
|
30
|
+
steps = 5,
|
31
|
+
eps = 1e-7,
|
32
|
+
coefs = (3.4445, -4.7750, 2.0315)
|
33
|
+
):
|
34
|
+
if t.ndim <= 3:
|
35
|
+
return t
|
36
|
+
|
37
|
+
shape = t.shape
|
38
|
+
should_transpose = shape[-2] > shape[-1]
|
39
|
+
|
40
|
+
if should_transpose:
|
41
|
+
t = t.transpose(-1, -2)
|
42
|
+
|
43
|
+
t, packed_shape = pack([t], '* i j')
|
44
|
+
t = t / t.norm(dim = (-1, -2), keepdim = True).clamp(min = eps)
|
45
|
+
|
46
|
+
a, b, c = coefs
|
47
|
+
|
48
|
+
for _ in range(steps):
|
49
|
+
A = t @ t.transpose(-1, -2)
|
50
|
+
B = b * A + c * A @ A
|
51
|
+
t = a * t + B @ t
|
52
|
+
|
53
|
+
t, = unpack(t, packed_shape, '* i j')
|
54
|
+
|
55
|
+
if should_transpose:
|
56
|
+
t = t.transpose(-1, -2)
|
57
|
+
|
58
|
+
return t
|
59
|
+
|
60
|
+
class Trainer(Module):
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
model: TinyRecursiveModel | Module,
|
64
|
+
dataset: Dataset,
|
65
|
+
optim_klass = AdamW,
|
66
|
+
learning_rate = 1e-4,
|
67
|
+
weight_decay = 1.,
|
68
|
+
batch_size = 16,
|
69
|
+
epochs = 2,
|
70
|
+
halt_prob_thres = 0.5,
|
71
|
+
max_recurrent_steps = 12,
|
72
|
+
ema_decay_rate = 0.999,
|
73
|
+
switch_ema_every = 10000, # switch ema https://arxiv.org/abs/2402.09240
|
74
|
+
accelerate_kwargs: dict = dict(),
|
75
|
+
cpu = False
|
76
|
+
):
|
77
|
+
super().__init__()
|
78
|
+
|
79
|
+
self.accelerator = Accelerator(**accelerate_kwargs, cpu = cpu)
|
80
|
+
|
81
|
+
self.batch_size = batch_size
|
82
|
+
self.epochs = epochs
|
83
|
+
|
84
|
+
self.dataset = dataset
|
85
|
+
self.dataloader = dataloader = DataLoader(self.dataset, batch_size = self.batch_size, shuffle = True)
|
86
|
+
|
87
|
+
self.optim = optim_klass(
|
88
|
+
model.parameters(),
|
89
|
+
lr = learning_rate,
|
90
|
+
weight_decay = weight_decay
|
91
|
+
)
|
92
|
+
|
93
|
+
self.model = model
|
94
|
+
|
95
|
+
# ema model
|
96
|
+
|
97
|
+
self.ema_model = None
|
98
|
+
|
99
|
+
if self.accelerator.is_main_process:
|
100
|
+
self.ema_model = EMA(
|
101
|
+
model,
|
102
|
+
beta = ema_decay_rate,
|
103
|
+
update_model_with_ema_every = switch_ema_every,
|
104
|
+
forward_method_names = ('predict',)
|
105
|
+
)
|
106
|
+
|
107
|
+
# recurrent and act related variables
|
108
|
+
|
109
|
+
self.halt_prob_thres = halt_prob_thres
|
110
|
+
|
111
|
+
self.max_recurrent_steps = max_recurrent_steps
|
112
|
+
|
113
|
+
# prepare maybe distributed
|
114
|
+
|
115
|
+
self.model, self.optim, self.dataloader = self.accelerator.prepare(self.model, self.optim, self.dataloader)
|
116
|
+
|
117
|
+
def forward(self):
|
118
|
+
|
119
|
+
for epoch in range_from_one(self.epochs):
|
120
|
+
|
121
|
+
for dataset_input, dataset_output in self.dataloader:
|
122
|
+
|
123
|
+
outputs, latents = self.model.get_initial()
|
124
|
+
|
125
|
+
for recurrent_step in range_from_one(self.max_recurrent_steps):
|
126
|
+
|
127
|
+
loss, (main_loss, halt_loss), outputs, latents, pred, halt = self.model(dataset_input, outputs, latents, labels = dataset_output)
|
128
|
+
|
129
|
+
self.accelerator.print(f'[{epoch} ({recurrent_step} / {self.max_recurrent_steps})] loss: {main_loss.mean().item():.3f} | halt loss: {halt_loss.mean().item():.3f}')
|
130
|
+
|
131
|
+
self.accelerator.backward(loss)
|
132
|
+
|
133
|
+
self.optim.step()
|
134
|
+
self.optim.zero_grad()
|
135
|
+
|
136
|
+
if self.accelerator.is_main_process:
|
137
|
+
self.ema_model.update()
|
138
|
+
|
139
|
+
# handle halting
|
140
|
+
|
141
|
+
halt_mask = halt >= self.halt_prob_thres
|
142
|
+
|
143
|
+
if not halt_mask.any():
|
144
|
+
continue
|
145
|
+
|
146
|
+
outputs = outputs[~halt_mask]
|
147
|
+
latents = latents[~halt_mask]
|
148
|
+
dataset_input = dataset_input[~halt_mask]
|
149
|
+
dataset_output = dataset_output[~halt_mask]
|
150
|
+
|
151
|
+
if is_empty(outputs):
|
152
|
+
break
|
153
|
+
|
154
|
+
self.accelerator.print('complete')
|
155
|
+
|
156
|
+
if self.accelerator.is_main_process:
|
157
|
+
self.ema_model.copy_params_from_ema_to_model()
|
tiny_recursive_model/trm.py
CHANGED
@@ -2,13 +2,11 @@ from __future__ import annotations
|
|
2
2
|
from contextlib import nullcontext
|
3
3
|
|
4
4
|
import torch
|
5
|
-
from torch import nn
|
5
|
+
from torch import nn, cat, arange, tensor
|
6
6
|
import torch.nn.functional as F
|
7
7
|
from torch.nn import Module, ModuleList
|
8
|
-
from torch.optim import AdamW
|
9
|
-
from torch.utils.data import Dataset, DataLoader
|
10
8
|
|
11
|
-
from einops import rearrange, repeat
|
9
|
+
from einops import rearrange, repeat, reduce, pack, unpack
|
12
10
|
from einops.layers.torch import Reduce, Rearrange
|
13
11
|
|
14
12
|
# network related
|
@@ -16,10 +14,6 @@ from einops.layers.torch import Reduce, Rearrange
|
|
16
14
|
from x_transformers import Encoder
|
17
15
|
from tiny_recursive_model.mlp_mixer_1d import MLPMixer1D
|
18
16
|
|
19
|
-
# ema - apparently greatly helped with results
|
20
|
-
|
21
|
-
from ema_pytorch import EMA
|
22
|
-
|
23
17
|
# helpers
|
24
18
|
|
25
19
|
def exists(v):
|
@@ -28,9 +22,6 @@ def exists(v):
|
|
28
22
|
def default(v, d):
|
29
23
|
return v if exists(v) else d
|
30
24
|
|
31
|
-
def range_from_one(n):
|
32
|
-
return range(1, n + 1)
|
33
|
-
|
34
25
|
def is_empty(t):
|
35
26
|
return t.numel() == 0
|
36
27
|
|
@@ -72,6 +63,10 @@ class TinyRecursiveModel(Module):
|
|
72
63
|
|
73
64
|
self.halt_loss_weight = halt_loss_weight
|
74
65
|
|
66
|
+
@property
|
67
|
+
def device(self):
|
68
|
+
return next(self.parameters()).device
|
69
|
+
|
75
70
|
def refine_latent_then_output_once(
|
76
71
|
self,
|
77
72
|
inputs, # (b n d)
|
@@ -115,118 +110,112 @@ class TinyRecursiveModel(Module):
|
|
115
110
|
|
116
111
|
return outputs, latents
|
117
112
|
|
118
|
-
|
113
|
+
@torch.no_grad()
|
114
|
+
def predict(
|
119
115
|
self,
|
120
116
|
seq,
|
121
|
-
|
122
|
-
|
123
|
-
labels = None
|
117
|
+
halt_prob_thres = 0.5,
|
118
|
+
num_deep_refinement_steps = 12
|
124
119
|
):
|
120
|
+
batch = seq.shape[0]
|
121
|
+
|
125
122
|
inputs = self.input_embed(seq)
|
126
123
|
|
127
|
-
outputs, latents = self.
|
124
|
+
outputs, latents = self.get_initial()
|
128
125
|
|
129
|
-
|
126
|
+
# active batch indices, the step it exited at, and the final output predictions
|
130
127
|
|
131
|
-
|
128
|
+
active_batch_indices = arange(batch, device = self.device, dtype = torch.float32)
|
132
129
|
|
133
|
-
|
130
|
+
preds = []
|
131
|
+
exited_step_indices = []
|
132
|
+
exited_batch_indices = []
|
134
133
|
|
135
|
-
|
134
|
+
for i in range(num_deep_refinement_steps):
|
135
|
+
step = i + 1
|
136
|
+
is_last = step == num_deep_refinement_steps
|
136
137
|
|
137
|
-
|
138
|
-
return return_package
|
138
|
+
outputs, latents = self.deep_refinement(inputs, outputs, latents)
|
139
139
|
|
140
|
-
|
140
|
+
halt_prob = self.to_halt_pred(outputs)
|
141
141
|
|
142
|
-
|
142
|
+
should_halt = (halt_prob >= halt_prob_thres) | is_last
|
143
143
|
|
144
|
-
|
144
|
+
if not should_halt.any():
|
145
|
+
continue
|
145
146
|
|
146
|
-
|
147
|
+
# append to exited predictions
|
147
148
|
|
148
|
-
|
149
|
+
pred = self.to_pred(outputs[should_halt])
|
150
|
+
preds.append(pred)
|
149
151
|
|
150
|
-
|
151
|
-
losses = (loss, halt_loss)
|
152
|
+
# append the step at which early halted
|
152
153
|
|
153
|
-
|
154
|
+
exited_step_indices.extend([step] * should_halt.sum().item())
|
154
155
|
|
155
|
-
#
|
156
|
+
# append indices for sorting back
|
156
157
|
|
157
|
-
|
158
|
-
def __init__(
|
159
|
-
self,
|
160
|
-
model: TinyRecursiveModel | Module,
|
161
|
-
dataset: Dataset,
|
162
|
-
optim_klass = AdamW,
|
163
|
-
learning_rate = 1e-4,
|
164
|
-
weight_decay = 1.,
|
165
|
-
batch_size = 16,
|
166
|
-
epochs = 2,
|
167
|
-
halt_prob_thres = 0.5,
|
168
|
-
max_recurrent_steps = 12,
|
169
|
-
ema_decay_rate = 0.999,
|
170
|
-
ema_update_model_with_ema_every = 10000
|
171
|
-
):
|
172
|
-
super().__init__()
|
158
|
+
exited_batch_indices.append(active_batch_indices[should_halt])
|
173
159
|
|
174
|
-
|
175
|
-
|
160
|
+
if is_last:
|
161
|
+
continue
|
176
162
|
|
177
|
-
|
178
|
-
self.dataloader = dataloader = DataLoader(self.dataset, batch_size = self.batch_size, shuffle = True)
|
163
|
+
# ready for next round
|
179
164
|
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
)
|
165
|
+
inputs = inputs[~should_halt]
|
166
|
+
outputs = outputs[~should_halt]
|
167
|
+
latents = latents[~should_halt]
|
168
|
+
active_batch_indices = active_batch_indices[~should_halt]
|
185
169
|
|
186
|
-
|
170
|
+
if is_empty(outputs):
|
171
|
+
break
|
187
172
|
|
188
|
-
|
189
|
-
|
190
|
-
beta = ema_decay_rate,
|
191
|
-
update_model_with_ema_every = ema_update_model_with_ema_every
|
192
|
-
)
|
173
|
+
preds = cat(preds).argmax(dim = -1)
|
174
|
+
exited_step_indices = tensor(exited_step_indices)
|
193
175
|
|
194
|
-
|
176
|
+
exited_batch_indices = cat(exited_batch_indices)
|
177
|
+
sort_indices = exited_batch_indices.argsort(dim = -1)
|
195
178
|
|
196
|
-
|
179
|
+
return preds[sort_indices], exited_step_indices[sort_indices]
|
197
180
|
|
198
|
-
def forward(
|
181
|
+
def forward(
|
182
|
+
self,
|
183
|
+
seq,
|
184
|
+
outputs,
|
185
|
+
latents,
|
186
|
+
labels = None
|
187
|
+
):
|
188
|
+
inputs = self.input_embed(seq)
|
199
189
|
|
200
|
-
|
190
|
+
outputs, latents = self.deep_refinement(inputs, outputs, latents)
|
201
191
|
|
202
|
-
|
192
|
+
pred = self.to_pred(outputs)
|
203
193
|
|
204
|
-
|
194
|
+
halt_prob = self.to_halt_pred(outputs)
|
205
195
|
|
206
|
-
|
196
|
+
outputs, latents = outputs.detach(), latents.detach()
|
207
197
|
|
208
|
-
|
198
|
+
return_package = (outputs, latents, pred, halt_prob)
|
209
199
|
|
210
|
-
|
200
|
+
if not exists(labels):
|
201
|
+
return return_package
|
211
202
|
|
212
|
-
|
203
|
+
# calculate loss if labels passed in
|
213
204
|
|
214
|
-
|
215
|
-
|
205
|
+
loss = F.cross_entropy(rearrange(pred, 'b n l -> b l n'), labels, reduction = 'none')
|
206
|
+
loss = reduce(loss, 'b ... -> b', 'mean')
|
216
207
|
|
217
|
-
|
208
|
+
is_all_correct = (pred.argmax(dim = -1) == labels).all(dim = -1)
|
218
209
|
|
219
|
-
|
210
|
+
halt_loss = F.binary_cross_entropy(halt_prob, is_all_correct.float(), reduction = 'none')
|
220
211
|
|
221
|
-
|
212
|
+
# total loss and loss breakdown
|
222
213
|
|
223
|
-
|
224
|
-
|
214
|
+
total_loss = (
|
215
|
+
loss +
|
216
|
+
halt_loss * self.halt_loss_weight
|
217
|
+
)
|
225
218
|
|
226
|
-
|
227
|
-
latents = latents[~halt_mask]
|
228
|
-
dataset_input = dataset_input[~halt_mask]
|
229
|
-
dataset_output = dataset_output[~halt_mask]
|
219
|
+
losses = (loss, halt_loss)
|
230
220
|
|
231
|
-
|
232
|
-
break
|
221
|
+
return (total_loss.mean(), losses, *return_package)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: tiny-recursive-model
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.3
|
4
4
|
Summary: Tiny Recursive Model
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/tiny-recursive-model/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/tiny-recursive-model
|
@@ -55,6 +55,53 @@ Official repository is [here](https://github.com/SamsungSAILMontreal/TinyRecursi
|
|
55
55
|
|
56
56
|
<img width="300" alt="trm-fig3" src="https://github.com/user-attachments/assets/bfe3dd2a-e859-492a-84d5-faf37339f534" />
|
57
57
|
|
58
|
+
## Install
|
59
|
+
|
60
|
+
```bash
|
61
|
+
$ pip install tiny-recursive-model
|
62
|
+
```
|
63
|
+
|
64
|
+
## Usage
|
65
|
+
|
66
|
+
```python
|
67
|
+
import torch
|
68
|
+
from tiny_recursive_model import TinyRecursiveModel, MLPMixer1D, Trainer
|
69
|
+
|
70
|
+
trm = TinyRecursiveModel(
|
71
|
+
dim = 16,
|
72
|
+
num_tokens = 256,
|
73
|
+
network = MLPMixer1D(
|
74
|
+
dim = 16,
|
75
|
+
depth = 2,
|
76
|
+
seq_len = 256
|
77
|
+
),
|
78
|
+
)
|
79
|
+
|
80
|
+
from torch.utils.data import Dataset
|
81
|
+
class MockDataset(Dataset):
|
82
|
+
def __len__(self):
|
83
|
+
return 16
|
84
|
+
|
85
|
+
def __getitem__(self, idx):
|
86
|
+
inp = torch.randint(0, 256, (256,))
|
87
|
+
out = torch.randint(0, 256, (256,))
|
88
|
+
return inp, out
|
89
|
+
|
90
|
+
trainer = Trainer(
|
91
|
+
trm,
|
92
|
+
MockDataset(),
|
93
|
+
epochs = 1,
|
94
|
+
batch_size = 16,
|
95
|
+
cpu = True
|
96
|
+
)
|
97
|
+
|
98
|
+
trainer()
|
99
|
+
|
100
|
+
pred_answer, exit_indices = trm.predict(torch.randint(0, 256, (1, 256)), halt_prob_thres = 0.1)
|
101
|
+
|
102
|
+
torch.save(trm.state_dict(), 'saved-trm.pt')
|
103
|
+
```
|
104
|
+
|
58
105
|
## Citations
|
59
106
|
|
60
107
|
```bibtex
|
@@ -0,0 +1,8 @@
|
|
1
|
+
tiny_recursive_model/__init__.py,sha256=zuMcrofGu7DnvJM2Mb-O3tqBJF5q8L-8X8OTmq7_o5w,189
|
2
|
+
tiny_recursive_model/mlp_mixer_1d.py,sha256=6ivDK9dgHdVl1axg2ayifJ7H5QI3hXptHnb6lfNrno0,1398
|
3
|
+
tiny_recursive_model/trainer.py,sha256=6dQPmRaQZWI6527OvlOdgHKCFsufkZnjSHClRdHjs20,4218
|
4
|
+
tiny_recursive_model/trm.py,sha256=Ep18uwvhWjHxGeyv42ruXLVc2F6TlZg2_CmeVVfYz7c,6001
|
5
|
+
tiny_recursive_model-0.0.3.dist-info/METADATA,sha256=0enBPVOxRoReOf0hms_ZoAI4HHdMWUrrW4Ps0MTuQ9g,3943
|
6
|
+
tiny_recursive_model-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
tiny_recursive_model-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
tiny_recursive_model-0.0.3.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
tiny_recursive_model/__init__.py,sha256=UufV6--ilPn4quRWyhvaFRMKRfHvfLsAmF9RU-L31rM,77
|
2
|
-
tiny_recursive_model/mlp_mixer_1d.py,sha256=6ivDK9dgHdVl1axg2ayifJ7H5QI3hXptHnb6lfNrno0,1398
|
3
|
-
tiny_recursive_model/trm.py,sha256=YwzTod4CeeXlbAiM-TBB7rEEHWsxnPxavaGiVCTPMEM,6350
|
4
|
-
tiny_recursive_model-0.0.1.dist-info/METADATA,sha256=G-cM7okuLAiOxhofXoRh2Ih-bwYifcA3AAhmYmKo-v4,3107
|
5
|
-
tiny_recursive_model-0.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
tiny_recursive_model-0.0.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
tiny_recursive_model-0.0.1.dist-info/RECORD,,
|
File without changes
|
{tiny_recursive_model-0.0.1.dist-info → tiny_recursive_model-0.0.3.dist-info}/licenses/LICENSE
RENAMED
File without changes
|