tiny-recursive-model 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tiny_recursive_model/__init__.py +4 -0
- tiny_recursive_model/trm.py +8 -6
- {tiny_recursive_model-0.0.2.dist-info → tiny_recursive_model-0.0.4.dist-info}/METADATA +63 -1
- tiny_recursive_model-0.0.4.dist-info/RECORD +8 -0
- tiny_recursive_model-0.0.2.dist-info/RECORD +0 -8
- {tiny_recursive_model-0.0.2.dist-info → tiny_recursive_model-0.0.4.dist-info}/WHEEL +0 -0
- {tiny_recursive_model-0.0.2.dist-info → tiny_recursive_model-0.0.4.dist-info}/licenses/LICENSE +0 -0
tiny_recursive_model/__init__.py
CHANGED
tiny_recursive_model/trm.py
CHANGED
@@ -25,6 +25,9 @@ def default(v, d):
|
|
25
25
|
def is_empty(t):
|
26
26
|
return t.numel() == 0
|
27
27
|
|
28
|
+
def range_from_one(n):
|
29
|
+
return range(1, n + 1)
|
30
|
+
|
28
31
|
# classes
|
29
32
|
|
30
33
|
class TinyRecursiveModel(Module):
|
@@ -98,11 +101,11 @@ class TinyRecursiveModel(Module):
|
|
98
101
|
latents, # (b n d)
|
99
102
|
):
|
100
103
|
|
101
|
-
for
|
104
|
+
for step in range_from_one(self.num_refinement_blocks):
|
102
105
|
|
103
106
|
# only last round of refinement receives gradients
|
104
107
|
|
105
|
-
is_last =
|
108
|
+
is_last = step == self.num_refinement_blocks
|
106
109
|
context = torch.no_grad if not is_last else nullcontext
|
107
110
|
|
108
111
|
with context():
|
@@ -115,7 +118,7 @@ class TinyRecursiveModel(Module):
|
|
115
118
|
self,
|
116
119
|
seq,
|
117
120
|
halt_prob_thres = 0.5,
|
118
|
-
|
121
|
+
max_deep_refinement_steps = 12
|
119
122
|
):
|
120
123
|
batch = seq.shape[0]
|
121
124
|
|
@@ -131,9 +134,8 @@ class TinyRecursiveModel(Module):
|
|
131
134
|
exited_step_indices = []
|
132
135
|
exited_batch_indices = []
|
133
136
|
|
134
|
-
for
|
135
|
-
|
136
|
-
is_last = step == num_deep_refinement_steps
|
137
|
+
for step in range_from_one(max_deep_refinement_steps):
|
138
|
+
is_last = step == max_deep_refinement_steps
|
137
139
|
|
138
140
|
outputs, latents = self.deep_refinement(inputs, outputs, latents)
|
139
141
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: tiny-recursive-model
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.4
|
4
4
|
Summary: Tiny Recursive Model
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/tiny-recursive-model/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/tiny-recursive-model
|
@@ -55,6 +55,68 @@ Official repository is [here](https://github.com/SamsungSAILMontreal/TinyRecursi
|
|
55
55
|
|
56
56
|
<img width="300" alt="trm-fig3" src="https://github.com/user-attachments/assets/bfe3dd2a-e859-492a-84d5-faf37339f534" />
|
57
57
|
|
58
|
+
## Install
|
59
|
+
|
60
|
+
```bash
|
61
|
+
$ pip install tiny-recursive-model
|
62
|
+
```
|
63
|
+
|
64
|
+
## Usage
|
65
|
+
|
66
|
+
```python
|
67
|
+
import torch
|
68
|
+
from tiny_recursive_model import TinyRecursiveModel, MLPMixer1D, Trainer
|
69
|
+
|
70
|
+
trm = TinyRecursiveModel(
|
71
|
+
dim = 16,
|
72
|
+
num_tokens = 256,
|
73
|
+
network = MLPMixer1D(
|
74
|
+
dim = 16,
|
75
|
+
depth = 2,
|
76
|
+
seq_len = 256
|
77
|
+
),
|
78
|
+
)
|
79
|
+
|
80
|
+
# mock dataset
|
81
|
+
|
82
|
+
from torch.utils.data import Dataset
|
83
|
+
class MockDataset(Dataset):
|
84
|
+
def __len__(self):
|
85
|
+
return 16
|
86
|
+
|
87
|
+
def __getitem__(self, idx):
|
88
|
+
inp = torch.randint(0, 256, (256,))
|
89
|
+
out = torch.randint(0, 256, (256,))
|
90
|
+
return inp, out
|
91
|
+
|
92
|
+
mock_dataset = MockDataset()
|
93
|
+
|
94
|
+
# trainer
|
95
|
+
|
96
|
+
trainer = Trainer(
|
97
|
+
trm,
|
98
|
+
mock_dataset,
|
99
|
+
epochs = 1,
|
100
|
+
batch_size = 16,
|
101
|
+
cpu = True
|
102
|
+
)
|
103
|
+
|
104
|
+
trainer()
|
105
|
+
|
106
|
+
# inference
|
107
|
+
|
108
|
+
pred_answer, exit_indices = trm.predict(
|
109
|
+
torch.randint(0, 256, (1, 256)),
|
110
|
+
max_deep_refinement_steps = 12,
|
111
|
+
halt_prob_thres = 0.1
|
112
|
+
)
|
113
|
+
|
114
|
+
# save to collection of specialized networks for tool call
|
115
|
+
|
116
|
+
torch.save(trm.state_dict(), 'saved-trm.pt')
|
117
|
+
|
118
|
+
```
|
119
|
+
|
58
120
|
## Citations
|
59
121
|
|
60
122
|
```bibtex
|
@@ -0,0 +1,8 @@
|
|
1
|
+
tiny_recursive_model/__init__.py,sha256=zuMcrofGu7DnvJM2Mb-O3tqBJF5q8L-8X8OTmq7_o5w,189
|
2
|
+
tiny_recursive_model/mlp_mixer_1d.py,sha256=6ivDK9dgHdVl1axg2ayifJ7H5QI3hXptHnb6lfNrno0,1398
|
3
|
+
tiny_recursive_model/trainer.py,sha256=6dQPmRaQZWI6527OvlOdgHKCFsufkZnjSHClRdHjs20,4218
|
4
|
+
tiny_recursive_model/trm.py,sha256=pvhXZjDFoLQR-bc8ZDc_ikklF-s21cY5o6xYAwSDoL8,6048
|
5
|
+
tiny_recursive_model-0.0.4.dist-info/METADATA,sha256=dZKb8mPxFPRghjSG7ZCofEI80nSDlAjFVQAB8JvUAi0,4119
|
6
|
+
tiny_recursive_model-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
tiny_recursive_model-0.0.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
tiny_recursive_model-0.0.4.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
tiny_recursive_model/__init__.py,sha256=obuHzL-k9cpbJiwFxopEYBuMQi898C0r45hqgB6x5Yo,123
|
2
|
-
tiny_recursive_model/mlp_mixer_1d.py,sha256=6ivDK9dgHdVl1axg2ayifJ7H5QI3hXptHnb6lfNrno0,1398
|
3
|
-
tiny_recursive_model/trainer.py,sha256=6dQPmRaQZWI6527OvlOdgHKCFsufkZnjSHClRdHjs20,4218
|
4
|
-
tiny_recursive_model/trm.py,sha256=Ep18uwvhWjHxGeyv42ruXLVc2F6TlZg2_CmeVVfYz7c,6001
|
5
|
-
tiny_recursive_model-0.0.2.dist-info/METADATA,sha256=JQeCdRnntKCNIMm5LVmSanMV7pavyPjgWyL3fm2LDJ0,3107
|
6
|
-
tiny_recursive_model-0.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
tiny_recursive_model-0.0.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
tiny_recursive_model-0.0.2.dist-info/RECORD,,
|
File without changes
|
{tiny_recursive_model-0.0.2.dist-info → tiny_recursive_model-0.0.4.dist-info}/licenses/LICENSE
RENAMED
File without changes
|