locoformer 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
locoformer/locoformer.py CHANGED
@@ -2,11 +2,12 @@ from __future__ import annotations
2
2
  from functools import partial
3
3
 
4
4
  import torch
5
- from torch import cat, stack, is_tensor
5
+ from torch import nn, cat, stack, arange, Tensor, is_tensor
6
6
  import torch.nn.functional as F
7
- from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity
7
+ from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
8
8
  from torch.utils._pytree import tree_map
9
9
 
10
+ import einx
10
11
  from einops import rearrange, einsum
11
12
  from einops.layers.torch import Rearrange
12
13
 
@@ -24,6 +25,9 @@ def exists(v):
24
25
  def default(v, d):
25
26
  return v if exists(v) else d
26
27
 
28
+ def first(arr):
29
+ return arr[0]
30
+
27
31
  def divisible_by(num, den):
28
32
  return (num % den) == 0
29
33
 
@@ -148,9 +152,12 @@ class Attention(Module):
148
152
  def __init__(
149
153
  self,
150
154
  dim,
155
+ window_size,
151
156
  dim_head = 64,
152
157
  heads = 8,
153
- pre_rmsnorm = True
158
+ pre_rmsnorm = True,
159
+ fixed_window_size = False,
160
+ accept_value_residual = False
154
161
  ):
155
162
  super().__init__()
156
163
  self.scale = dim_head ** -0.5
@@ -167,20 +174,55 @@ class Attention(Module):
167
174
  self.to_kv = LinearNoBias(dim, dim_inner * 2)
168
175
  self.to_out = LinearNoBias(dim_inner, dim)
169
176
 
177
+ self.to_v_gates = Sequential(
178
+ LinearNoBias(dim, heads),
179
+ Rearrange('b n h -> b h n 1'),
180
+ nn.Sigmoid()
181
+ )
182
+
183
+ # value residual
184
+
185
+ self.accept_value_residual = accept_value_residual
186
+
187
+ if accept_value_residual:
188
+ self.to_value_residual_mix = Sequential(
189
+ LinearNoBias(dim, heads),
190
+ Rearrange('b n h -> b h n 1'),
191
+ nn.Sigmoid()
192
+ )
193
+
194
+ # fixed window size
195
+
196
+ self.fixed_window_size = fixed_window_size
197
+ self.window_size = window_size
198
+
170
199
  def forward(
171
200
  self,
172
201
  tokens,
202
+ value_residual = None,
173
203
  kv_cache = None,
174
- return_kv_cache = False
204
+ return_kv_cache = False,
175
205
  ):
206
+ seq_len = tokens.shape[-2]
207
+ assert seq_len <= self.window_size
208
+
209
+ device = tokens.device
210
+
176
211
  tokens = self.norm(tokens)
177
212
 
178
213
  q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
179
214
 
180
215
  q, k, v = map(self.split_heads, (q, k, v))
181
216
 
217
+ orig_v = v
218
+
182
219
  q = q * self.scale
183
220
 
221
+ if exists(value_residual):
222
+ assert self.accept_value_residual
223
+ mix = self.to_value_residual_mix(tokens)
224
+ v = v.lerp(value_residual, mix)
225
+
184
226
  if exists(kv_cache):
185
227
  ck, cv = kv_cache
186
228
  k = cat((ck, k), dim = -2)
@@ -195,7 +237,13 @@ class Attention(Module):
195
237
 
196
238
  i, j = sim.shape[-2:]
197
239
 
198
- causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
240
+ if self.fixed_window_size:
241
+ i_seq = arange(i, device = device)
242
+ j_seq = arange(j, device = device) - (j - i)
243
+ dist = einx.subtract('i, j -> i j', i_seq, j_seq)
244
+ causal_mask = (dist < 0) | (dist > self.window_size)
245
+ else:
246
+ causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
199
247
 
200
248
  sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
201
249
 
@@ -203,6 +251,8 @@ class Attention(Module):
203
251
 
204
252
  out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
205
253
 
254
+ out = out * self.to_v_gates(tokens)
255
+
206
256
  out = self.merge_heads(out)
207
257
 
208
258
  out = self.to_out(out)
@@ -210,7 +260,7 @@ class Attention(Module):
210
260
  if not return_kv_cache:
211
261
  return out
212
262
 
213
- return out, next_kv_cache
263
+ return out, (next_kv_cache, orig_v)
214
264
 
215
265
  class FeedForward(Module):
216
266
  def __init__(
@@ -244,17 +294,21 @@ class TransformerXL(Module):
244
294
  self,
245
295
  dim,
246
296
  depth,
297
+ window_size,
247
298
  dim_head = 64,
248
299
  heads = 8,
249
300
  expansion_factor = 4.,
250
- final_norm = True
301
+ final_norm = True,
302
+ fixed_window_size = False,
251
303
  ):
252
304
  super().__init__()
253
305
 
254
306
  layers = ModuleList([])
255
307
 
256
- for _ in range(depth):
257
- attn = Attention(dim = dim, dim_head = dim_head, heads = heads)
308
+ for i in range(depth):
309
+ is_first = i == 0
310
+
311
+ attn = Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first)
258
312
 
259
313
  ff = FeedForward(dim = dim, expansion_factor = expansion_factor)
260
314
 
@@ -265,6 +319,11 @@ class TransformerXL(Module):
265
319
  self.layers = layers
