locoformer 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
locoformer/__init__.py ADDED
@@ -0,0 +1 @@
1
+ from locoformer.locoformer import Locoformer
@@ -0,0 +1,414 @@
1
+ from __future__ import annotations
2
+ from functools import partial
3
+
4
+ import torch
5
+ from torch import cat, stack, is_tensor
6
+ import torch.nn.functional as F
7
+ from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity
8
+ from torch.utils._pytree import tree_map
9
+
10
+ from einops import rearrange, einsum
11
+ from einops.layers.torch import Rearrange
12
+
13
+ from rotary_embedding_torch import RotaryEmbedding
14
+
15
+ from assoc_scan import AssocScan
16
+
17
+ LinearNoBias = partial(Linear, bias = False)
18
+
19
+ # helper functions
20
+
21
+ def exists(v):
22
+ return v is not None
23
+
24
+ def default(v, d):
25
+ return v if exists(v) else d
26
+
27
+ def divisible_by(num, den):
28
+ return (num % den) == 0
29
+
30
+ def tree_map_tensor(x, fn):
31
+ return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
32
+
33
+ def detach_all(x):
34
+ return tree_map_tensor(x, lambda t: t.detach())
35
+
36
+ def combine_kv_cache(cache1, cache2):
37
+ combined_cache = []
38
+
39
+ for layer_cache1, layer_cache2 in zip(cache1, cache2):
40
+ next_cache = cat((layer_cache1, layer_cache2), dim = -2)
41
+ combined_cache.append(next_cache)
42
+
43
+ return combined_cache
44
+
45
+ # generalized advantage estimate
46
+
47
+ @torch.no_grad()
48
+ def calc_gae(
49
+ rewards,
50
+ values,
51
+ masks,
52
+ gamma = 0.99,
53
+ lam = 0.95,
54
+ use_accelerated = None
55
+ ):
56
+ assert values.shape[-1] == rewards.shape[-1]
57
+ use_accelerated = default(use_accelerated, rewards.is_cuda)
58
+
59
+ values = F.pad(values, (0, 1), value = 0.)
60
+ values, values_next = values[..., :-1], values[..., 1:]
61
+
62
+ delta = rewards + gamma * values_next * masks - values
63
+ gates = gamma * lam * masks
64
+
65
+ scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
66
+
67
+ gae = scan(gates, delta)
68
+
69
+ returns = gae + values
70
+
71
+ return returns
72
+
73
+ # transformer-xl mask w/ flex attn
74
+
75
+ flex_attention = None
76
+
77
+ try:
78
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
79
+ if torch.cuda.is_available():
80
+ flex_attention = torch.compile(flex_attention)
81
+ except ImportError:
82
+ pass
83
+
84
+ def create_xl_mask(
85
+ seq_len,
86
+ kv_seq_len,
87
+ window_size,
88
+ episode_ids = None, # (b n) - in the case that within the same batch there are multiple episodes
89
+ lookback_blocks = 1, # in transformer-xl, lookback is one window size block, but can be multiple for longer context
90
+ device = None
91
+ ):
92
+ assert kv_seq_len >= seq_len
93
+ assert window_size <= seq_len
94
+
95
+ offset = kv_seq_len - seq_len
96
+
97
+ def create_block_mask_fn(b, __, q, k):
98
+ offset_q = q + offset
99
+ block_q = offset_q // window_size
100
+ block_k = k // window_size
101
+
102
+ causal_mask = offset_q >= k
103
+
104
+ # in transformer-xl, the previous segment is fully attended to - may just double the segments and make this sliding for ease of inference logic
105
+
106
+ block_mask = (block_q >= block_k) & (block_q <= (block_k + lookback_blocks))
107
+
108
+ mask = causal_mask & block_mask
109
+
110
+ # handle intra-episodic attention if needed
111
+
112
+ if exists(episode_ids):
113
+ q_episode = episodes[b, q + offset]
114
+ k_episode = episodes[b, k]
115
+
116
+ intra_episode_mask = q_episode == k_episode
117
+ mask = mask & intra_episode_mask
118
+
119
+ return mask
120
+
121
+ create_kwargs = dict(device = device) if exists(device) else dict()
122
+ return create_block_mask(create_block_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True, **create_kwargs)
123
+
124
+ def create_sliding_mask(
125
+ seq_len,
126
+ kv_seq_len,
127
+ window_size,
128
+ device = None
129
+ ):
130
+ assert kv_seq_len >= seq_len
131
+ offset = kv_seq_len - seq_len
132
+
133
+ def sliding_mask(_, __, q, k):
134
+ offset_q = q + offset
135
+ distance = offset_q - k
136
+
137
+ backward_sliding_mask = distance <= window_size
138
+ forward_sliding_mask = distance >= 0
139
+
140
+ return backward_sliding_mask & forward_sliding_mask
141
+
142
+ create_kwargs = dict(device = device) if exists(device) else dict()
143
+ return create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True, **create_kwargs)
144
+
145
+ # transformer-xl with ppo
146
+
147
+ class Attention(Module):
148
+ def __init__(
149
+ self,
150
+ dim,
151
+ dim_head = 64,
152
+ heads = 8,
153
+ pre_rmsnorm = True
154
+ ):
155
+ super().__init__()
156
+ self.scale = dim_head ** -0.5
157
+
158
+ self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
159
+
160
+ self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
161
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
162
+
163
+ self.rotary_embed = RotaryEmbedding(dim_head)
164
+
165
+ dim_inner = dim_head * heads
166
+ self.to_q = LinearNoBias(dim, dim_inner)
167
+ self.to_kv = LinearNoBias(dim, dim_inner * 2)
168
+ self.to_out = LinearNoBias(dim_inner, dim)
169
+
170
+ def forward(
171
+ self,
172
+ tokens,
173
+ kv_cache = None,
174
+ return_kv_cache = False
175
+ ):
176
+ tokens = self.norm(tokens)
177
+
178
+ q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
179
+
180
+ q, k, v = map(self.split_heads, (q, k, v))
181
+
182
+ q = q * self.scale
183
+
184
+ if exists(kv_cache):
185
+ ck, cv = kv_cache
186
+ k = cat((ck, k), dim = -2)
187
+ v = cat((cv, v), dim = -2)
188
+
189
+ if return_kv_cache:
190
+ next_kv_cache = stack((k, v))
191
+
192
+ q, k = self.rotary_embed.rotate_queries_with_cached_keys(q, k)
193
+
194
+ sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
195
+
196
+ i, j = sim.shape[-2:]
197
+
198
+ causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
199
+
200
+ sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
201
+
202
+ attn = sim.softmax(dim = -1)
203
+
204
+ out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
205
+
206
+ out = self.merge_heads(out)
207
+
208
+ out = self.to_out(out)
209
+
210
+ if not return_kv_cache:
211
+ return out
212
+
213
+ return out, next_kv_cache
214
+
215
+ class FeedForward(Module):
216
+ def __init__(
217
+ self,
218
+ dim,
219
+ expansion_factor = 4.,
220
+ pre_rmsnorm = True
221
+ ):
222
+ super().__init__()
223
+ self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
224
+
225
+ dim_inner = int(dim * expansion_factor * 2 / 3)
226
+
227
+ self.proj_in = Linear(dim, dim_inner * 2)
228
+ self.proj_out = Linear(dim_inner, dim)
229
+
230
+ def forward(
231
+ self,
232
+ x
233
+ ):
234
+ x = self.norm(x)
235
+
236
+ x, gates = self.proj_in(x).chunk(2, dim = -1)
237
+
238
+ x = x * F.gelu(gates)
239
+
240
+ return self.proj_out(x)
241
+
242
+ class TransformerXL(Module):
243
+ def __init__(
244
+ self,
245
+ dim,
246
+ depth,
247
+ dim_head = 64,
248
+ heads = 8,
249
+ expansion_factor = 4.,
250
+ final_norm = True
251
+ ):
252
+ super().__init__()
253
+
254
+ layers = ModuleList([])
255
+
256
+ for _ in range(depth):
257
+ attn = Attention(dim = dim, dim_head = dim_head, heads = heads)
258
+
259
+ ff = FeedForward(dim = dim, expansion_factor = expansion_factor)
260
+
261
+ layers.append(ModuleList([
262
+ attn, ff
263
+ ]))
264
+
265
+ self.layers = layers
266
+ self.norm = RMSNorm(dim) if final_norm else Identity()
267
+
268
+ def forward(
269
+ self,
270
+ x,
271
+ cache = None,
272
+ return_kv_cache = False
273
+ ):
274
+
275
+ cache = default(cache, (None,) * len(self.layers))
276
+
277
+ next_kv_caches = []
278
+
279
+ for (attn, ff), kv_cache in zip(self.layers, cache):
280
+
281
+ attn_out, next_kv_cache = attn(x, kv_cache = kv_cache, return_kv_cache = True)
282
+
283
+ next_kv_caches.append(next_kv_cache)
284
+
285
+ x = attn_out + x
286
+ x = ff(x) + x
287
+
288
+ embed = self.norm(x)
289
+
290
+ if not return_kv_cache:
291
+ return embed
292
+
293
+ return embed, stack(next_kv_caches)
294
+
295
+ # class
296
+
297
+ class Locoformer(Module):
298
+ def __init__(
299
+ self,
300
+ embedder: Module,
301
+ unembedder: Module,
302
+ transformer: dict | TransformerXL,
303
+ value_network: Module | None = None
304
+ ):
305
+ super().__init__()
306
+
307
+ if isinstance(transformer, dict):
308
+ transformer = TransformerXL(**transformer)
309
+
310
+ self.transformer = transformer
311
+
312
+ self.embedder = embedder
313
+ self.unembedder = unembedder
314
+
315
+ self.value_network = value_network
316
+
317
+ @property
318
+ def device(self):
319
+ return next(self.parameters()).device
320
+
321
+ def get_stateful_forward(
322
+ self,
323
+ segment_size,
324
+ initial_states: Tensor | None = None,
325
+ inference_mode = False,
326
+ has_batch_dim = False,
327
+ **kwargs
328
+ ):
329
+ cache = None
330
+
331
+ def stateful_forward(state: Tensor, override_kwargs: dict = dict()):
332
+ nonlocal cache
333
+
334
+ # handle no batch, for easier time rolling out against envs
335
+
336
+ if not has_batch_dim:
337
+ state = rearrange(state, '... -> 1 ...')
338
+
339
+ # forwards
340
+
341
+ out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
342
+
343
+ # handle cache
344
+
345
+ cache_len = cache.shape[-2]
346
+
347
+ if divisible_by(cache_len, segment_size * 2):
348
+ cache = cache[..., -segment_size:, :]
349
+
350
+ # maybe remove batch
351
+
352
+ if not has_batch_dim:
353
+ out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
354
+
355
+ return out
356
+
357
+ if inference_mode:
358
+ stateful_forward = torch.inference_mode()(stateful_forward)
359
+
360
+ # handle prompt
361
+
362
+ if not exists(initial_states):
363
+ return stateful_forward
364
+
365
+ initial_logits = []
366
+
367
+ for state_segments in initial_states.split(segment_size, dim = -1):
368
+
369
+ logits = stateful_forward(state_segments, return_values = False)
370
+ initial_logits.append(logits)
371
+
372
+ initial_logits = cat(initial_logits, dim = -2)
373
+
374
+ return stateful_forward, initial_logits
375
+
376
+ def forward(
377
+ self,
378
+ state: Tensor,
379
+ cache: Tensor | None = None,
380
+ detach_cache = False,
381
+ return_values = False
382
+ ):
383
+
384
+ tokens = self.embedder(state)
385
+
386
+ embed, kv_cache = self.transformer(tokens, cache = cache, return_kv_cache = True)
387
+
388
+ # unembed to actions - in language models this would be the next state
389
+
390
+ action_logits = self.unembedder(embed)
391
+
392
+ out = action_logits
393
+
394
+ # maybe detach cache
395
+
396
+ if detach_cache:
397
+ kv_cache = detach_all(kv_cache)
398
+
399
+ # handle returning of values
400
+
401
+ if return_values:
402
+ assert exists(self.value_network)
403
+
404
+ values = self.value_network(embed)
405
+
406
+ if values.ndim == 3:
407
+ assert values.shape[-1] == 1
408
+ values = rearrange(values, '... 1 -> ...')
409
+
410
+ out = (out, values)
411
+
412
+ # output and cache
413
+
414
+ return out, kv_cache
@@ -0,0 +1,71 @@
1
+ Metadata-Version: 2.4
2
+ Name: locoformer
3
+ Version: 0.0.5
4
+ Summary: LocoFormer
5
+ Project-URL: Homepage, https://pypi.org/project/locoformer/
6
+ Project-URL: Repository, https://github.com/lucidrains/locoformer
7
+ Author-email: Phil Wang <lucidrains@gmail.com>
8
+ License: MIT License
9
+
10
+ Copyright (c) 2025 Phil Wang
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+ License-File: LICENSE
30
+ Keywords: artificial intelligence,attention mechanism,cross-embodiment,deep learning,robotics,transformer
31
+ Classifier: Development Status :: 4 - Beta
32
+ Classifier: Intended Audience :: Developers
33
+ Classifier: License :: OSI Approved :: MIT License
34
+ Classifier: Programming Language :: Python :: 3.9
35
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
+ Requires-Python: >=3.9
37
+ Requires-Dist: assoc-scan
38
+ Requires-Dist: einops>=0.8.0
39
+ Requires-Dist: einx>=0.3.0
40
+ Requires-Dist: rotary-embedding-torch
41
+ Requires-Dist: torch>=2.4
42
+ Requires-Dist: x-mlps-pytorch
43
+ Provides-Extra: examples
44
+ Requires-Dist: accelerate; extra == 'examples'
45
+ Requires-Dist: tqdm; extra == 'examples'
46
+ Provides-Extra: test
47
+ Requires-Dist: pytest; extra == 'test'
48
+ Description-Content-Type: text/markdown
49
+
50
+ <img src="./fig3.png" width="400px"></img>
51
+
52
+ ## LocoFormer (wip)
53
+
54
+ [LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
55
+
56
+ The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment). When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
57
+
58
+ ## Sponsors
59
+
60
+ This open sourced work is sponsored by [Safe Sentinel](https://www.safesentinels.com/)
61
+
62
+ ## Citations
63
+
64
+ ```bibtex
65
+ @article{liu2025locoformer,
66
+ title = {LocoFormer: Generalist Locomotion via Long-Context Adaptation},
67
+ author = {Liu, Min and Pathak, Deepak and Agarwal, Ananye},
68
+ journal = {Conference on Robot Learning ({CoRL})},
69
+ year = {2025}
70
+ }
71
+ ```
@@ -0,0 +1,6 @@
1
+ locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
2
+ locoformer/locoformer.py,sha256=Yoh3hrj2E_91YLoYRa73wGzjdIiMdcd5ofNjkiVlogI,10570
3
+ locoformer-0.0.5.dist-info/METADATA,sha256=oe6HfOwWKQvusiJl1ukmNFcrGRhdDZ6NcKZi3upv-SY,3159
4
+ locoformer-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ locoformer-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ locoformer-0.0.5.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Phil Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.