locoformer 0.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,36 @@
1
+ # This workflow will upload a Python Package using Twine when a release is created
2
+ # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3
+
4
+ # This workflow uses actions that are not certified by GitHub.
5
+ # They are provided by a third-party and are governed by
6
+ # separate terms of service, privacy policy, and support
7
+ # documentation.
8
+
9
+ name: Upload Python Package
10
+
11
+ on:
12
+ release:
13
+ types: [published]
14
+
15
+ jobs:
16
+ deploy:
17
+
18
+ runs-on: ubuntu-latest
19
+
20
+ steps:
21
+ - uses: actions/checkout@v2
22
+ - name: Set up Python
23
+ uses: actions/setup-python@v2
24
+ with:
25
+ python-version: '3.x'
26
+ - name: Install dependencies
27
+ run: |
28
+ python -m pip install --upgrade pip
29
+ pip install build
30
+ - name: Build package
31
+ run: python -m build
32
+ - name: Publish package
33
+ uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34
+ with:
35
+ user: __token__
36
+ password: ${{ secrets.PYPI_API_TOKEN }}
@@ -0,0 +1,21 @@
1
+ name: Pytest
2
+ on: [push, pull_request]
3
+
4
+ jobs:
5
+ build:
6
+
7
+ runs-on: ubuntu-latest
8
+
9
+ steps:
10
+ - uses: actions/checkout@v4
11
+ - name: Set up Python 3.10
12
+ uses: actions/setup-python@v5
13
+ with:
14
+ python-version: "3.10"
15
+ - name: Install dependencies
16
+ run: |
17
+ python -m pip install --upgrade pip
18
+ python -m pip install -e .[test]
19
+ - name: Test with pytest
20
+ run: |
21
+ python -m pytest tests/
@@ -0,0 +1,208 @@
1
+
2
+ # Byte-compiled / optimized / DLL files
3
+ __pycache__/
4
+ *.py[codz]
5
+ *$py.class
6
+
7
+ # C extensions
8
+ *.so
9
+
10
+ # Distribution / packaging
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py.cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+ cover/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ # For a library or package, you might want to ignore these files since the code is
88
+ # intended to run in multiple environments; otherwise, check them in:
89
+ # .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # UV
99
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
100
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
101
+ # commonly ignored for libraries.
102
+ #uv.lock
103
+
104
+ # poetry
105
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
107
+ # commonly ignored for libraries.
108
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109
+ #poetry.lock
110
+ #poetry.toml
111
+
112
+ # pdm
113
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
114
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
115
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
116
+ #pdm.lock
117
+ #pdm.toml
118
+ .pdm-python
119
+ .pdm-build/
120
+
121
+ # pixi
122
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
123
+ #pixi.lock
124
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
125
+ # in the .venv directory. It is recommended not to include this directory in version control.
126
+ .pixi
127
+
128
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
129
+ __pypackages__/
130
+
131
+ # Celery stuff
132
+ celerybeat-schedule
133
+ celerybeat.pid
134
+
135
+ # SageMath parsed files
136
+ *.sage.py
137
+
138
+ # Environments
139
+ .env
140
+ .envrc
141
+ .venv
142
+ env/
143
+ venv/
144
+ ENV/
145
+ env.bak/
146
+ venv.bak/
147
+
148
+ # Spyder project settings
149
+ .spyderproject
150
+ .spyproject
151
+
152
+ # Rope project settings
153
+ .ropeproject
154
+
155
+ # mkdocs documentation
156
+ /site
157
+
158
+ # mypy
159
+ .mypy_cache/
160
+ .dmypy.json
161
+ dmypy.json
162
+
163
+ # Pyre type checker
164
+ .pyre/
165
+
166
+ # pytype static type analyzer
167
+ .pytype/
168
+
169
+ # Cython debug symbols
170
+ cython_debug/
171
+
172
+ # PyCharm
173
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
174
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
175
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
176
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
177
+ #.idea/
178
+
179
+ # Abstra
180
+ # Abstra is an AI-powered process automation framework.
181
+ # Ignore directories containing user credentials, local state, and settings.
182
+ # Learn more at https://abstra.io/docs
183
+ .abstra/
184
+
185
+ # Visual Studio Code
186
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
187
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
188
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
189
+ # you could uncomment the following to ignore the entire vscode folder
190
+ # .vscode/
191
+
192
+ # Ruff stuff:
193
+ .ruff_cache/
194
+
195
+ # PyPI configuration file
196
+ .pypirc
197
+
198
+ # Cursor
199
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
200
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
201
+ # refer to https://docs.cursor.com/context/ignore-files
202
+ .cursorignore
203
+ .cursorindexingignore
204
+
205
+ # Marimo
206
+ marimo/_static/
207
+ marimo/_lsp/
208
+ __marimo__/
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Phil Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,71 @@
1
+ Metadata-Version: 2.4
2
+ Name: locoformer
3
+ Version: 0.0.6
4
+ Summary: LocoFormer
5
+ Project-URL: Homepage, https://pypi.org/project/locoformer/
6
+ Project-URL: Repository, https://github.com/lucidrains/locoformer
7
+ Author-email: Phil Wang <lucidrains@gmail.com>
8
+ License: MIT License
9
+
10
+ Copyright (c) 2025 Phil Wang
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+ License-File: LICENSE
30
+ Keywords: artificial intelligence,attention mechanism,cross-embodiment,deep learning,robotics,transformer
31
+ Classifier: Development Status :: 4 - Beta
32
+ Classifier: Intended Audience :: Developers
33
+ Classifier: License :: OSI Approved :: MIT License
34
+ Classifier: Programming Language :: Python :: 3.9
35
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
+ Requires-Python: >=3.9
37
+ Requires-Dist: assoc-scan
38
+ Requires-Dist: einops>=0.8.0
39
+ Requires-Dist: einx>=0.3.0
40
+ Requires-Dist: rotary-embedding-torch
41
+ Requires-Dist: torch>=2.4
42
+ Requires-Dist: x-mlps-pytorch
43
+ Provides-Extra: examples
44
+ Requires-Dist: accelerate; extra == 'examples'
45
+ Requires-Dist: tqdm; extra == 'examples'
46
+ Provides-Extra: test
47
+ Requires-Dist: pytest; extra == 'test'
48
+ Description-Content-Type: text/markdown
49
+
50
+ <img src="./fig3.png" width="400px"></img>
51
+
52
+ ## LocoFormer (wip)
53
+
54
+ [LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
55
+
56
+ The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment). When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
57
+
58
+ ## Sponsors
59
+
60
+ This open sourced work is sponsored by [Safe Sentinel](https://www.safesentinels.com/)
61
+
62
+ ## Citations
63
+
64
+ ```bibtex
65
+ @article{liu2025locoformer,
66
+ title = {LocoFormer: Generalist Locomotion via Long-Context Adaptation},
67
+ author = {Liu, Min and Pathak, Deepak and Agarwal, Ananye},
68
+ journal = {Conference on Robot Learning ({CoRL})},
69
+ year = {2025}
70
+ }
71
+ ```
@@ -0,0 +1,22 @@
1
+ <img src="./fig3.png" width="400px"></img>
2
+
3
+ ## LocoFormer (wip)
4
+
5
+ [LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
6
+
7
+ The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment). When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
8
+
9
+ ## Sponsors
10
+
11
+ This open sourced work is sponsored by [Safe Sentinel](https://www.safesentinels.com/)
12
+
13
+ ## Citations
14
+
15
+ ```bibtex
16
+ @article{liu2025locoformer,
17
+ title = {LocoFormer: Generalist Locomotion via Long-Context Adaptation},
18
+ author = {Liu, Min and Pathak, Deepak and Agarwal, Ananye},
19
+ journal = {Conference on Robot Learning ({CoRL})},
20
+ year = {2025}
21
+ }
22
+ ```
@@ -0,0 +1,3 @@
1
+ # Data source
2
+
3
+ The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
Binary file
Binary file
@@ -0,0 +1 @@
1
+ from locoformer.locoformer import Locoformer
@@ -0,0 +1,483 @@
1
+ from __future__ import annotations
2
+ from functools import partial
3
+
4
+ import torch
5
+ from torch import nn, cat, stack, arange, is_tensor
6
+ import torch.nn.functional as F
7
+ from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
8
+ from torch.utils._pytree import tree_map
9
+
10
+ import einx
11
+ from einops import rearrange, einsum
12
+ from einops.layers.torch import Rearrange
13
+
14
+ from rotary_embedding_torch import RotaryEmbedding
15
+
16
+ from assoc_scan import AssocScan
17
+
18
+ LinearNoBias = partial(Linear, bias = False)
19
+
20
+ # helper functions
21
+
22
+ def exists(v):
23
+ return v is not None
24
+
25
+ def default(v, d):
26
+ return v if exists(v) else d
27
+
28
+ def first(arr):
29
+ return arr[0]
30
+
31
+ def divisible_by(num, den):
32
+ return (num % den) == 0
33
+
34
+ def tree_map_tensor(x, fn):
35
+ return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
36
+
37
+ def detach_all(x):
38
+ return tree_map_tensor(x, lambda t: t.detach())
39
+
40
+ def combine_kv_cache(cache1, cache2):
41
+ combined_cache = []
42
+
43
+ for layer_cache1, layer_cache2 in zip(cache1, cache2):
44
+ next_cache = cat((layer_cache1, layer_cache2), dim = -2)
45
+ combined_cache.append(next_cache)
46
+
47
+ return combined_cache
48
+
49
+ # generalized advantage estimate
50
+
51
+ @torch.no_grad()
52
+ def calc_gae(
53
+ rewards,
54
+ values,
55
+ masks,
56
+ gamma = 0.99,
57
+ lam = 0.95,
58
+ use_accelerated = None
59
+ ):
60
+ assert values.shape[-1] == rewards.shape[-1]
61
+ use_accelerated = default(use_accelerated, rewards.is_cuda)
62
+
63
+ values = F.pad(values, (0, 1), value = 0.)
64
+ values, values_next = values[..., :-1], values[..., 1:]
65
+
66
+ delta = rewards + gamma * values_next * masks - values
67
+ gates = gamma * lam * masks
68
+
69
+ scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
70
+
71
+ gae = scan(gates, delta)
72
+
73
+ returns = gae + values
74
+
75
+ return returns
76
+
77
+ # transformer-xl mask w/ flex attn
78
+
79
+ flex_attention = None
80
+
81
+ try:
82
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
83
+ if torch.cuda.is_available():
84
+ flex_attention = torch.compile(flex_attention)
85
+ except ImportError:
86
+ pass
87
+
88
+ def create_xl_mask(
89
+ seq_len,
90
+ kv_seq_len,
91
+ window_size,
92
+ episode_ids = None, # (b n) - in the case that within the same batch there are multiple episodes
93
+ lookback_blocks = 1, # in transformer-xl, lookback is one window size block, but can be multiple for longer context
94
+ device = None
95
+ ):
96
+ assert kv_seq_len >= seq_len
97
+ assert window_size <= seq_len
98
+
99
+ offset = kv_seq_len - seq_len
100
+
101
+ def create_block_mask_fn(b, __, q, k):
102
+ offset_q = q + offset
103
+ block_q = offset_q // window_size
104
+ block_k = k // window_size
105
+
106
+ causal_mask = offset_q >= k
107
+
108
+ # in transformer-xl, the previous segment is fully attended to - may just double the segments and make this sliding for ease of inference logic
109
+
110
+ block_mask = (block_q >= block_k) & (block_q <= (block_k + lookback_blocks))
111
+
112
+ mask = causal_mask & block_mask
113
+
114
+ # handle intra-episodic attention if needed
115
+
116
+ if exists(episode_ids):
117
+ q_episode = episodes[b, q + offset]
118
+ k_episode = episodes[b, k]
119
+
120
+ intra_episode_mask = q_episode == k_episode
121
+ mask = mask & intra_episode_mask
122
+
123
+ return mask
124
+
125
+ create_kwargs = dict(device = device) if exists(device) else dict()
126
+ return create_block_mask(create_block_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True, **create_kwargs)
127
+
128
+ def create_sliding_mask(
129
+ seq_len,
130
+ kv_seq_len,
131
+ window_size,
132
+ device = None
133
+ ):
134
+ assert kv_seq_len >= seq_len
135
+ offset = kv_seq_len - seq_len
136
+
137
+ def sliding_mask(_, __, q, k):
138
+ offset_q = q + offset
139
+ distance = offset_q - k
140
+
141
+ backward_sliding_mask = distance <= window_size
142
+ forward_sliding_mask = distance >= 0
143
+
144
+ return backward_sliding_mask & forward_sliding_mask
145
+
146
+ create_kwargs = dict(device = device) if exists(device) else dict()
147
+ return create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True, **create_kwargs)
148
+
149
+ # transformer-xl with ppo
150
+
151
+ class Attention(Module):
152
+ def __init__(
153
+ self,
154
+ dim,
155
+ window_size,
156
+ dim_head = 64,
157
+ heads = 8,
158
+ pre_rmsnorm = True,
159
+ fixed_window_size = False,
160
+ accept_value_residual = False
161
+ ):
162
+ super().__init__()
163
+ self.scale = dim_head ** -0.5
164
+
165
+ self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
166
+
167
+ self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
168
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
169
+
170
+ self.rotary_embed = RotaryEmbedding(dim_head)
171
+
172
+ dim_inner = dim_head * heads
173
+ self.to_q = LinearNoBias(dim, dim_inner)
174
+ self.to_kv = LinearNoBias(dim, dim_inner * 2)
175
+ self.to_out = LinearNoBias(dim_inner, dim)
176
+
177
+ self.to_v_gates = Sequential(
178
+ LinearNoBias(dim, heads),
179
+ Rearrange('b n h -> b h n 1'),
180
+ nn.Sigmoid()
181
+ )
182
+
183
+ # value residual
184
+
185
+ self.accept_value_residual = accept_value_residual
186
+
187
+ if accept_value_residual:
188
+ self.to_value_residual_mix = Sequential(
189
+ LinearNoBias(dim, heads),
190
+ Rearrange('b n h -> b h n 1'),
191
+ nn.Sigmoid()
192
+ )
193
+
194
+ # fixed window size
195
+
196
+ self.fixed_window_size = fixed_window_size
197
+ self.window_size = window_size
198
+
199
+ def forward(
200
+ self,
201
+ tokens,
202
+ value_residual = None,
203
+ kv_cache = None,
204
+ return_kv_cache = False,
205
+ ):
206
+ seq_len = tokens.shape[-2]
207
+ assert seq_len <= self.window_size
208
+
209
+ device = tokens.device
210
+
211
+ tokens = self.norm(tokens)
212
+
213
+ q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
214
+
215
+ q, k, v = map(self.split_heads, (q, k, v))
216
+
217
+ orig_v = v
218
+
219
+ q = q * self.scale
220
+
221
+ if exists(value_residual):
222
+ assert self.accept_value_residual
223
+ mix = self.to_value_residual_mix(tokens)
224
+ v = v.lerp(value_residual, mix)
225
+
226
+ if exists(kv_cache):
227
+ ck, cv = kv_cache
228
+ k = cat((ck, k), dim = -2)
229
+ v = cat((cv, v), dim = -2)
230
+
231
+ if return_kv_cache:
232
+ next_kv_cache = stack((k, v))
233
+
234
+ q, k = self.rotary_embed.rotate_queries_with_cached_keys(q, k)
235
+
236
+ sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
237
+
238
+ i, j = sim.shape[-2:]
239
+
240
+ if self.fixed_window_size:
241
+ i_seq = arange(i, device = device)
242
+ j_seq = arange(j, device = device) - (j - i)
243
+ dist = einx.subtract('i, j -> i j', i_seq, j_seq)
244
+ causal_mask = (dist < 0) | (dist > self.window_size)
245
+ else:
246
+ causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
247
+
248
+ sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
249
+
250
+ attn = sim.softmax(dim = -1)
251
+
252
+ out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
253
+
254
+ out = out * self.to_v_gates(tokens)
255
+
256
+ out = self.merge_heads(out)
257
+
258
+ out = self.to_out(out)
259
+
260
+ if not return_kv_cache:
261
+ return out
262
+
263
+ return out, (next_kv_cache, orig_v)
264
+
265
+ class FeedForward(Module):
266
+ def __init__(
267
+ self,
268
+ dim,
269
+ expansion_factor = 4.,
270
+ pre_rmsnorm = True
271
+ ):
272
+ super().__init__()
273
+ self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
274
+
275
+ dim_inner = int(dim * expansion_factor * 2 / 3)
276
+
277
+ self.proj_in = Linear(dim, dim_inner * 2)
278
+ self.proj_out = Linear(dim_inner, dim)
279
+
280
+ def forward(
281
+ self,
282
+ x
283
+ ):
284
+ x = self.norm(x)
285
+
286
+ x, gates = self.proj_in(x).chunk(2, dim = -1)
287
+
288
+ x = x * F.gelu(gates)
289
+
290
+ return self.proj_out(x)
291
+
292
+ class TransformerXL(Module):
293
+ def __init__(
294
+ self,
295
+ dim,
296
+ depth,
297
+ window_size,
298
+ dim_head = 64,
299
+ heads = 8,
300
+ expansion_factor = 4.,
301
+ final_norm = True,
302
+ fixed_window_size = False,
303
+ ):
304
+ super().__init__()
305
+
306
+ layers = ModuleList([])
307
+
308
+ for i in range(depth):
309
+ is_first = i == 0
310
+
311
+ attn = Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first)
312
+
313
+ ff = FeedForward(dim = dim, expansion_factor = expansion_factor)
314
+
315
+ layers.append(ModuleList([
316
+ attn, ff
317
+ ]))
318
+
319
+ self.layers = layers
320
+ self.norm = RMSNorm(dim) if final_norm else Identity()
321
+
322
+ # fixed window size
323
+
324
+ self.fixed_window_size = fixed_window_size
325
+ self.window_size = window_size
326
+
327
+ def forward(
328
+ self,
329
+ x,
330
+ cache = None,
331
+ return_kv_cache = False
332
+ ):
333
+
334
+ cache = default(cache, (None,) * len(self.layers))
335
+
336
+ next_kv_caches = []
337
+ value_residual = None
338
+
339
+ for (attn, ff), kv_cache in zip(self.layers, cache):
340
+
341
+ attn_out, (next_kv_cache, values) = attn(x, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
342
+
343
+ x = attn_out + x
344
+ x = ff(x) + x
345
+
346
+ next_kv_caches.append(next_kv_cache)
347
+ value_residual = default(value_residual, values)
348
+
349
+ embed = self.norm(x)
350
+
351
+ if not return_kv_cache:
352
+ return embed
353
+
354
+ next_kv_cache = stack(next_kv_caches)
355
+
356
+ next_kv_cache = next_kv_cache[..., -self.window_size:, :]
357
+
358
+ return embed, next_kv_cache
359
+
360
+ # class
361
+
362
+ class Locoformer(Module):
363
+ def __init__(
364
+ self,
365
+ embedder: Module,
366
+ unembedder: Module,
367
+ transformer: dict | TransformerXL,
368
+ value_network: Module | None = None
369
+ ):
370
+ super().__init__()
371
+
372
+ if isinstance(transformer, dict):
373
+ transformer = TransformerXL(**transformer)
374
+
375
+ self.transformer = transformer
376
+
377
+ self.embedder = embedder
378
+ self.unembedder = unembedder
379
+
380
+ self.value_network = value_network
381
+
382
+ self.fixed_window_size = transformer.fixed_window_size
383
+ self.window_size = transformer.window_size
384
+
385
+ @property
386
+ def device(self):
387
+ return next(self.parameters()).device
388
+
389
+ def get_stateful_forward(
390
+ self,
391
+ initial_states: Tensor | None = None,
392
+ inference_mode = False,
393
+ has_batch_dim = False,
394
+ **kwargs
395
+ ):
396
+ window_size = self.window_size
397
+
398
+ cache = None
399
+
400
+ def stateful_forward(state: Tensor, **override_kwargs):
401
+ nonlocal cache
402
+
403
+ # handle no batch, for easier time rolling out against envs
404
+
405
+ if not has_batch_dim:
406
+ state = rearrange(state, '... -> 1 ...')
407
+
408
+ # forwards
409
+
410
+ out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
411
+
412
+ # handle cache
413
+
414
+ cache_len = cache.shape[-2]
415
+
416
+ if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
417
+ cache = cache[..., -window_size:, :]
418
+
419
+ # maybe remove batch
420
+
421
+ if not has_batch_dim:
422
+ out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
423
+
424
+ return out
425
+
426
+ if inference_mode:
427
+ stateful_forward = torch.inference_mode()(stateful_forward)
428
+
429
+ # handle prompt
430
+
431
+ if not exists(initial_states):
432
+ return stateful_forward
433
+
434
+ initial_logits = []
435
+
436
+ for state_segments in initial_states.split(self.window_size, dim = -1):
437
+
438
+ logits = stateful_forward(state_segments, return_values = False)
439
+ initial_logits.append(logits)
440
+
441
+ initial_logits = cat(initial_logits, dim = -2)
442
+
443
+ return stateful_forward, initial_logits
444
+
445
+ def forward(
446
+ self,
447
+ state: Tensor,
448
+ cache: Tensor | None = None,
449
+ detach_cache = False,
450
+ return_values = False
451
+ ):
452
+
453
+ tokens = self.embedder(state)
454
+
455
+ embed, kv_cache = self.transformer(tokens, cache = cache, return_kv_cache = True)
456
+
457
+ # unembed to actions - in language models this would be the next state
458
+
459
+ action_logits = self.unembedder(embed)
460
+
461
+ out = action_logits
462
+
463
+ # maybe detach cache
464
+
465
+ if detach_cache:
466
+ kv_cache = detach_all(kv_cache)
467
+
468
+ # handle returning of values
469
+
470
+ if return_values:
471
+ assert exists(self.value_network)
472
+
473
+ values = self.value_network(embed)
474
+
475
+ if values.ndim == 3:
476
+ assert values.shape[-1] == 1
477
+ values = rearrange(values, '... 1 -> ...')
478
+
479
+ out = (out, values)
480
+
481
+ # output and cache
482
+
483
+ return out, kv_cache
@@ -0,0 +1,69 @@
1
+ [project]
2
+ name = "locoformer"
3
+ version = "0.0.6"
4
+ description = "LocoFormer"
5
+ authors = [
6
+ { name = "Phil Wang", email = "lucidrains@gmail.com" }
7
+ ]
8
+ readme = "README.md"
9
+ requires-python = ">= 3.9"
10
+ license = { file = "LICENSE" }
11
+ keywords = [
12
+ 'artificial intelligence',
13
+ 'deep learning',
14
+ 'transformer',
15
+ 'attention mechanism',
16
+ 'robotics',
17
+ 'cross-embodiment',
18
+ ]
19
+
20
+ classifiers=[
21
+ 'Development Status :: 4 - Beta',
22
+ 'Intended Audience :: Developers',
23
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
24
+ 'License :: OSI Approved :: MIT License',
25
+ 'Programming Language :: Python :: 3.9',
26
+ ]
27
+
28
+ dependencies = [
29
+ "assoc-scan",
30
+ "einx>=0.3.0",
31
+ "einops>=0.8.0",
32
+ "rotary-embedding-torch",
33
+ "torch>=2.4",
34
+ "x-mlps-pytorch",
35
+ ]
36
+
37
+ [project.urls]
38
+ Homepage = "https://pypi.org/project/locoformer/"
39
+ Repository = "https://github.com/lucidrains/locoformer"
40
+
41
+ [project.optional-dependencies]
42
+
43
+ examples = [
44
+ "accelerate",
45
+ "tqdm"
46
+ ]
47
+
48
+ test = [
49
+ "pytest"
50
+ ]
51
+
52
+ [tool.pytest.ini_options]
53
+ pythonpath = [
54
+ "."
55
+ ]
56
+
57
+ [build-system]
58
+ requires = ["hatchling"]
59
+ build-backend = "hatchling.build"
60
+
61
+ [tool.rye]
62
+ managed = true
63
+ dev-dependencies = []
64
+
65
+ [tool.hatch.metadata]
66
+ allow-direct-references = true
67
+
68
+ [tool.hatch.build.targets.wheel]
69
+ packages = ["locoformer"]
@@ -0,0 +1,38 @@
1
+ import pytest
2
+ param = pytest.mark.parametrize
3
+
4
+ import torch
5
+ from x_mlps_pytorch import MLP
6
+
7
+ from einops import rearrange
8
+
9
+ def test_locoformer():
10
+ from locoformer.locoformer import Locoformer
11
+ from torch import nn
12
+
13
+ model = Locoformer(
14
+ embedder = nn.Embedding(256, 128),
15
+ unembedder = nn.Linear(128, 256, bias = False),
16
+ value_network = MLP(128, 32, 1),
17
+ transformer = dict(
18
+ dim = 128,
19
+ depth = 1,
20
+ window_size = 256
21
+ )
22
+ )
23
+
24
+ seq = torch.randint(0, 256, (3, 512))
25
+
26
+ (logits, values), cache = model(seq, return_values = True)
27
+ (logits, values), cache = model(seq, return_values = True, cache = cache)
28
+ (logits, values), cache = model(seq, return_values = True, cache = cache)
29
+
30
+ assert logits.shape == (3, 512, 256)
31
+
32
+ stateful_forward = model.get_stateful_forward(256, has_batch_dim = True, return_values = True, inference_mode = True)
33
+
34
+ for state in seq.unbind(dim = -1):
35
+ state = rearrange(state, 'b -> b 1')
36
+
37
+ logits, values = stateful_forward(state)
38
+ assert logits.shape == (3, 1, 256)
@@ -0,0 +1,191 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # 'accelerate',
4
+ # 'locoformer',
5
+ # 'tqdm'
6
+ # ]
7
+ # ///
8
+
9
+ import tqdm
10
+ import gzip
11
+ from math import ceil
12
+ import numpy as np
13
+
14
+ import torch
15
+ from torch import nn
16
+ from torch import from_numpy
17
+ from torch.optim import Adam
18
+ from torch.nn import functional as F
19
+ from torch.utils.data import DataLoader, Dataset
20
+
21
+ from einops import rearrange
22
+ from accelerate import Accelerator
23
+
24
+ from locoformer.locoformer import Locoformer
25
+
26
+ # constants
27
+
28
+ NUM_BATCHES = int(1e5)
29
+ BATCH_SIZE = 16
30
+ LEARNING_RATE = 2e-4
31
+ VALIDATE_EVERY = 100
32
+
33
+ GENERATE_EVERY = 250
34
+ PRIME_LENGTH = 32
35
+ GENERATE_LENGTH = 1024
36
+
37
+ SEQ_LEN = 256
38
+ NUM_SEGMENTS = 4
39
+ FIXED_WINDOW_SIZE = False
40
+
41
+ # helpers
42
+
43
+ def cycle(loader):
44
+ while True:
45
+ for data in loader:
46
+ yield data
47
+
48
+ def divisible_by(num, den):
49
+ return (num % den) == 0
50
+
51
+ def decode_token(token):
52
+ return str(chr(max(32, token)))
53
+
54
+ def decode_tokens(tokens):
55
+ return ''.join(list(map(decode_token, tokens)))
56
+
57
+ # sampling
58
+
59
+ def log(t, eps = 1e-20):
60
+ return t.clamp(min = eps).log()
61
+
62
+ def gumbel_noise(t):
63
+ return -log(-log(torch.rand_like(t)))
64
+
65
+ def gumbel_sample(logits, temperature = 1., eps = 1e-6, keepdim = True):
66
+ noise = gumbel_noise(logits)
67
+ return ((logits / max(temperature, eps)) + noise).argmax(dim = -1, keepdim = keepdim)
68
+
69
+ def topk_logits_filter(logits, frac_num_tokens = 0.1):
70
+ num_tokens = logits.shape[-1]
71
+ k = ceil(frac_num_tokens * num_tokens)
72
+
73
+ val, ind = torch.topk(logits, k)
74
+ probs = torch.full_like(logits, float('-inf'))
75
+ probs.scatter_(-1, ind, val)
76
+ return probs
77
+
78
+ # instantiate model
79
+
80
+ dim_model = 512
81
+
82
+ model = Locoformer(
83
+ embedder = nn.Embedding(256, dim_model),
84
+ unembedder = nn.Linear(dim_model, 256, bias = False),
85
+ transformer = dict(
86
+ dim = dim_model,
87
+ depth = 6,
88
+ fixed_window_size = FIXED_WINDOW_SIZE,
89
+ window_size = SEQ_LEN
90
+ )
91
+ )
92
+
93
+ # prepare enwik8 data
94
+
95
+ with gzip.open('./data/enwik8.gz') as file:
96
+ data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
97
+ train_data, valid_data = np.split(data, [int(90e6)])
98
+ data_train, data_val = from_numpy(train_data), from_numpy(valid_data)
99
+
100
+ class TextSamplerDataset(Dataset):
101
+ def __init__(self, data, seq_len, segments):
102
+ super().__init__()
103
+ self.data = data
104
+ self.seq_len = seq_len
105
+ self.segments = segments
106
+ self.total_len = seq_len * segments
107
+
108
+ def __getitem__(self, index):
109
+ rand_start = torch.randint(0, self.data.size(0) - self.total_len - 1, (1,))
110
+ full_seq = self.data[rand_start: rand_start + self.total_len + 1].long()
111
+ return full_seq
112
+
113
+ def __len__(self):
114
+ return self.data.size(0) // self.total_len
115
+
116
+ train_dataset = TextSamplerDataset(data_train, SEQ_LEN, NUM_SEGMENTS)
117
+ val_dataset = TextSamplerDataset(data_val, SEQ_LEN, NUM_SEGMENTS)
118
+ train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE)
119
+ val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE)
120
+
121
+ # optimizer
122
+
123
+ optim = Adam(model.parameters(), lr = LEARNING_RATE)
124
+
125
+ # prepare accelerate
126
+
127
+ accelerate = Accelerator()
128
+
129
+ model, optim, train_loader = accelerate.prepare(model, optim, train_loader)
130
+
131
+ # training loop
132
+
133
+ train_loader_iter = cycle(train_loader)
134
+ val_loader_iter = cycle(val_loader)
135
+
136
+ for i in range(NUM_BATCHES):
137
+ model.train()
138
+
139
+ seq = next(train_loader_iter)
140
+ seq, labels = seq[:, :-1], seq[:, 1:]
141
+
142
+ cache = None
143
+
144
+ for segment_seq, segment_labels in zip(seq.chunk(NUM_SEGMENTS, dim = -1), labels.chunk(NUM_SEGMENTS, dim = -1)):
145
+
146
+ logits, cache = model(
147
+ segment_seq,
148
+ cache = cache,
149
+ detach_cache = True
150
+ )
151
+
152
+ loss = F.cross_entropy(
153
+ rearrange(logits, 'b n l -> b l n'),
154
+ segment_labels
155
+ )
156
+
157
+ accelerate.backward(loss / NUM_SEGMENTS)
158
+ accelerate.print(f'[{i}] loss: {loss.item():.3f}')
159
+
160
+ optim.step()
161
+ optim.zero_grad()
162
+
163
+ if divisible_by(i + 1, GENERATE_EVERY):
164
+ model.eval()
165
+
166
+ val_seq = next(val_loader_iter)
167
+ prime = val_seq[0, :PRIME_LENGTH]
168
+
169
+ prime = prime.to(model.device)
170
+ out = prime
171
+
172
+ stateful_forward, logits = model.get_stateful_forward(has_batch_dim = False, initial_states = prime, inference_mode = True)
173
+
174
+ # sample
175
+
176
+ while out.shape[-1] < GENERATE_LENGTH:
177
+ filtered_logits = topk_logits_filter(logits[-1])
178
+
179
+ sampled = gumbel_sample(filtered_logits)
180
+ out = torch.cat((out, sampled), dim = -1)
181
+
182
+ logits = stateful_forward(sampled)
183
+
184
+ # decoded
185
+
186
+ decoded_prime = decode_tokens(prime.cpu())
187
+ decoded_string = decode_tokens(out[PRIME_LENGTH:].cpu())
188
+
189
+ print(f'\n\n[prime]: {decoded_prime}\n\n')
190
+ print('*' * 100)
191
+ print(f'\n\n [generated]: {decoded_string}\n\n')