266
320
  self.norm = RMSNorm(dim) if final_norm else Identity()
267
321
 
322
+ # fixed window size
323
+
324
+ self.fixed_window_size = fixed_window_size
325
+ self.window_size = window_size
326
+
268
327
  def forward(
269
328
  self,
270
329
  x,
@@ -275,22 +334,28 @@ class TransformerXL(Module):
275
334
  cache = default(cache, (None,) * len(self.layers))
276
335
 
277
336
  next_kv_caches = []
337
+ value_residual = None
278
338
 
279
339
  for (attn, ff), kv_cache in zip(self.layers, cache):
280
340
 
281
- attn_out, next_kv_cache = attn(x, kv_cache = kv_cache, return_kv_cache = True)
282
-
283
- next_kv_caches.append(next_kv_cache)
341
+ attn_out, (next_kv_cache, values) = attn(x, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
284
342
 
285
343
  x = attn_out + x
286
344
  x = ff(x) + x
287
345
 
346
+ next_kv_caches.append(next_kv_cache)
347
+ value_residual = default(value_residual, values)
348
+
288
349
  embed = self.norm(x)
289
350
 
290
351
  if not return_kv_cache:
291
352
  return embed
292
353
 
293
- return embed, stack(next_kv_caches)
354
+ next_kv_cache = stack(next_kv_caches)
355
+
356
+ next_kv_cache = next_kv_cache[..., -self.window_size:, :]
357
+
358
+ return embed, next_kv_cache
294
359
 
295
360
  # class
296
361
 
@@ -314,28 +379,42 @@ class Locoformer(Module):
314
379
 
315
380
  self.value_network = value_network
316
381
 
382
+ self.fixed_window_size = transformer.fixed_window_size
383
+ self.window_size = transformer.window_size
384
+
317
385
  @property
318
386
  def device(self):
319
387
  return next(self.parameters()).device
320
388
 
389
+ def actor_parameters(self):
390
+ return self.unembedder.parameters()
391
+
392
+ def critic_parameters(self):
393
+ return self.value_network.parameters()
394
+
321
395
  def get_stateful_forward(
322
396
  self,
323
- segment_size,
324
397
  initial_states: Tensor | None = None,
325
398
  inference_mode = False,
326
399
  has_batch_dim = False,
400
+ has_time_dim = False,
327
401
  **kwargs
328
402
  ):
403
+ window_size = self.window_size
404
+
329
405
  cache = None
330
406
 
331
- def stateful_forward(state: Tensor, override_kwargs: dict = dict()):
407
+ def stateful_forward(state: Tensor, **override_kwargs):
332
408
  nonlocal cache
333
409
 
334
- # handle no batch, for easier time rolling out against envs
410
+ # handle no batch or time, for easier time rolling out against envs
335
411
 
336
412
  if not has_batch_dim:
337
413
  state = rearrange(state, '... -> 1 ...')
338
414
 
415
+ if not has_time_dim:
416
+ state = rearrange(state, '... d -> ... 1 d')
417
+
339
418
  # forwards
340
419
 
341
420
  out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
@@ -344,10 +423,13 @@ class Locoformer(Module):
344
423
 
345
424
  cache_len = cache.shape[-2]
346
425
 
347
- if divisible_by(cache_len, segment_size * 2):
348
- cache = cache[..., -segment_size:, :]
426
+ if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
427
+ cache = cache[..., -window_size:, :]
428
+
429
+ # maybe remove batch or time
349
430
 
350
- # maybe remove batch
431
+ if not has_time_dim:
432
+ out = tree_map_tensor(out, lambda t: rearrange(t, '... 1 d -> ... d'))
351
433
 
352
434
  if not has_batch_dim:
353
435
  out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
@@ -364,7 +446,7 @@ class Locoformer(Module):
364
446
 
365
447
  initial_logits = []
366
448
 
367
- for state_segments in initial_states.split(segment_size, dim = -1):
449
+ for state_segments in initial_states.split(self.window_size, dim = -1):
368
450
 
369
451
  logits = stateful_forward(state_segments, return_values = False)
370
452
  initial_logits.append(logits)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: locoformer
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Summary: LocoFormer
5
5
  Project-URL: Homepage, https://pypi.org/project/locoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/locoformer
@@ -53,7 +53,7 @@ Description-Content-Type: text/markdown
53
53
 
54
54
  [LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
55
55
 
56
- The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment). When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
56
+ The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
57
57
 
58
58
  ## Sponsors
59
59
 
@@ -0,0 +1,6 @@
1
+ locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
2
+ locoformer/locoformer.py,sha256=lJQs0CKr9iztF8tie1FRUVEItCt-IZbIILQqKcgK2sI,13142
3
+ locoformer-0.0.7.dist-info/METADATA,sha256=PZ_phKV3t4Bha0GnUB5HPmE9w8A5fvNevsuN532Ls3s,3193
4
+ locoformer-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ locoformer-0.0.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ locoformer-0.0.7.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
2
- locoformer/locoformer.py,sha256=Yoh3hrj2E_91YLoYRa73wGzjdIiMdcd5ofNjkiVlogI,10570
3
- locoformer-0.0.5.dist-info/METADATA,sha256=oe6HfOwWKQvusiJl1ukmNFcrGRhdDZ6NcKZi3upv-SY,3159
4
- locoformer-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- locoformer-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- locoformer-0.0.5.dist-info/RECORD,,