tiny-recursive-model 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,6 +25,9 @@ def default(v, d):
25
25
  def is_empty(t):
26
26
  return t.numel() == 0
27
27
 
28
+ def range_from_one(n):
29
+ return range(1, n + 1)
30
+
28
31
  # classes
29
32
 
30
33
  class TinyRecursiveModel(Module):
@@ -98,11 +101,11 @@ class TinyRecursiveModel(Module):
98
101
  latents, # (b n d)
99
102
  ):
100
103
 
101
- for i in range(self.num_refinement_blocks):
104
+ for step in range_from_one(self.num_refinement_blocks):
102
105
 
103
106
  # only last round of refinement receives gradients
104
107
 
105
- is_last = i == (self.num_refinement_blocks - 1)
108
+ is_last = step == self.num_refinement_blocks
106
109
  context = torch.no_grad if not is_last else nullcontext
107
110
 
108
111
  with context():
@@ -115,7 +118,7 @@ class TinyRecursiveModel(Module):
115
118
  self,
116
119
  seq,
117
120
  halt_prob_thres = 0.5,
118
- num_deep_refinement_steps = 12
121
+ max_deep_refinement_steps = 12
119
122
  ):
120
123
  batch = seq.shape[0]
121
124
 
@@ -131,9 +134,8 @@ class TinyRecursiveModel(Module):
131
134
  exited_step_indices = []
132
135
  exited_batch_indices = []
133
136
 
134
- for i in range(num_deep_refinement_steps):
135
- step = i + 1
136
- is_last = step == num_deep_refinement_steps
137
+ for step in range_from_one(max_deep_refinement_steps):
138
+ is_last = step == max_deep_refinement_steps
137
139
 
138
140
  outputs, latents = self.deep_refinement(inputs, outputs, latents)
139
141
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tiny-recursive-model
3
- Version: 0.0.3
3
+ Version: 0.0.4
4
4
  Summary: Tiny Recursive Model
5
5
  Project-URL: Homepage, https://pypi.org/project/tiny-recursive-model/
6
6
  Project-URL: Repository, https://github.com/lucidrains/tiny-recursive-model
@@ -77,6 +77,8 @@ trm = TinyRecursiveModel(
77
77
  ),
78
78
  )
79
79
 
80
+ # mock dataset
81
+
80
82
  from torch.utils.data import Dataset
81
83
  class MockDataset(Dataset):
82
84
  def __len__(self):
@@ -87,9 +89,13 @@ class MockDataset(Dataset):
87
89
  out = torch.randint(0, 256, (256,))
88
90
  return inp, out
89
91
 
92
+ mock_dataset = MockDataset()
93
+
94
+ # trainer
95
+
90
96
  trainer = Trainer(
91
97
  trm,
92
- MockDataset(),
98
+ mock_dataset,
93
99
  epochs = 1,
94
100
  batch_size = 16,
95
101
  cpu = True
@@ -97,9 +103,18 @@ trainer = Trainer(
97
103
 
98
104
  trainer()
99
105
 
100
- pred_answer, exit_indices = trm.predict(torch.randint(0, 256, (1, 256)), halt_prob_thres = 0.1)
106
+ # inference
107
+
108
+ pred_answer, exit_indices = trm.predict(
109
+ torch.randint(0, 256, (1, 256)),
110
+ max_deep_refinement_steps = 12,
111
+ halt_prob_thres = 0.1
112
+ )
113
+
114
+ # save to collection of specialized networks for tool call
101
115
 
102
116
  torch.save(trm.state_dict(), 'saved-trm.pt')
117
+
103
118
  ```
104
119
 
105
120
  ## Citations
@@ -0,0 +1,8 @@
1
+ tiny_recursive_model/__init__.py,sha256=zuMcrofGu7DnvJM2Mb-O3tqBJF5q8L-8X8OTmq7_o5w,189
2
+ tiny_recursive_model/mlp_mixer_1d.py,sha256=6ivDK9dgHdVl1axg2ayifJ7H5QI3hXptHnb6lfNrno0,1398
3
+ tiny_recursive_model/trainer.py,sha256=6dQPmRaQZWI6527OvlOdgHKCFsufkZnjSHClRdHjs20,4218
4
+ tiny_recursive_model/trm.py,sha256=pvhXZjDFoLQR-bc8ZDc_ikklF-s21cY5o6xYAwSDoL8,6048
5
+ tiny_recursive_model-0.0.4.dist-info/METADATA,sha256=dZKb8mPxFPRghjSG7ZCofEI80nSDlAjFVQAB8JvUAi0,4119
6
+ tiny_recursive_model-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ tiny_recursive_model-0.0.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ tiny_recursive_model-0.0.4.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- tiny_recursive_model/__init__.py,sha256=zuMcrofGu7DnvJM2Mb-O3tqBJF5q8L-8X8OTmq7_o5w,189
2
- tiny_recursive_model/mlp_mixer_1d.py,sha256=6ivDK9dgHdVl1axg2ayifJ7H5QI3hXptHnb6lfNrno0,1398
3
- tiny_recursive_model/trainer.py,sha256=6dQPmRaQZWI6527OvlOdgHKCFsufkZnjSHClRdHjs20,4218
4
- tiny_recursive_model/trm.py,sha256=Ep18uwvhWjHxGeyv42ruXLVc2F6TlZg2_CmeVVfYz7c,6001
5
- tiny_recursive_model-0.0.3.dist-info/METADATA,sha256=0enBPVOxRoReOf0hms_ZoAI4HHdMWUrrW4Ps0MTuQ9g,3943
6
- tiny_recursive_model-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- tiny_recursive_model-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- tiny_recursive_model-0.0.3.dist-info/RECORD,,