tiny-recursive-model 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tiny_recursive_model/trm.py +8 -6
- {tiny_recursive_model-0.0.3.dist-info → tiny_recursive_model-0.0.5.dist-info}/METADATA +19 -4
- tiny_recursive_model-0.0.5.dist-info/RECORD +8 -0
- tiny_recursive_model-0.0.3.dist-info/RECORD +0 -8
- {tiny_recursive_model-0.0.3.dist-info → tiny_recursive_model-0.0.5.dist-info}/WHEEL +0 -0
- {tiny_recursive_model-0.0.3.dist-info → tiny_recursive_model-0.0.5.dist-info}/licenses/LICENSE +0 -0
tiny_recursive_model/trm.py
CHANGED
@@ -25,6 +25,9 @@ def default(v, d):
|
|
25
25
|
def is_empty(t):
|
26
26
|
return t.numel() == 0
|
27
27
|
|
28
|
+
def range_from_one(n):
|
29
|
+
return range(1, n + 1)
|
30
|
+
|
28
31
|
# classes
|
29
32
|
|
30
33
|
class TinyRecursiveModel(Module):
|
@@ -98,11 +101,11 @@ class TinyRecursiveModel(Module):
|
|
98
101
|
latents, # (b n d)
|
99
102
|
):
|
100
103
|
|
101
|
-
for
|
104
|
+
for step in range_from_one(self.num_refinement_blocks):
|
102
105
|
|
103
106
|
# only last round of refinement receives gradients
|
104
107
|
|
105
|
-
is_last =
|
108
|
+
is_last = step == self.num_refinement_blocks
|
106
109
|
context = torch.no_grad if not is_last else nullcontext
|
107
110
|
|
108
111
|
with context():
|
@@ -115,7 +118,7 @@ class TinyRecursiveModel(Module):
|
|
115
118
|
self,
|
116
119
|
seq,
|
117
120
|
halt_prob_thres = 0.5,
|
118
|
-
|
121
|
+
max_deep_refinement_steps = 12
|
119
122
|
):
|
120
123
|
batch = seq.shape[0]
|
121
124
|
|
@@ -131,9 +134,8 @@ class TinyRecursiveModel(Module):
|
|
131
134
|
exited_step_indices = []
|
132
135
|
exited_batch_indices = []
|
133
136
|
|
134
|
-
for
|
135
|
-
|
136
|
-
is_last = step == num_deep_refinement_steps
|
137
|
+
for step in range_from_one(max_deep_refinement_steps):
|
138
|
+
is_last = step == max_deep_refinement_steps
|
137
139
|
|
138
140
|
outputs, latents = self.deep_refinement(inputs, outputs, latents)
|
139
141
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: tiny-recursive-model
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.5
|
4
4
|
Summary: Tiny Recursive Model
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/tiny-recursive-model/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/tiny-recursive-model
|
@@ -38,7 +38,7 @@ Requires-Dist: accelerate
|
|
38
38
|
Requires-Dist: einops>=0.8.1
|
39
39
|
Requires-Dist: ema-pytorch
|
40
40
|
Requires-Dist: torch>=2.4
|
41
|
-
Requires-Dist: x-transformers
|
41
|
+
Requires-Dist: x-transformers>=2.8.4
|
42
42
|
Provides-Extra: examples
|
43
43
|
Provides-Extra: test
|
44
44
|
Requires-Dist: pytest; extra == 'test'
|
@@ -77,6 +77,8 @@ trm = TinyRecursiveModel(
|
|
77
77
|
),
|
78
78
|
)
|
79
79
|
|
80
|
+
# mock dataset
|
81
|
+
|
80
82
|
from torch.utils.data import Dataset
|
81
83
|
class MockDataset(Dataset):
|
82
84
|
def __len__(self):
|
@@ -87,9 +89,13 @@ class MockDataset(Dataset):
|
|
87
89
|
out = torch.randint(0, 256, (256,))
|
88
90
|
return inp, out
|
89
91
|
|
92
|
+
mock_dataset = MockDataset()
|
93
|
+
|
94
|
+
# trainer
|
95
|
+
|
90
96
|
trainer = Trainer(
|
91
97
|
trm,
|
92
|
-
|
98
|
+
mock_dataset,
|
93
99
|
epochs = 1,
|
94
100
|
batch_size = 16,
|
95
101
|
cpu = True
|
@@ -97,9 +103,18 @@ trainer = Trainer(
|
|
97
103
|
|
98
104
|
trainer()
|
99
105
|
|
100
|
-
|
106
|
+
# inference
|
107
|
+
|
108
|
+
pred_answer, exit_indices = trm.predict(
|
109
|
+
torch.randint(0, 256, (1, 256)),
|
110
|
+
max_deep_refinement_steps = 12,
|
111
|
+
halt_prob_thres = 0.1
|
112
|
+
)
|
113
|
+
|
114
|
+
# save to collection of specialized networks for tool call
|
101
115
|
|
102
116
|
torch.save(trm.state_dict(), 'saved-trm.pt')
|
117
|
+
|
103
118
|
```
|
104
119
|
|
105
120
|
## Citations
|
@@ -0,0 +1,8 @@
|
|
1
|
+
tiny_recursive_model/__init__.py,sha256=zuMcrofGu7DnvJM2Mb-O3tqBJF5q8L-8X8OTmq7_o5w,189
|
2
|
+
tiny_recursive_model/mlp_mixer_1d.py,sha256=6ivDK9dgHdVl1axg2ayifJ7H5QI3hXptHnb6lfNrno0,1398
|
3
|
+
tiny_recursive_model/trainer.py,sha256=6dQPmRaQZWI6527OvlOdgHKCFsufkZnjSHClRdHjs20,4218
|
4
|
+
tiny_recursive_model/trm.py,sha256=pvhXZjDFoLQR-bc8ZDc_ikklF-s21cY5o6xYAwSDoL8,6048
|
5
|
+
tiny_recursive_model-0.0.5.dist-info/METADATA,sha256=GHJavWHHc4O001dFHLsDimAVIiiogCIv4al4QFWix60,4126
|
6
|
+
tiny_recursive_model-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
tiny_recursive_model-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
tiny_recursive_model-0.0.5.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
tiny_recursive_model/__init__.py,sha256=zuMcrofGu7DnvJM2Mb-O3tqBJF5q8L-8X8OTmq7_o5w,189
|
2
|
-
tiny_recursive_model/mlp_mixer_1d.py,sha256=6ivDK9dgHdVl1axg2ayifJ7H5QI3hXptHnb6lfNrno0,1398
|
3
|
-
tiny_recursive_model/trainer.py,sha256=6dQPmRaQZWI6527OvlOdgHKCFsufkZnjSHClRdHjs20,4218
|
4
|
-
tiny_recursive_model/trm.py,sha256=Ep18uwvhWjHxGeyv42ruXLVc2F6TlZg2_CmeVVfYz7c,6001
|
5
|
-
tiny_recursive_model-0.0.3.dist-info/METADATA,sha256=0enBPVOxRoReOf0hms_ZoAI4HHdMWUrrW4Ps0MTuQ9g,3943
|
6
|
-
tiny_recursive_model-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
tiny_recursive_model-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
tiny_recursive_model-0.0.3.dist-info/RECORD,,
|
File without changes
|
{tiny_recursive_model-0.0.3.dist-info → tiny_recursive_model-0.0.5.dist-info}/licenses/LICENSE
RENAMED
File without changes
|