tiny-recursive-model 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,7 @@
1
1
  from tiny_recursive_model.trm import (
2
2
  TinyRecursiveModel,
3
+ )
4
+
5
+ from tiny_recursive_model.trainer import (
3
6
  Trainer
4
7
  )
@@ -0,0 +1,157 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch.nn import Module
5
+ from torch.optim import AdamW
6
+ from torch.utils.data import Dataset, DataLoader
7
+
8
+ from einops import pack, unpack
9
+
10
+ from accelerate import Accelerator
11
+
12
+ # ema - apparently greatly helped with results
13
+
14
+ from ema_pytorch import EMA
15
+
16
+ from tiny_recursive_model.trm import TinyRecursiveModel
17
+
18
+ # helpers
19
+
20
+ def range_from_one(n):
21
+ return range(1, n + 1)
22
+
23
+ def is_empty(t):
24
+ return t.numel() == 0
25
+
26
+ # trainer
27
+
28
+ def newtonschulz5(
29
+ t,
30
+ steps = 5,
31
+ eps = 1e-7,
32
+ coefs = (3.4445, -4.7750, 2.0315)
33
+ ):
34
+ if t.ndim <= 3:
35
+ return t
36
+
37
+ shape = t.shape
38
+ should_transpose = shape[-2] > shape[-1]
39
+
40
+ if should_transpose:
41
+ t = t.transpose(-1, -2)
42
+
43
+ t, packed_shape = pack([t], '* i j')
44
+ t = t / t.norm(dim = (-1, -2), keepdim = True).clamp(min = eps)
45
+
46
+ a, b, c = coefs
47
+
48
+ for _ in range(steps):
49
+ A = t @ t.transpose(-1, -2)
50
+ B = b * A + c * A @ A
51
+ t = a * t + B @ t
52
+
53
+ t, = unpack(t, packed_shape, '* i j')
54
+
55
+ if should_transpose:
56
+ t = t.transpose(-1, -2)
57
+
58
+ return t
59
+
60
+ class Trainer(Module):
61
+ def __init__(
62
+ self,
63
+ model: TinyRecursiveModel | Module,
64
+ dataset: Dataset,
65
+ optim_klass = AdamW,
66
+ learning_rate = 1e-4,
67
+ weight_decay = 1.,
68
+ batch_size = 16,
69
+ epochs = 2,
70
+ halt_prob_thres = 0.5,
71
+ max_recurrent_steps = 12,
72
+ ema_decay_rate = 0.999,
73
+ switch_ema_every = 10000, # switch ema https://arxiv.org/abs/2402.09240
74
+ accelerate_kwargs: dict = dict(),
75
+ cpu = False
76
+ ):
77
+ super().__init__()
78
+
79
+ self.accelerator = Accelerator(**accelerate_kwargs, cpu = cpu)
80
+
81
+ self.batch_size = batch_size
82
+ self.epochs = epochs
83
+
84
+ self.dataset = dataset
85
+ self.dataloader = dataloader = DataLoader(self.dataset, batch_size = self.batch_size, shuffle = True)
86
+
87
+ self.optim = optim_klass(
88
+ model.parameters(),
89
+ lr = learning_rate,
90
+ weight_decay = weight_decay
91
+ )
92
+
93
+ self.model = model
94
+
95
+ # ema model
96
+
97
+ self.ema_model = None
98
+
99
+ if self.accelerator.is_main_process:
100
+ self.ema_model = EMA(
101
+ model,
102
+ beta = ema_decay_rate,
103
+ update_model_with_ema_every = switch_ema_every,
104
+ forward_method_names = ('predict',)
105
+ )
106
+
107
+ # recurrent and act related variables
108
+
109
+ self.halt_prob_thres = halt_prob_thres
110
+
111
+ self.max_recurrent_steps = max_recurrent_steps
112
+
113
+ # prepare maybe distributed
114
+
115
+ self.model, self.optim, self.dataloader = self.accelerator.prepare(self.model, self.optim, self.dataloader)
116
+
117
+ def forward(self):
118
+
119
+ for epoch in range_from_one(self.epochs):
120
+
121
+ for dataset_input, dataset_output in self.dataloader:
122
+
123
+ outputs, latents = self.model.get_initial()
124
+
125
+ for recurrent_step in range_from_one(self.max_recurrent_steps):
126
+
127
+ loss, (main_loss, halt_loss), outputs, latents, pred, halt = self.model(dataset_input, outputs, latents, labels = dataset_output)
128
+
129
+ self.accelerator.print(f'[{epoch} ({recurrent_step} / {self.max_recurrent_steps})] loss: {main_loss.mean().item():.3f} | halt loss: {halt_loss.mean().item():.3f}')
130
+
131
+ self.accelerator.backward(loss)
132
+
133
+ self.optim.step()
134
+ self.optim.zero_grad()
135
+
136
+ if self.accelerator.is_main_process:
137
+ self.ema_model.update()
138
+
139
+ # handle halting
140
+
141
+ halt_mask = halt >= self.halt_prob_thres
142
+
143
+ if not halt_mask.any():
144
+ continue
145
+
146
+ outputs = outputs[~halt_mask]
147
+ latents = latents[~halt_mask]
148
+ dataset_input = dataset_input[~halt_mask]
149
+ dataset_output = dataset_output[~halt_mask]
150
+
151
+ if is_empty(outputs):
152
+ break
153
+
154
+ self.accelerator.print('complete')
155
+
156
+ if self.accelerator.is_main_process:
157
+ self.ema_model.copy_params_from_ema_to_model()
@@ -2,13 +2,11 @@ from __future__ import annotations
2
2
  from contextlib import nullcontext
