tiny-recursive-model 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,3 +5,7 @@ from tiny_recursive_model.trm import (
5
5
  from tiny_recursive_model.trainer import (
6
6
  Trainer
7
7
  )
8
+
9
+ from tiny_recursive_model.mlp_mixer_1d import (
10
+ MLPMixer1D
11
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tiny-recursive-model
3
- Version: 0.0.2
3
+ Version: 0.0.3
4
4
  Summary: Tiny Recursive Model
5
5
  Project-URL: Homepage, https://pypi.org/project/tiny-recursive-model/
6
6
  Project-URL: Repository, https://github.com/lucidrains/tiny-recursive-model
@@ -55,6 +55,53 @@ Official repository is [here](https://github.com/SamsungSAILMontreal/TinyRecursi
55
55
 
56
56
  <img width="300" alt="trm-fig3" src="https://github.com/user-attachments/assets/bfe3dd2a-e859-492a-84d5-faf37339f534" />
57
57
 
58
+ ## Install
59
+
60
+ ```bash
61
+ $ pip install tiny-recursive-model
62
+ ```
63
+
64
+ ## Usage
65
+
66
+ ```python
67
+ import torch
68
+ from tiny_recursive_model import TinyRecursiveModel, MLPMixer1D, Trainer
69
+
70
+ trm = TinyRecursiveModel(
71
+ dim = 16,
72
+ num_tokens = 256,
73
+ network = MLPMixer1D(
74
+ dim = 16,
75
+ depth = 2,
76
+ seq_len = 256
77
+ ),
78
+ )
79
+
80
+ from torch.utils.data import Dataset
81
+ class MockDataset(Dataset):
82
+ def __len__(self):
83
+ return 16
84
+
85
+ def __getitem__(self, idx):
86
+ inp = torch.randint(0, 256, (256,))
87
+ out = torch.randint(0, 256, (256,))
88
+ return inp, out
89
+
90
+ trainer = Trainer(
91
+ trm,
92
+ MockDataset(),
93
+ epochs = 1,
94
+ batch_size = 16,
95
+ cpu = True
96
+ )
97
+
98
+ trainer()
99
+
100
+ pred_answer, exit_indices = trm.predict(torch.randint(0, 256, (1, 256)), halt_prob_thres = 0.1)
101
+
102
+ torch.save(trm.state_dict(), 'saved-trm.pt')
103
+ ```
104
+
58
105
  ## Citations
59
106
 
60
107
  ```bibtex
@@ -0,0 +1,8 @@
1
+ tiny_recursive_model/__init__.py,sha256=zuMcrofGu7DnvJM2Mb-O3tqBJF5q8L-8X8OTmq7_o5w,189
2
+ tiny_recursive_model/mlp_mixer_1d.py,sha256=6ivDK9dgHdVl1axg2ayifJ7H5QI3hXptHnb6lfNrno0,1398
3
+ tiny_recursive_model/trainer.py,sha256=6dQPmRaQZWI6527OvlOdgHKCFsufkZnjSHClRdHjs20,4218
4
+ tiny_recursive_model/trm.py,sha256=Ep18uwvhWjHxGeyv42ruXLVc2F6TlZg2_CmeVVfYz7c,6001
5
+ tiny_recursive_model-0.0.3.dist-info/METADATA,sha256=0enBPVOxRoReOf0hms_ZoAI4HHdMWUrrW4Ps0MTuQ9g,3943
6
+ tiny_recursive_model-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ tiny_recursive_model-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ tiny_recursive_model-0.0.3.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- tiny_recursive_model/__init__.py,sha256=obuHzL-k9cpbJiwFxopEYBuMQi898C0r45hqgB6x5Yo,123
2
- tiny_recursive_model/mlp_mixer_1d.py,sha256=6ivDK9dgHdVl1axg2ayifJ7H5QI3hXptHnb6lfNrno0,1398
3
- tiny_recursive_model/trainer.py,sha256=6dQPmRaQZWI6527OvlOdgHKCFsufkZnjSHClRdHjs20,4218
4
- tiny_recursive_model/trm.py,sha256=Ep18uwvhWjHxGeyv42ruXLVc2F6TlZg2_CmeVVfYz7c,6001
5
- tiny_recursive_model-0.0.2.dist-info/METADATA,sha256=JQeCdRnntKCNIMm5LVmSanMV7pavyPjgWyL3fm2LDJ0,3107
6
- tiny_recursive_model-0.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- tiny_recursive_model-0.0.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- tiny_recursive_model-0.0.2.dist-info/RECORD,,