locoformer 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- locoformer/locoformer.py +102 -20
- {locoformer-0.0.5.dist-info → locoformer-0.0.7.dist-info}/METADATA +2 -2
- locoformer-0.0.7.dist-info/RECORD +6 -0
- locoformer-0.0.5.dist-info/RECORD +0 -6
- {locoformer-0.0.5.dist-info → locoformer-0.0.7.dist-info}/WHEEL +0 -0
- {locoformer-0.0.5.dist-info → locoformer-0.0.7.dist-info}/licenses/LICENSE +0 -0
locoformer/locoformer.py
CHANGED
|
@@ -2,11 +2,12 @@ from __future__ import annotations
|
|
|
2
2
|
from functools import partial
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
-
from torch import cat, stack, is_tensor
|
|
5
|
+
from torch import nn, cat, stack, arange, Tensor, is_tensor
|
|
6
6
|
import torch.nn.functional as F
|
|
7
|
-
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity
|
|
7
|
+
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
|
|
8
8
|
from torch.utils._pytree import tree_map
|
|
9
9
|
|
|
10
|
+
import einx
|
|
10
11
|
from einops import rearrange, einsum
|
|
11
12
|
from einops.layers.torch import Rearrange
|
|
12
13
|
|
|
@@ -24,6 +25,9 @@ def exists(v):
|
|
|
24
25
|
def default(v, d):
|
|
25
26
|
return v if exists(v) else d
|
|
26
27
|
|
|
28
|
+
def first(arr):
|
|
29
|
+
return arr[0]
|
|
30
|
+
|
|
27
31
|
def divisible_by(num, den):
|
|
28
32
|
return (num % den) == 0
|
|
29
33
|
|
|
@@ -148,9 +152,12 @@ class Attention(Module):
|
|
|
148
152
|
def __init__(
|
|
149
153
|
self,
|
|
150
154
|
dim,
|
|
155
|
+
window_size,
|
|
151
156
|
dim_head = 64,
|
|
152
157
|
heads = 8,
|
|
153
|
-
pre_rmsnorm = True
|
|
158
|
+
pre_rmsnorm = True,
|
|
159
|
+
fixed_window_size = False,
|
|
160
|
+
accept_value_residual = False
|
|
154
161
|
):
|
|
155
162
|
super().__init__()
|
|
156
163
|
self.scale = dim_head ** -0.5
|
|
@@ -167,20 +174,55 @@ class Attention(Module):
|
|
|
167
174
|
self.to_kv = LinearNoBias(dim, dim_inner * 2)
|
|
168
175
|
self.to_out = LinearNoBias(dim_inner, dim)
|
|
169
176
|
|
|
177
|
+
self.to_v_gates = Sequential(
|
|
178
|
+
LinearNoBias(dim, heads),
|
|
179
|
+
Rearrange('b n h -> b h n 1'),
|
|
180
|
+
nn.Sigmoid()
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# value residual
|
|
184
|
+
|
|
185
|
+
self.accept_value_residual = accept_value_residual
|
|
186
|
+
|
|
187
|
+
if accept_value_residual:
|
|
188
|
+
self.to_value_residual_mix = Sequential(
|
|
189
|
+
LinearNoBias(dim, heads),
|
|
190
|
+
Rearrange('b n h -> b h n 1'),
|
|
191
|
+
nn.Sigmoid()
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# fixed window size
|
|
195
|
+
|
|
196
|
+
self.fixed_window_size = fixed_window_size
|
|
197
|
+
self.window_size = window_size
|
|
198
|
+
|
|
170
199
|
def forward(
|
|
171
200
|
self,
|
|
172
201
|
tokens,
|
|
202
|
+
value_residual = None,
|
|
173
203
|
kv_cache = None,
|
|
174
|
-
return_kv_cache = False
|
|
204
|
+
return_kv_cache = False,
|
|
175
205
|
):
|
|
206
|
+
seq_len = tokens.shape[-2]
|
|
207
|
+
assert seq_len <= self.window_size
|
|
208
|
+
|
|
209
|
+
device = tokens.device
|
|
210
|
+
|
|
176
211
|
tokens = self.norm(tokens)
|
|
177
212
|
|
|
178
213
|
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
|
|
179
214
|
|
|
180
215
|
q, k, v = map(self.split_heads, (q, k, v))
|
|
181
216
|
|
|
217
|
+
orig_v = v
|
|
218
|
+
|
|
182
219
|
q = q * self.scale
|
|
183
220
|
|
|
221
|
+
if exists(value_residual):
|
|
222
|
+
assert self.accept_value_residual
|
|
223
|
+
mix = self.to_value_residual_mix(tokens)
|
|
224
|
+
v = v.lerp(value_residual, mix)
|
|
225
|
+
|
|
184
226
|
if exists(kv_cache):
|
|
185
227
|
ck, cv = kv_cache
|
|
186
228
|
k = cat((ck, k), dim = -2)
|
|
@@ -195,7 +237,13 @@ class Attention(Module):
|
|
|
195
237
|
|
|
196
238
|
i, j = sim.shape[-2:]
|
|
197
239
|
|
|
198
|
-
|
|
240
|
+
if self.fixed_window_size:
|
|
241
|
+
i_seq = arange(i, device = device)
|
|
242
|
+
j_seq = arange(j, device = device) - (j - i)
|
|
243
|
+
dist = einx.subtract('i, j -> i j', i_seq, j_seq)
|
|
244
|
+
causal_mask = (dist < 0) | (dist > self.window_size)
|
|
245
|
+
else:
|
|
246
|
+
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
|
|
199
247
|
|
|
200
248
|
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
|
201
249
|
|
|
@@ -203,6 +251,8 @@ class Attention(Module):
|
|
|
203
251
|
|
|
204
252
|
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
|
|
205
253
|
|
|
254
|
+
out = out * self.to_v_gates(tokens)
|
|
255
|
+
|
|
206
256
|
out = self.merge_heads(out)
|
|
207
257
|
|
|
208
258
|
out = self.to_out(out)
|
|
@@ -210,7 +260,7 @@ class Attention(Module):
|
|
|
210
260
|
if not return_kv_cache:
|
|
211
261
|
return out
|
|
212
262
|
|
|
213
|
-
return out, next_kv_cache
|
|
263
|
+
return out, (next_kv_cache, orig_v)
|
|
214
264
|
|
|
215
265
|
class FeedForward(Module):
|
|
216
266
|
def __init__(
|
|
@@ -244,17 +294,21 @@ class TransformerXL(Module):
|
|
|
244
294
|
self,
|
|
245
295
|
dim,
|
|
246
296
|
depth,
|
|
297
|
+
window_size,
|
|
247
298
|
dim_head = 64,
|
|
248
299
|
heads = 8,
|
|
249
300
|
expansion_factor = 4.,
|
|
250
|
-
final_norm = True
|
|
301
|
+
final_norm = True,
|
|
302
|
+
fixed_window_size = False,
|
|
251
303
|
):
|
|
252
304
|
super().__init__()
|
|
253
305
|
|
|
254
306
|
layers = ModuleList([])
|
|
255
307
|
|
|
256
|
-
for
|
|
257
|
-
|
|
308
|
+
for i in range(depth):
|
|
309
|
+
is_first = i == 0
|
|
310
|
+
|
|
311
|
+
attn = Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first)
|
|
258
312
|
|
|
259
313
|
ff = FeedForward(dim = dim, expansion_factor = expansion_factor)
|
|
260
314
|
|
|
@@ -265,6 +319,11 @@ class TransformerXL(Module):
|
|
|
265
319
|
self.layers = layers
|
|
266
320
|
self.norm = RMSNorm(dim) if final_norm else Identity()
|
|
267
321
|
|
|
322
|
+
# fixed window size
|
|
323
|
+
|
|
324
|
+
self.fixed_window_size = fixed_window_size
|
|
325
|
+
self.window_size = window_size
|
|
326
|
+
|
|
268
327
|
def forward(
|
|
269
328
|
self,
|
|
270
329
|
x,
|
|
@@ -275,22 +334,28 @@ class TransformerXL(Module):
|
|
|
275
334
|
cache = default(cache, (None,) * len(self.layers))
|
|
276
335
|
|
|
277
336
|
next_kv_caches = []
|
|
337
|
+
value_residual = None
|
|
278
338
|
|
|
279
339
|
for (attn, ff), kv_cache in zip(self.layers, cache):
|
|
280
340
|
|
|
281
|
-
attn_out, next_kv_cache = attn(x, kv_cache = kv_cache, return_kv_cache = True)
|
|
282
|
-
|
|
283
|
-
next_kv_caches.append(next_kv_cache)
|
|
341
|
+
attn_out, (next_kv_cache, values) = attn(x, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
|
|
284
342
|
|
|
285
343
|
x = attn_out + x
|
|
286
344
|
x = ff(x) + x
|
|
287
345
|
|
|
346
|
+
next_kv_caches.append(next_kv_cache)
|
|
347
|
+
value_residual = default(value_residual, values)
|
|
348
|
+
|
|
288
349
|
embed = self.norm(x)
|
|
289
350
|
|
|
290
351
|
if not return_kv_cache:
|
|
291
352
|
return embed
|
|
292
353
|
|
|
293
|
-
|
|
354
|
+
next_kv_cache = stack(next_kv_caches)
|
|
355
|
+
|
|
356
|
+
next_kv_cache = next_kv_cache[..., -self.window_size:, :]
|
|
357
|
+
|
|
358
|
+
return embed, next_kv_cache
|
|
294
359
|
|
|
295
360
|
# class
|
|
296
361
|
|
|
@@ -314,28 +379,42 @@ class Locoformer(Module):
|
|
|
314
379
|
|
|
315
380
|
self.value_network = value_network
|
|
316
381
|
|
|
382
|
+
self.fixed_window_size = transformer.fixed_window_size
|
|
383
|
+
self.window_size = transformer.window_size
|
|
384
|
+
|
|
317
385
|
@property
|
|
318
386
|
def device(self):
|
|
319
387
|
return next(self.parameters()).device
|
|
320
388
|
|
|
389
|
+
def actor_parameters(self):
|
|
390
|
+
return self.unembedder.parameters()
|
|
391
|
+
|
|
392
|
+
def critic_parameters(self):
|
|
393
|
+
return self.value_network.parameters()
|
|
394
|
+
|
|
321
395
|
def get_stateful_forward(
|
|
322
396
|
self,
|
|
323
|
-
segment_size,
|
|
324
397
|
initial_states: Tensor | None = None,
|
|
325
398
|
inference_mode = False,
|
|
326
399
|
has_batch_dim = False,
|
|
400
|
+
has_time_dim = False,
|
|
327
401
|
**kwargs
|
|
328
402
|
):
|
|
403
|
+
window_size = self.window_size
|
|
404
|
+
|
|
329
405
|
cache = None
|
|
330
406
|
|
|
331
|
-
def stateful_forward(state: Tensor, override_kwargs
|
|
407
|
+
def stateful_forward(state: Tensor, **override_kwargs):
|
|
332
408
|
nonlocal cache
|
|
333
409
|
|
|
334
|
-
# handle no batch, for easier time rolling out against envs
|
|
410
|
+
# handle no batch or time, for easier time rolling out against envs
|
|
335
411
|
|
|
336
412
|
if not has_batch_dim:
|
|
337
413
|
state = rearrange(state, '... -> 1 ...')
|
|
338
414
|
|
|
415
|
+
if not has_time_dim:
|
|
416
|
+
state = rearrange(state, '... d -> ... 1 d')
|
|
417
|
+
|
|
339
418
|
# forwards
|
|
340
419
|
|
|
341
420
|
out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
|
|
@@ -344,10 +423,13 @@ class Locoformer(Module):
|
|
|
344
423
|
|
|
345
424
|
cache_len = cache.shape[-2]
|
|
346
425
|
|
|
347
|
-
if divisible_by(cache_len,
|
|
348
|
-
cache = cache[..., -
|
|
426
|
+
if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
|
|
427
|
+
cache = cache[..., -window_size:, :]
|
|
428
|
+
|
|
429
|
+
# maybe remove batch or time
|
|
349
430
|
|
|
350
|
-
|
|
431
|
+
if not has_time_dim:
|
|
432
|
+
out = tree_map_tensor(out, lambda t: rearrange(t, '... 1 d -> ... d'))
|
|
351
433
|
|
|
352
434
|
if not has_batch_dim:
|
|
353
435
|
out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
|
|
@@ -364,7 +446,7 @@ class Locoformer(Module):
|
|
|
364
446
|
|
|
365
447
|
initial_logits = []
|
|
366
448
|
|
|
367
|
-
for state_segments in initial_states.split(
|
|
449
|
+
for state_segments in initial_states.split(self.window_size, dim = -1):
|
|
368
450
|
|
|
369
451
|
logits = stateful_forward(state_segments, return_values = False)
|
|
370
452
|
initial_logits.append(logits)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: locoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.7
|
|
4
4
|
Summary: LocoFormer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/locoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/locoformer
|
|
@@ -53,7 +53,7 @@ Description-Content-Type: text/markdown
|
|
|
53
53
|
|
|
54
54
|
[LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
|
|
55
55
|
|
|
56
|
-
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment). When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
56
|
+
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
57
57
|
|
|
58
58
|
## Sponsors
|
|
59
59
|
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
|
|
2
|
+
locoformer/locoformer.py,sha256=lJQs0CKr9iztF8tie1FRUVEItCt-IZbIILQqKcgK2sI,13142
|
|
3
|
+
locoformer-0.0.7.dist-info/METADATA,sha256=PZ_phKV3t4Bha0GnUB5HPmE9w8A5fvNevsuN532Ls3s,3193
|
|
4
|
+
locoformer-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
locoformer-0.0.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
locoformer-0.0.7.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
|
|
2
|
-
locoformer/locoformer.py,sha256=Yoh3hrj2E_91YLoYRa73wGzjdIiMdcd5ofNjkiVlogI,10570
|
|
3
|
-
locoformer-0.0.5.dist-info/METADATA,sha256=oe6HfOwWKQvusiJl1ukmNFcrGRhdDZ6NcKZi3upv-SY,3159
|
|
4
|
-
locoformer-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
locoformer-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
locoformer-0.0.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|