3
3
 
4
4
  import torch
5
- from torch import nn
5
+ from torch import nn, cat, arange, tensor
6
6
  import torch.nn.functional as F
7
7
  from torch.nn import Module, ModuleList
8
- from torch.optim import AdamW
9
- from torch.utils.data import Dataset, DataLoader
10
8
 
11
- from einops import rearrange, repeat
9
+ from einops import rearrange, repeat, reduce, pack, unpack
12
10
  from einops.layers.torch import Reduce, Rearrange
13
11
 
14
12
  # network related
@@ -16,10 +14,6 @@ from einops.layers.torch import Reduce, Rearrange
16
14
  from x_transformers import Encoder
17
15
  from tiny_recursive_model.mlp_mixer_1d import MLPMixer1D
18
16
 
19
- # ema - apparently greatly helped with results
20
-
21
- from ema_pytorch import EMA
22
-
23
17
  # helpers
24
18
 
25
19
  def exists(v):
@@ -28,9 +22,6 @@ def exists(v):
28
22
  def default(v, d):
29
23
  return v if exists(v) else d
30
24
 
31
- def range_from_one(n):
32
- return range(1, n + 1)
33
-
34
25
  def is_empty(t):
35
26
  return t.numel() == 0
36
27
 
@@ -72,6 +63,10 @@ class TinyRecursiveModel(Module):
72
63
 
73
64
  self.halt_loss_weight = halt_loss_weight
74
65
 
