tiny-recursive-model 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,3 +5,7 @@ from tiny_recursive_model.trm import (
5
5
  from tiny_recursive_model.trainer import (
6
6
  Trainer
7
7
  )
8
+
9
+ from tiny_recursive_model.mlp_mixer_1d import (
10
+ MLPMixer1D
11
+ )
@@ -25,6 +25,9 @@ def default(v, d):
25
25
  def is_empty(t):
26
26
  return t.numel() == 0
27
27
 
28
+ def range_from_one(n):
29
+ return range(1, n + 1)
30
+
28
31
  # classes
29
32
 
30
33
  class TinyRecursiveModel(Module):
@@ -98,11 +101,11 @@ class TinyRecursiveModel(Module):
98
101
  latents, # (b n d)
99
102
  ):
100
103
 
101
- for i in range(self.num_refinement_blocks):
104
+ for step in range_from_one(self.num_refinement_blocks):
102
105
 
103
106
  # only last round of refinement receives gradients
104
107
 
105
- is_last = i == (self.num_refinement_blocks - 1)
108
+ is_last = step == self.num_refinement_blocks
106
109
  context = torch.no_grad if not is_last else nullcontext
107
110
 
108
111
  with context():
@@ -115,7 +118,7 @@ class TinyRecursiveModel(Module):
115
118
  self,
116
119
  seq,
117
120
  halt_prob_thres = 0.5,
118
- num_deep_refinement_steps = 12
121
+ max_deep_refinement_steps = 12
119
122
  ):
120
123
  batch = seq.shape[0]
121
124
 
@@ -131,9 +134,8 @@ class TinyRecursiveModel(Module):
131
134
  exited_step_indices = []
132
135
  exited_batch_indices = []
133
136
 
134
- for i in range(num_deep_refinement_steps):
135
- step = i + 1
136
- is_last = step == num_deep_refinement_steps
137
+ for step in range_from_one(max_deep_refinement_steps):
138
+ is_last = step == max_deep_refinement_steps
137
139
 
138
140
  outputs, latents = self.deep_refinement(inputs, outputs, latents)
139
141
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tiny-recursive-model
3
- Version: 0.0.2
3
+ Version: 0.0.4
4
4
  Summary: Tiny Recursive Model
5
5
  Project-URL: Homepage, https://pypi.org/project/tiny-recursive-model/
6
6
  Project-URL: Repository, https://github.com/lucidrains/tiny-recursive-model
@@ -55,6 +55,68 @@ Official repository is [here](https://github.com/SamsungSAILMontreal/TinyRecursi
55
55
 
56
56
  <img width="300" alt="trm-fig3" src="https://github.com/user-attachments/assets/bfe3dd2a-e859-492a-84d5-faf37339f534" />
57
57
 
58
+ ## Install
59
+
60
+ ```bash
61
+ $ pip install tiny-recursive-model
62
+ ```
63
+
64
+ ## Usage
65
+
66
+ ```python
67
+ import torch
68
+ from tiny_recursive_model import TinyRecursiveModel, MLPMixer1D, Trainer
69
+
70
+ trm = TinyRecursiveModel(
71
+ dim = 16,
72
+ num_tokens = 256,
73
+ network = MLPMixer1D(
74
+ dim = 16,
75
+ depth = 2,
76
+ seq_len = 256
77
+ ),
78
+ )
79
+
80
+ # mock dataset
81
+
82
+ from torch.utils.data import Dataset
83
+ class MockDataset(Dataset):
84
+ def __len__(self):
85
+ return 16
86
+
87
+ def __getitem__(self, idx):
88
+ inp = torch.randint(0, 256, (256,))
89
+ out = torch.randint(0, 256, (256,))
90
+ return inp, out
91
+
92
+ mock_dataset = MockDataset()
93
+
94
+ # trainer
95
+
96
+ trainer = Trainer(
97
+ trm,
98
+ mock_dataset,
99
+ epochs = 1,
100
+ batch_size = 16,
101
+ cpu = True
102
+ )
103
+
104
+ trainer()
105
+
106
+ # inference
107
+
108
+ pred_answer, exit_indices = trm.predict(
109
+ torch.randint(0, 256, (1, 256)),
110
+ max_deep_refinement_steps = 12,
111
+ halt_prob_thres = 0.1
112
+ )
113
+
114
+ # save to collection of specialized networks for tool call
115
+
116
+ torch.save(trm.state_dict(), 'saved-trm.pt')
117
+
118
+ ```
119
+
58
120
  ## Citations
59
121
 
60
122
  ```bibtex
@@ -0,0 +1,8 @@
1
+ tiny_recursive_model/__init__.py,sha256=zuMcrofGu7DnvJM2Mb-O3tqBJF5q8L-8X8OTmq7_o5w,189
2
+ tiny_recursive_model/mlp_mixer_1d.py,sha256=6ivDK9dgHdVl1axg2ayifJ7H5QI3hXptHnb6lfNrno0,1398
3
+ tiny_recursive_model/trainer.py,sha256=6dQPmRaQZWI6527OvlOdgHKCFsufkZnjSHClRdHjs20,4218
4
+ tiny_recursive_model/trm.py,sha256=pvhXZjDFoLQR-bc8ZDc_ikklF-s21cY5o6xYAwSDoL8,6048
5
+ tiny_recursive_model-0.0.4.dist-info/METADATA,sha256=dZKb8mPxFPRghjSG7ZCofEI80nSDlAjFVQAB8JvUAi0,4119
6
+ tiny_recursive_model-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ tiny_recursive_model-0.0.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ tiny_recursive_model-0.0.4.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- tiny_recursive_model/__init__.py,sha256=obuHzL-k9cpbJiwFxopEYBuMQi898C0r45hqgB6x5Yo,123
2
- tiny_recursive_model/mlp_mixer_1d.py,sha256=6ivDK9dgHdVl1axg2ayifJ7H5QI3hXptHnb6lfNrno0,1398
3
- tiny_recursive_model/trainer.py,sha256=6dQPmRaQZWI6527OvlOdgHKCFsufkZnjSHClRdHjs20,4218
4
- tiny_recursive_model/trm.py,sha256=Ep18uwvhWjHxGeyv42ruXLVc2F6TlZg2_CmeVVfYz7c,6001
5
- tiny_recursive_model-0.0.2.dist-info/METADATA,sha256=JQeCdRnntKCNIMm5LVmSanMV7pavyPjgWyL3fm2LDJ0,3107
6
- tiny_recursive_model-0.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- tiny_recursive_model-0.0.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- tiny_recursive_model-0.0.2.dist-info/RECORD,,