66
+ @property
67
+ def device(self):
68
+ return next(self.parameters()).device
69
+
75
70
  def refine_latent_then_output_once(
76
71
  self,
77
72
  inputs, # (b n d)
@@ -115,118 +110,112 @@ class TinyRecursiveModel(Module):
115
110
 
116
111
  return outputs, latents
117
112
 
118
- def forward(
113
+ @torch.no_grad()
114
+ def predict(
119
115
  self,
120
116
  seq,
121
- outputs,
122
- latents,
123
- labels = None
117
+ halt_prob_thres = 0.5,
118
+ num_deep_refinement_steps = 12
124
119
  ):
120
+ batch = seq.shape[0]
121
+
125
122
  inputs = self.input_embed(seq)
126
123
 
127
- outputs, latents = self.deep_refinement(inputs, outputs, latents)
124
+ outputs, latents = self.get_initial()
128
125
 
129
- pred = self.to_pred(outputs)
126
+ # active batch indices, the step it exited at, and the final output predictions
130
127
 
131
- should_halt = self.to_halt_pred(outputs)
128
+ active_batch_indices = arange(batch, device = self.device, dtype = torch.float32)
132
129
 
133
- outputs, latents = outputs.detach(), latents.detach()
130
+ preds = []
131
+ exited_step_indices = []
132
+ exited_batch_indices = []
134
133
 
135
- return_package = (outputs, latents, pred, should_halt)
134
+ for i in range(num_deep_refinement_steps):
135
+ step = i + 1
136
+ is_last = step == num_deep_refinement_steps
136
137
 
137
- if not exists(labels):
138
- return return_package
138
+ outputs, latents = self.deep_refinement(inputs, outputs, latents)
139
139
 
140
- # calculate loss if labels passed in
140
+ halt_prob = self.to_halt_pred(outputs)
141
141
 
142
- loss = F.cross_entropy(rearrange(pred, 'b n l -> b l n'), labels)
142
+ should_halt = (halt_prob >= halt_prob_thres) | is_last
143
143
 
144
- is_all_correct = (pred.argmax(dim = -1) == labels).all(dim = -1)
144
+ if not should_halt.any():
145
+ continue
145
146
 
146
- halt_loss = F.binary_cross_entropy(should_halt, is_all_correct.float())
147
+ # append to exited predictions
147
148
 
148
- # total loss and loss breakdown
149
+ pred = self.to_pred(outputs[should_halt])
150
+ preds.append(pred)
149
151
 
150
- total_loss = loss + halt_loss * self.halt_loss_weight
151
- losses = (loss, halt_loss)
152
+ # append the step at which early halted
152
153
 
153
- return (total_loss, losses, *return_package)
154
+ exited_step_indices.extend([step] * should_halt.sum().item())
154
155
 
155
- # trainer
156
+ # append indices for sorting back
156
157
 
157
- class Trainer(Module):
158
- def __init__(
159
- self,
160
- model: TinyRecursiveModel | Module,
161
- dataset: Dataset,
162
- optim_klass = AdamW,
163
- learning_rate = 1e-4,
164
- weight_decay = 1.,
165
- batch_size = 16,
166
- epochs = 2,
167
- halt_prob_thres = 0.5,
168
- max_recurrent_steps = 12,
169
- ema_decay_rate = 0.999,
170
- ema_update_model_with_ema_every = 10000
171
- ):
172
- super().__init__()
158
+ exited_batch_indices.append(active_batch_indices[should_halt])
173
159
 
174
- self.batch_size = batch_size
175
- self.epochs = epochs
160
+ if is_last:
161
+ continue
176
162
 
177
- self.dataset = dataset
178
- self.dataloader = dataloader = DataLoader(self.dataset, batch_size = self.batch_size, shuffle = True)
163
+ # ready for next round
179
164
 
180
- self.optim = optim_klass(
181
- model.parameters(),
182
- lr = learning_rate,
183
- weight_decay = weight_decay
184
- )
165
+ inputs = inputs[~should_halt]
166
+ outputs = outputs[~should_halt]
167
+ latents = latents[~should_halt]
168
+ active_batch_indices = active_batch_indices[~should_halt]
185
169
 
186
- self.model = model
170
+ if is_empty(outputs):
171
+ break
187
172
 
188
- self.ema_model = EMA(
189
- model,
190
- beta = ema_decay_rate,
191
- update_model_with_ema_every = ema_update_model_with_ema_every
192
- )
173
+ preds = cat(preds).argmax(dim = -1)
174
+ exited_step_indices = tensor(exited_step_indices)
193
175
 
194
- self.halt_prob_thres = halt_prob_thres
176
+ exited_batch_indices = cat(exited_batch_indices)
177
+ sort_indices = exited_batch_indices.argsort(dim = -1)
195
178
 
196
- self.max_recurrent_steps = max_recurrent_steps
179
+ return preds[sort_indices], exited_step_indices[sort_indices]
197
180
 
198
- def forward(self):
181
+ def forward(
182
+ self,
183
+ seq,
184
+ outputs,
185
+ latents,
186
+ labels = None
187
+ ):
188
+ inputs = self.input_embed(seq)
199
189
 
200
- for epoch in range_from_one(self.epochs):
190
+ outputs, latents = self.deep_refinement(inputs, outputs, latents)
201
191
 
202
- for dataset_input, dataset_output in self.dataloader:
192
+ pred = self.to_pred(outputs)
203
193
 
204
- outputs, latents = self.model.get_initial()
194
+ halt_prob = self.to_halt_pred(outputs)
205
195
 
206
- for recurrent_step in range_from_one(self.max_recurrent_steps):
196
+ outputs, latents = outputs.detach(), latents.detach()
207
197
 
208
- loss, (main_loss, halt_loss), outputs, latents, pred, halt = self.model(dataset_input, outputs, latents, labels = dataset_output)
198
+ return_package = (outputs, latents, pred, halt_prob)
209
199
 
210
- print(f'[{epoch} ({recurrent_step} / {self.max_recurrent_steps})] loss: {main_loss.item():.3f} | halt loss: {halt_loss.item():.3f}')
200
+ if not exists(labels):
201
+ return return_package
211
202
 
212
- loss.backward()
203
+ # calculate loss if labels passed in
213
204
 
214
- self.optim.step()
215
- self.optim.zero_grad()
205
+ loss = F.cross_entropy(rearrange(pred, 'b n l -> b l n'), labels, reduction = 'none')
206
+ loss = reduce(loss, 'b ... -> b', 'mean')
216
207
 
217
- self.ema_model.update()
208
+ is_all_correct = (pred.argmax(dim = -1) == labels).all(dim = -1)
218
209
 
219
- # handle halting
210
+ halt_loss = F.binary_cross_entropy(halt_prob, is_all_correct.float(), reduction = 'none')
220
211
 
221
- halt_mask = halt >= self.halt_prob_thres
212
+ # total loss and loss breakdown
222
213
 
223
- if not halt_mask.any():
224
- continue
214
+ total_loss = (
215
+ loss +
216
+ halt_loss * self.halt_loss_weight
217
+ )
225
218
 
226
- outputs = outputs[~halt_mask]
227
- latents = latents[~halt_mask]
228
- dataset_input = dataset_input[~halt_mask]
229
- dataset_output = dataset_output[~halt_mask]
219
+ losses = (loss, halt_loss)
230
220
 
231
- if is_empty(outputs):
232
- break
221
+ return (total_loss.mean(), losses, *return_package)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tiny-recursive-model
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: Tiny Recursive Model
5
5
  Project-URL: Homepage, https://pypi.org/project/tiny-recursive-model/
6
6
  Project-URL: Repository, https://github.com/lucidrains/tiny-recursive-model
@@ -0,0 +1,8 @@
1
+ tiny_recursive_model/__init__.py,sha256=obuHzL-k9cpbJiwFxopEYBuMQi898C0r45hqgB6x5Yo,123
2
+ tiny_recursive_model/mlp_mixer_1d.py,sha256=6ivDK9dgHdVl1axg2ayifJ7H5QI3hXptHnb6lfNrno0,1398
3
+ tiny_recursive_model/trainer.py,sha256=6dQPmRaQZWI6527OvlOdgHKCFsufkZnjSHClRdHjs20,4218
4
+ tiny_recursive_model/trm.py,sha256=Ep18uwvhWjHxGeyv42ruXLVc2F6TlZg2_CmeVVfYz7c,6001
5
+ tiny_recursive_model-0.0.2.dist-info/METADATA,sha256=JQeCdRnntKCNIMm5LVmSanMV7pavyPjgWyL3fm2LDJ0,3107
6
+ tiny_recursive_model-0.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ tiny_recursive_model-0.0.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ tiny_recursive_model-0.0.2.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- tiny_recursive_model/__init__.py,sha256=UufV6--ilPn4quRWyhvaFRMKRfHvfLsAmF9RU-L31rM,77
2
- tiny_recursive_model/mlp_mixer_1d.py,sha256=6ivDK9dgHdVl1axg2ayifJ7H5QI3hXptHnb6lfNrno0,1398
3
- tiny_recursive_model/trm.py,sha256=YwzTod4CeeXlbAiM-TBB7rEEHWsxnPxavaGiVCTPMEM,6350
4
- tiny_recursive_model-0.0.1.dist-info/METADATA,sha256=G-cM7okuLAiOxhofXoRh2Ih-bwYifcA3AAhmYmKo-v4,3107
5
- tiny_recursive_model-0.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- tiny_recursive_model-0.0.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- tiny_recursive_model-0.0.1.dist-info/RECORD,,