locoformer 0.0.17__tar.gz → 0.0.30__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {locoformer-0.0.17 → locoformer-0.0.30}/PKG-INFO +3 -2
- {locoformer-0.0.17 → locoformer-0.0.30}/README.md +1 -1
- {locoformer-0.0.17 → locoformer-0.0.30}/locoformer/locoformer.py +302 -58
- {locoformer-0.0.17 → locoformer-0.0.30}/pyproject.toml +2 -1
- {locoformer-0.0.17 → locoformer-0.0.30}/tests/test_locoformer.py +66 -5
- {locoformer-0.0.17 → locoformer-0.0.30}/train.py +1 -1
- {locoformer-0.0.17 → locoformer-0.0.30}/train_gym.py +25 -26
- {locoformer-0.0.17 → locoformer-0.0.30}/.github/workflows/python-publish.yml +0 -0
- {locoformer-0.0.17 → locoformer-0.0.30}/.github/workflows/test.yml +0 -0
- {locoformer-0.0.17 → locoformer-0.0.30}/.gitignore +0 -0
- {locoformer-0.0.17 → locoformer-0.0.30}/LICENSE +0 -0
- {locoformer-0.0.17 → locoformer-0.0.30}/data/README.md +0 -0
- {locoformer-0.0.17 → locoformer-0.0.30}/data/enwik8.gz +0 -0
- {locoformer-0.0.17 → locoformer-0.0.30}/fig3.png +0 -0
- {locoformer-0.0.17 → locoformer-0.0.30}/locoformer/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: locoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.30
|
|
4
4
|
Summary: LocoFormer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/locoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/locoformer
|
|
@@ -38,6 +38,7 @@ Requires-Dist: assoc-scan
|
|
|
38
38
|
Requires-Dist: beartype
|
|
39
39
|
Requires-Dist: einops>=0.8.0
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
|
+
Requires-Dist: hl-gauss-pytorch>=0.2.0
|
|
41
42
|
Requires-Dist: rotary-embedding-torch
|
|
42
43
|
Requires-Dist: torch>=2.4
|
|
43
44
|
Requires-Dist: x-mlps-pytorch
|
|
@@ -54,7 +55,7 @@ Description-Content-Type: text/markdown
|
|
|
54
55
|
|
|
55
56
|
[LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
|
|
56
57
|
|
|
57
|
-
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment)
|
|
58
|
+
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) and extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
58
59
|
|
|
59
60
|
## Sponsors
|
|
60
61
|
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
[LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
|
|
6
6
|
|
|
7
|
-
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment)
|
|
7
|
+
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) and extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
8
8
|
|
|
9
9
|
## Sponsors
|
|
10
10
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
2
3
|
from functools import partial
|
|
3
4
|
|
|
4
5
|
from pathlib import Path
|
|
@@ -16,7 +17,7 @@ import torch
|
|
|
16
17
|
from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
|
|
17
18
|
import torch.nn.functional as F
|
|
18
19
|
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
|
|
19
|
-
from torch.utils._pytree import tree_map
|
|
20
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
20
21
|
from torch.utils.data import Dataset, DataLoader
|
|
21
22
|
from torch.optim import Optimizer
|
|
22
23
|
|
|
@@ -26,12 +27,16 @@ from einops.layers.torch import Rearrange
|
|
|
26
27
|
|
|
27
28
|
from rotary_embedding_torch import RotaryEmbedding
|
|
28
29
|
|
|
30
|
+
from hl_gauss_pytorch import HLGaussLoss
|
|
31
|
+
|
|
29
32
|
from assoc_scan import AssocScan
|
|
30
33
|
|
|
31
34
|
# constants
|
|
32
35
|
|
|
33
36
|
LinearNoBias = partial(Linear, bias = False)
|
|
34
37
|
|
|
38
|
+
Cache = namedtuple('Cache', ('curr_timestep', 'kv_cache')) # (int, Tensor)
|
|
39
|
+
|
|
35
40
|
# helper functions
|
|
36
41
|
|
|
37
42
|
def exists(v):
|
|
@@ -43,6 +48,9 @@ def default(v, d):
|
|
|
43
48
|
def first(arr):
|
|
44
49
|
return arr[0]
|
|
45
50
|
|
|
51
|
+
def xnor(x, y):
|
|
52
|
+
return not (x ^ y)
|
|
53
|
+
|
|
46
54
|
def divisible_by(num, den):
|
|
47
55
|
return (num % den) == 0
|
|
48
56
|
|
|
@@ -51,6 +59,9 @@ def divisible_by(num, den):
|
|
|
51
59
|
def log(t, eps = 1e-20):
|
|
52
60
|
return t.clamp_min(eps).log()
|
|
53
61
|
|
|
62
|
+
def is_empty(t):
|
|
63
|
+
return t.numel() == 0
|
|
64
|
+
|
|
54
65
|
def tree_map_tensor(x, fn):
|
|
55
66
|
return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
|
|
56
67
|
|
|
@@ -67,6 +78,9 @@ def pad_at_dim(
|
|
|
67
78
|
zeros = ((0, 0) * dims_from_right)
|
|
68
79
|
return F.pad(t, (*zeros, *pad), value = value)
|
|
69
80
|
|
|
81
|
+
def normalize(t, eps = 1e-5):
|
|
82
|
+
return (t - t.mean()) / t.std().clamp_min(eps)
|
|
83
|
+
|
|
70
84
|
def calc_entropy(logits):
|
|
71
85
|
prob = logits.softmax(dim = -1)
|
|
72
86
|
return -(prob * log(prob)).sum(dim = -1)
|
|
@@ -250,6 +264,57 @@ class ReplayDataset(Dataset):
|
|
|
250
264
|
|
|
251
265
|
return data
|
|
252
266
|
|
|
267
|
+
class RemappedReplayDataset(Dataset):
|
|
268
|
+
def __init__(
|
|
269
|
+
self,
|
|
270
|
+
dataset: ReplayDataset,
|
|
271
|
+
episode_mapping: Tensor | list[list[int]],
|
|
272
|
+
shuffle_episodes = False
|
|
273
|
+
):
|
|
274
|
+
assert len(dataset) > 0
|
|
275
|
+
self.dataset = dataset
|
|
276
|
+
|
|
277
|
+
if is_tensor(episode_mapping):
|
|
278
|
+
assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
|
|
279
|
+
episode_mapping = episode_mapping.tolist()
|
|
280
|
+
|
|
281
|
+
self.episode_mapping = episode_mapping
|
|
282
|
+
self.shuffle_episodes = shuffle_episodes
|
|
283
|
+
|
|
284
|
+
def __len__(self):
|
|
285
|
+
return len(self.episode_mapping)
|
|
286
|
+
|
|
287
|
+
def __getitem__(self, idx):
|
|
288
|
+
|
|
289
|
+
episode_indices = self.episode_mapping[idx]
|
|
290
|
+
|
|
291
|
+
episode_indices = tensor(episode_indices)
|
|
292
|
+
episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
|
|
293
|
+
|
|
294
|
+
assert not is_empty(episode_indices)
|
|
295
|
+
|
|
296
|
+
if self.shuffle_episodes and episode_indices.numel() > 1:
|
|
297
|
+
num_episodes = len(episode_indices)
|
|
298
|
+
episode_indices = episode_indices[torch.randperm(num_episodes)]
|
|
299
|
+
|
|
300
|
+
episode_data = [self.dataset[i] for i in episode_indices.tolist()]
|
|
301
|
+
|
|
302
|
+
episode_lens = stack([data.pop('_lens') for data in episode_data])
|
|
303
|
+
|
|
304
|
+
keys = first(episode_data).keys()
|
|
305
|
+
|
|
306
|
+
values = [list(data.values()) for data in episode_data]
|
|
307
|
+
|
|
308
|
+
values = [cat(field_values) for field_values in zip(*values)] # concat across time
|
|
309
|
+
|
|
310
|
+
multi_episode_data = dict(zip(keys, values))
|
|
311
|
+
|
|
312
|
+
multi_episode_data['_lens'] = episode_lens.sum()
|
|
313
|
+
|
|
314
|
+
multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
|
|
315
|
+
|
|
316
|
+
return multi_episode_data
|
|
317
|
+
|
|
253
318
|
class ReplayBuffer:
|
|
254
319
|
|
|
255
320
|
@beartype
|
|
@@ -314,6 +379,9 @@ class ReplayBuffer:
|
|
|
314
379
|
|
|
315
380
|
self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
|
|
316
381
|
|
|
382
|
+
def __len__(self):
|
|
383
|
+
return (self.episode_lens > 0).sum().item()
|
|
384
|
+
|
|
317
385
|
def reset_(self):
|
|
318
386
|
self.episode_lens[:] = 0
|
|
319
387
|
self.episode_index = 0
|
|
@@ -375,15 +443,91 @@ class ReplayBuffer:
|
|
|
375
443
|
|
|
376
444
|
return self.memory_namedtuple(**data)
|
|
377
445
|
|
|
378
|
-
def dataset(
|
|
446
|
+
def dataset(
|
|
447
|
+
self,
|
|
448
|
+
episode_mapping: Tensor | list[list[int]] | None = None,
|
|
449
|
+
) -> Dataset:
|
|
379
450
|
self.flush()
|
|
380
451
|
|
|
381
|
-
|
|
452
|
+
dataset = ReplayDataset(self.folder)
|
|
453
|
+
|
|
454
|
+
if not exists(episode_mapping):
|
|
455
|
+
return dataset
|
|
382
456
|
|
|
383
|
-
|
|
457
|
+
return RemappedReplayDataset(dataset, episode_mapping)
|
|
458
|
+
|
|
459
|
+
def dataloader(
|
|
460
|
+
self,
|
|
461
|
+
batch_size,
|
|
462
|
+
episode_mapping: Tensor | list[list[int]] | None = None,
|
|
463
|
+
**kwargs
|
|
464
|
+
) -> DataLoader:
|
|
384
465
|
self.flush()
|
|
385
466
|
|
|
386
|
-
return DataLoader(self.dataset(), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
|
|
467
|
+
return DataLoader(self.dataset(episode_mapping), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
|
|
468
|
+
|
|
469
|
+
# normalization + conditioning (needed for the commands to the robot)
|
|
470
|
+
|
|
471
|
+
class MaybeAdaRMSNormWrapper(Module):
|
|
472
|
+
def __init__(
|
|
473
|
+
self,
|
|
474
|
+
fn: Module,
|
|
475
|
+
dim,
|
|
476
|
+
dim_cond = None
|
|
477
|
+
):
|
|
478
|
+
super().__init__()
|
|
479
|
+
condition = exists(dim_cond)
|
|
480
|
+
|
|
481
|
+
self.fn = fn
|
|
482
|
+
self.norm = nn.RMSNorm(dim, elementwise_affine = not condition)
|
|
483
|
+
|
|
484
|
+
self.accept_condition = condition
|
|
485
|
+
|
|
486
|
+
if condition:
|
|
487
|
+
self.to_gamma = LinearNoBias(dim_cond, dim)
|
|
488
|
+
self.to_ada_norm_zero = nn.Linear(dim_cond, dim)
|
|
489
|
+
|
|
490
|
+
nn.init.zeros_(self.to_gamma.weight, 0.)
|
|
491
|
+
nn.init.zeros_(self.to_ada_norm_zero.weight, 0.)
|
|
492
|
+
nn.init.constant_(self.to_ada_norm_zero.bias, -5.)
|
|
493
|
+
|
|
494
|
+
def forward(
|
|
495
|
+
self,
|
|
496
|
+
x,
|
|
497
|
+
cond = None,
|
|
498
|
+
**kwargs
|
|
499
|
+
):
|
|
500
|
+
|
|
501
|
+
need_cond = self.accept_condition
|
|
502
|
+
assert xnor(exists(cond), need_cond)
|
|
503
|
+
|
|
504
|
+
prenormed = self.norm(x)
|
|
505
|
+
|
|
506
|
+
if need_cond:
|
|
507
|
+
if cond.ndim == 2:
|
|
508
|
+
cond = rearrange(cond, 'b d -> b 1 d')
|
|
509
|
+
|
|
510
|
+
scale_in = self.to_gamma(cond)
|
|
511
|
+
prenormed = prenormed * (scale_in + 1.)
|
|
512
|
+
|
|
513
|
+
all_fn_out = self.fn(prenormed, **kwargs)
|
|
514
|
+
|
|
515
|
+
if not need_cond:
|
|
516
|
+
return all_fn_out
|
|
517
|
+
|
|
518
|
+
# function may return multiple args
|
|
519
|
+
|
|
520
|
+
(out, *rest), tree_spec = tree_flatten(all_fn_out)
|
|
521
|
+
|
|
522
|
+
if need_cond:
|
|
523
|
+
scale_out = self.to_ada_norm_zero(cond).sigmoid()
|
|
524
|
+
out = out * scale_out
|
|
525
|
+
|
|
526
|
+
# restore
|
|
527
|
+
|
|
528
|
+
all_fn_out = tree_unflatten((out, *rest), tree_spec)
|
|
529
|
+
|
|
530
|
+
return all_fn_out
|
|
387
531
|
|
|
388
532
|
# transformer-xl with ppo
|
|
389
533
|
|
|
@@ -394,15 +538,12 @@ class Attention(Module):
|
|
|
394
538
|
window_size,
|
|
395
539
|
dim_head = 64,
|
|
396
540
|
heads = 8,
|
|
397
|
-
pre_rmsnorm = True,
|
|
398
541
|
fixed_window_size = False,
|
|
399
542
|
accept_value_residual = False
|
|
400
543
|
):
|
|
401
544
|
super().__init__()
|
|
402
545
|
self.scale = dim_head ** -0.5
|
|
403
546
|
|
|
404
|
-
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
|
405
|
-
|
|
406
547
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
|
407
548
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
408
549
|
|
|
@@ -446,8 +587,6 @@ class Attention(Module):
|
|
|
446
587
|
|
|
447
588
|
device = tokens.device
|
|
448
589
|
|
|
449
|
-
tokens = self.norm(tokens)
|
|
450
|
-
|
|
451
590
|
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
|
|
452
591
|
|
|
453
592
|
q, k, v = map(self.split_heads, (q, k, v))
|
|
@@ -536,19 +675,24 @@ class TransformerXL(Module):
|
|
|
536
675
|
dim_head = 64,
|
|
537
676
|
heads = 8,
|
|
538
677
|
expansion_factor = 4.,
|
|
678
|
+
dim_cond = None,
|
|
539
679
|
final_norm = True,
|
|
540
680
|
fixed_window_size = False,
|
|
541
681
|
):
|
|
542
682
|
super().__init__()
|
|
543
683
|
|
|
684
|
+
condition = exists(dim_cond)
|
|
685
|
+
|
|
686
|
+
norm_fn = partial(MaybeAdaRMSNormWrapper, dim = dim, dim_cond = dim_cond)
|
|
687
|
+
|
|
544
688
|
layers = ModuleList([])
|
|
545
689
|
|
|
546
690
|
for i in range(depth):
|
|
547
691
|
is_first = i == 0
|
|
548
692
|
|
|
549
|
-
attn = Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first)
|
|
693
|
+
attn = norm_fn(Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first))
|
|
550
694
|
|
|
551
|
-
ff = FeedForward(dim = dim, expansion_factor = expansion_factor)
|
|
695
|
+
ff = norm_fn(FeedForward(dim = dim, expansion_factor = expansion_factor))
|
|
552
696
|
|
|
553
697
|
layers.append(ModuleList([
|
|
554
698
|
attn, ff
|
|
@@ -603,14 +747,21 @@ class Locoformer(Module):
|
|
|
603
747
|
embedder: Module,
|
|
604
748
|
unembedder: Module,
|
|
605
749
|
transformer: dict | TransformerXL,
|
|
606
|
-
value_network: Module | None = None,
|
|
607
750
|
discount_factor = 0.999,
|
|
608
751
|
gae_lam = 0.95,
|
|
609
752
|
ppo_eps_clip = 0.2,
|
|
610
753
|
ppo_entropy_weight = 0.01,
|
|
611
754
|
ppo_value_clip = 0.4,
|
|
755
|
+
dim_value_input = None, # needs to be set for value network to be available
|
|
756
|
+
value_network: Module = nn.Identity(),
|
|
757
|
+
reward_range: tuple[float, float] | None = None,
|
|
758
|
+
reward_shaping_fns: list[Callable[[Tensor], float | Tensor]] | None = None,
|
|
759
|
+
num_reward_bins = 32,
|
|
760
|
+
hl_gauss_loss_kwargs = dict(),
|
|
612
761
|
value_loss_weight = 0.5,
|
|
613
|
-
calc_gae_kwargs: dict = dict()
|
|
762
|
+
calc_gae_kwargs: dict = dict(),
|
|
763
|
+
recurrent_kv_cache = True,
|
|
764
|
+
use_spo = False # simple policy optimization https://arxiv.org/abs/2401.16025 - Levine's group (PI) verified it is more stable than PPO
|
|
614
765
|
):
|
|
615
766
|
super().__init__()
|
|
616
767
|
|
|
@@ -622,11 +773,30 @@ class Locoformer(Module):
|
|
|
622
773
|
self.embedder = embedder
|
|
623
774
|
self.unembedder = unembedder
|
|
624
775
|
|
|
625
|
-
self.value_network = value_network
|
|
626
|
-
|
|
627
776
|
self.fixed_window_size = transformer.fixed_window_size
|
|
628
777
|
self.window_size = transformer.window_size
|
|
629
778
|
|
|
779
|
+
# determine value network, using HL Gauss Layer
|
|
780
|
+
|
|
781
|
+
self.to_value_pred = None
|
|
782
|
+
|
|
783
|
+
if exists(dim_value_input):
|
|
784
|
+
assert exists(reward_range)
|
|
785
|
+
|
|
786
|
+
self.to_value_pred = nn.Sequential(
|
|
787
|
+
value_network,
|
|
788
|
+
LinearNoBias(dim_value_input, num_reward_bins)
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
reward_min, reward_max = reward_range
|
|
792
|
+
|
|
793
|
+
self.hl_gauss_loss = HLGaussLoss(
|
|
794
|
+
min_value = reward_min,
|
|
795
|
+
max_value = reward_max,
|
|
796
|
+
num_bins = num_reward_bins,
|
|
797
|
+
**hl_gauss_loss_kwargs
|
|
798
|
+
)
|
|
799
|
+
|
|
630
800
|
# ppo related
|
|
631
801
|
|
|
632
802
|
self.discount_factor = discount_factor
|
|
@@ -638,6 +808,19 @@ class Locoformer(Module):
|
|
|
638
808
|
|
|
639
809
|
self.calc_gae_kwargs = calc_gae_kwargs
|
|
640
810
|
|
|
811
|
+
# maybe use spo
|
|
812
|
+
|
|
813
|
+
self.use_spo = use_spo
|
|
814
|
+
|
|
815
|
+
# maybe recurrent kv cache, from Ding et al. https://arxiv.org/abs/2012.15688
|
|
816
|
+
|
|
817
|
+
self.recurrent_kv_cache = recurrent_kv_cache
|
|
818
|
+
|
|
819
|
+
# reward shaping function
|
|
820
|
+
|
|
821
|
+
self.has_reward_shaping = exists(reward_shaping_fns)
|
|
822
|
+
self.reward_shaping_fns = reward_shaping_fns
|
|
823
|
+
|
|
641
824
|
# loss related
|
|
642
825
|
|
|
643
826
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
@@ -650,10 +833,10 @@ class Locoformer(Module):
|
|
|
650
833
|
return self.unembedder.parameters()
|
|
651
834
|
|
|
652
835
|
def critic_parameters(self):
|
|
653
|
-
if not exists(self.
|
|
836
|
+
if not exists(self.to_value_pred):
|
|
654
837
|
return []
|
|
655
838
|
|
|
656
|
-
return self.
|
|
839
|
+
return self.to_value_pred.parameters()
|
|
657
840
|
|
|
658
841
|
def ppo(
|
|
659
842
|
self,
|
|
@@ -663,12 +846,20 @@ class Locoformer(Module):
|
|
|
663
846
|
reward,
|
|
664
847
|
old_value,
|
|
665
848
|
mask,
|
|
849
|
+
episode_lens,
|
|
666
850
|
actor_optim: Optimizer | None = None,
|
|
667
851
|
critic_optim: Optimizer | None = None
|
|
668
852
|
):
|
|
669
853
|
window_size = self.window_size
|
|
670
854
|
total_learnable_tokens = mask.sum().item()
|
|
671
855
|
|
|
856
|
+
seq_len = state.shape[1]
|
|
857
|
+
gae_mask = einx.less('j, i -> i j', arange(seq_len, device = self.device), episode_lens)
|
|
858
|
+
|
|
859
|
+
advantage, returns = calc_gae(reward, old_value, masks = gae_mask, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
|
|
860
|
+
|
|
861
|
+
advantage = normalize(advantage)
|
|
862
|
+
|
|
672
863
|
windowed_tensors = [
|
|
673
864
|
t.split(window_size, dim = 1) for t in
|
|
674
865
|
(
|
|
@@ -677,7 +868,9 @@ class Locoformer(Module):
|
|
|
677
868
|
old_action_log_prob,
|
|
678
869
|
reward,
|
|
679
870
|
old_value,
|
|
680
|
-
mask
|
|
871
|
+
mask,
|
|
872
|
+
advantage,
|
|
873
|
+
returns
|
|
681
874
|
)
|
|
682
875
|
]
|
|
683
876
|
|
|
@@ -694,10 +887,12 @@ class Locoformer(Module):
|
|
|
694
887
|
old_action_log_prob,
|
|
695
888
|
reward,
|
|
696
889
|
old_value,
|
|
697
|
-
mask
|
|
890
|
+
mask,
|
|
891
|
+
advantage,
|
|
892
|
+
returns
|
|
698
893
|
) in zip(*windowed_tensors):
|
|
699
894
|
|
|
700
|
-
(action_logits,
|
|
895
|
+
(action_logits, value_logits), cache = self.forward(state, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True)
|
|
701
896
|
entropy = calc_entropy(action_logits)
|
|
702
897
|
|
|
703
898
|
action = rearrange(action, 'b t -> b t 1')
|
|
@@ -709,9 +904,10 @@ class Locoformer(Module):
|
|
|
709
904
|
eps_clip = self.ppo_eps_clip
|
|
710
905
|
ratio = (log_prob - old_action_log_prob).exp()
|
|
711
906
|
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
907
|
+
if self.use_spo:
|
|
908
|
+
actor_loss = -(ratio * advantage - (advantage.abs() * (ratio - 1.).square()) / (2 * eps_clip))
|
|
909
|
+
else:
|
|
910
|
+
actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
|
|
715
911
|
|
|
716
912
|
actor_loss = actor_loss - self.ppo_entropy_weight * entropy
|
|
717
913
|
|
|
@@ -720,11 +916,13 @@ class Locoformer(Module):
|
|
|
720
916
|
|
|
721
917
|
# update critic
|
|
722
918
|
|
|
723
|
-
value_loss =
|
|
919
|
+
value_loss = self.hl_gauss_loss(value_logits, returns, reduction = 'none')
|
|
724
920
|
|
|
725
921
|
value_clip = self.ppo_value_clip
|
|
922
|
+
value = self.hl_gauss_loss(value_logits)
|
|
923
|
+
|
|
726
924
|
clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
|
|
727
|
-
clipped_value_loss =
|
|
925
|
+
clipped_value_loss = self.hl_gauss_loss(clipped_value, returns, reduction = 'none')
|
|
728
926
|
|
|
729
927
|
critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
|
|
730
928
|
|
|
@@ -750,28 +948,48 @@ class Locoformer(Module):
|
|
|
750
948
|
|
|
751
949
|
return mean_actor_loss.detach(), mean_critic_loss.detach()
|
|
752
950
|
|
|
951
|
+
def state_to_rewards(
|
|
952
|
+
self,
|
|
953
|
+
state
|
|
954
|
+
) -> Tensor:
|
|
955
|
+
|
|
956
|
+
assert self.has_reward_shaping
|
|
957
|
+
|
|
958
|
+
rewards = [fn(state) for fn in self.reward_shaping_fns]
|
|
959
|
+
|
|
960
|
+
rewards = [tensor(reward) if not is_tensor(reward) else reward for reward in rewards]
|
|
961
|
+
return stack(rewards)
|
|
962
|
+
|
|
753
963
|
def wrap_env_functions(self, env):
|
|
754
964
|
|
|
755
|
-
def
|
|
756
|
-
|
|
965
|
+
def transform_output(el):
|
|
966
|
+
if isinstance(el, ndarray):
|
|
967
|
+
return from_numpy(el)
|
|
968
|
+
elif isinstance(el, (int, bool, float)):
|
|
969
|
+
return tensor(el)
|
|
970
|
+
else:
|
|
971
|
+
return el
|
|
757
972
|
|
|
758
|
-
|
|
759
|
-
|
|
973
|
+
def wrapped_reset(*args, **kwargs):
|
|
974
|
+
env_reset_out = env.reset(*args, **kwargs)
|
|
760
975
|
|
|
761
|
-
return
|
|
976
|
+
return tree_map(transform_output, env_reset_out)
|
|
762
977
|
|
|
763
978
|
def wrapped_step(action, *args, **kwargs):
|
|
764
|
-
out = env.step(action.item(), *args, **kwargs)
|
|
765
979
|
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
return tensor(el)
|
|
771
|
-
else:
|
|
772
|
-
return el
|
|
980
|
+
if is_tensor(action):
|
|
981
|
+
action = action.item()
|
|
982
|
+
|
|
983
|
+
env_step_out = env.step(action, *args, **kwargs)
|
|
773
984
|
|
|
774
|
-
|
|
985
|
+
env_step_out_torch = tree_map(transform_output, env_step_out)
|
|
986
|
+
|
|
987
|
+
if not self.has_reward_shaping:
|
|
988
|
+
return env_step_out_torch
|
|
989
|
+
|
|
990
|
+
shaped_rewards = self.state_to_rewards(env_step_out_torch)
|
|
991
|
+
|
|
992
|
+
return env_step_out_torch, shaped_rewards
|
|
775
993
|
|
|
776
994
|
return wrapped_reset, wrapped_step
|
|
777
995
|
|
|
@@ -781,6 +999,7 @@ class Locoformer(Module):
|
|
|
781
999
|
inference_mode = False,
|
|
782
1000
|
has_batch_dim = False,
|
|
783
1001
|
has_time_dim = False,
|
|
1002
|
+
state_time_dim = 1,
|
|
784
1003
|
**kwargs
|
|
785
1004
|
):
|
|
786
1005
|
window_size = self.window_size
|
|
@@ -796,23 +1015,16 @@ class Locoformer(Module):
|
|
|
796
1015
|
state = rearrange(state, '... -> 1 ...')
|
|
797
1016
|
|
|
798
1017
|
if not has_time_dim:
|
|
799
|
-
state =
|
|
1018
|
+
state = state.unsqueeze(state_time_dim)
|
|
800
1019
|
|
|
801
1020
|
# forwards
|
|
802
1021
|
|
|
803
1022
|
out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
|
|
804
1023
|
|
|
805
|
-
# handle cache
|
|
806
|
-
|
|
807
|
-
cache_len = cache.shape[-2]
|
|
808
|
-
|
|
809
|
-
if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
|
|
810
|
-
cache = cache[..., -window_size:, :]
|
|
811
|
-
|
|
812
1024
|
# maybe remove batch or time
|
|
813
1025
|
|
|
814
1026
|
if not has_time_dim:
|
|
815
|
-
out = tree_map_tensor(out, lambda t:
|
|
1027
|
+
out = tree_map_tensor(out, lambda t: t.squeeze(state_time_dim))
|
|
816
1028
|
|
|
817
1029
|
if not has_batch_dim:
|
|
818
1030
|
out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
|
|
@@ -841,16 +1053,35 @@ class Locoformer(Module):
|
|
|
841
1053
|
def forward(
|
|
842
1054
|
self,
|
|
843
1055
|
state: Tensor,
|
|
844
|
-
cache:
|
|
1056
|
+
cache: Cache | None = None,
|
|
845
1057
|
detach_cache = False,
|
|
846
|
-
return_values = False
|
|
1058
|
+
return_values = False,
|
|
1059
|
+
return_raw_value_logits = False
|
|
847
1060
|
):
|
|
848
1061
|
|
|
849
1062
|
state = state.to(self.device)
|
|
850
1063
|
|
|
851
1064
|
tokens = self.embedder(state)
|
|
852
1065
|
|
|
853
|
-
|
|
1066
|
+
# time
|
|
1067
|
+
|
|
1068
|
+
time = tokens.shape[-2]
|
|
1069
|
+
|
|
1070
|
+
# destruct the cache for the current timestep and the cache
|
|
1071
|
+
|
|
1072
|
+
prev_kv_cache = None
|
|
1073
|
+
timestep_start = 0
|
|
1074
|
+
|
|
1075
|
+
if exists(cache):
|
|
1076
|
+
timestep_start, prev_kv_cache = cache
|
|
1077
|
+
|
|
1078
|
+
# an assert - make sure during training or inference, forward never gets anything that crosses the window segment boundary, to open up some possibilities with extending memory
|
|
1079
|
+
|
|
1080
|
+
assert ((timestep_start % self.window_size) + time) <= self.window_size
|
|
1081
|
+
|
|
1082
|
+
# attention
|
|
1083
|
+
|
|
1084
|
+
embed, kv_cache = self.transformer(tokens, cache = prev_kv_cache, return_kv_cache = True)
|
|
854
1085
|
|
|
855
1086
|
# unembed to actions - in language models this would be the next state
|
|
856
1087
|
|
|
@@ -866,16 +1097,29 @@ class Locoformer(Module):
|
|
|
866
1097
|
# handle returning of values
|
|
867
1098
|
|
|
868
1099
|
if return_values:
|
|
869
|
-
assert exists(self.
|
|
1100
|
+
assert exists(self.to_value_pred)
|
|
870
1101
|
|
|
871
|
-
values = self.
|
|
1102
|
+
values = self.to_value_pred(embed)
|
|
872
1103
|
|
|
873
|
-
if
|
|
874
|
-
|
|
875
|
-
values = rearrange(values, '... 1 -> ...')
|
|
1104
|
+
if not return_raw_value_logits:
|
|
1105
|
+
values = self.hl_gauss_loss(values) # converts the value logits to scalar values
|
|
876
1106
|
|
|
877
1107
|
out = (out, values)
|
|
878
1108
|
|
|
879
1109
|
# output and cache
|
|
880
1110
|
|
|
881
|
-
|
|
1111
|
+
next_timestep = time + timestep_start
|
|
1112
|
+
|
|
1113
|
+
# handle curtailing kv cache at the right intervals
|
|
1114
|
+
|
|
1115
|
+
window_size = self.window_size
|
|
1116
|
+
|
|
1117
|
+
if self.fixed_window_size or divisible_by(next_timestep, window_size * 2):
|
|
1118
|
+
kv_cache = kv_cache[..., -window_size:, :]
|
|
1119
|
+
|
|
1120
|
+
# maybe recurrent cache - shift the kv cache from one layer above to the one below, for extending on receptive field of past
|
|
1121
|
+
|
|
1122
|
+
if self.recurrent_kv_cache and divisible_by(next_timestep, window_size):
|
|
1123
|
+
kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
|
|
1124
|
+
|
|
1125
|
+
return out, (next_timestep, kv_cache)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "locoformer"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.30"
|
|
4
4
|
description = "LocoFormer"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -30,6 +30,7 @@ dependencies = [
|
|
|
30
30
|
"beartype",
|
|
31
31
|
"einx>=0.3.0",
|
|
32
32
|
"einops>=0.8.0",
|
|
33
|
+
"hl-gauss-pytorch>=0.2.0",
|
|
33
34
|
"rotary-embedding-torch",
|
|
34
35
|
"torch>=2.4",
|
|
35
36
|
"x-mlps-pytorch",
|
|
@@ -2,18 +2,25 @@ import pytest
|
|
|
2
2
|
param = pytest.mark.parametrize
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
+
from torch import nn
|
|
5
6
|
from x_mlps_pytorch import MLP
|
|
6
7
|
|
|
7
8
|
from einops import rearrange
|
|
8
9
|
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
10
|
+
from locoformer.locoformer import Locoformer
|
|
11
|
+
|
|
12
|
+
@param('recurrent_kv_cache', (False, True))
|
|
13
|
+
def test_locoformer(
|
|
14
|
+
recurrent_kv_cache
|
|
15
|
+
):
|
|
12
16
|
|
|
13
17
|
model = Locoformer(
|
|
14
18
|
embedder = nn.Embedding(256, 128),
|
|
15
19
|
unembedder = nn.Linear(128, 256, bias = False),
|
|
16
|
-
value_network = MLP(128,
|
|
20
|
+
value_network = MLP(128, 64, 32),
|
|
21
|
+
dim_value_input = 32,
|
|
22
|
+
reward_range = (-100., 100.),
|
|
23
|
+
recurrent_kv_cache = recurrent_kv_cache,
|
|
17
24
|
transformer = dict(
|
|
18
25
|
dim = 128,
|
|
19
26
|
depth = 1,
|
|
@@ -83,4 +90,58 @@ def test_replay():
|
|
|
83
90
|
|
|
84
91
|
dataloader = replay_buffer.dataloader(batch_size = 3)
|
|
85
92
|
|
|
86
|
-
assert next(iter(dataloader))['state'].shape[0] == 3
|
|
93
|
+
assert next(iter(dataloader))['state'].shape[0] == 3
|
|
94
|
+
|
|
95
|
+
# we will now consider consecutive pairs of episodes as 2 trials to be used for in-context adaptation
|
|
96
|
+
# but realistically there will be a function that converts a given ReplayBuffer -> Int[batch, episode_indices]
|
|
97
|
+
|
|
98
|
+
from torch import stack, arange
|
|
99
|
+
|
|
100
|
+
episode_indices = arange(len(replay_buffer))
|
|
101
|
+
remapped_episodes = stack((episode_indices[:-1], episode_indices[1:]))
|
|
102
|
+
|
|
103
|
+
dataloader = replay_buffer.dataloader(
|
|
104
|
+
batch_size = 1,
|
|
105
|
+
episode_mapping = remapped_episodes
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
assert next(iter(dataloader))['_lens'][0] == (3 + 5) # first and second episodes are concatted together timewise
|
|
109
|
+
|
|
110
|
+
def test_reward_shaping():
|
|
111
|
+
|
|
112
|
+
model = Locoformer(
|
|
113
|
+
embedder = nn.Embedding(256, 128),
|
|
114
|
+
unembedder = nn.Linear(128, 256, bias = False),
|
|
115
|
+
value_network = MLP(128, 64, 32),
|
|
116
|
+
dim_value_input = 32,
|
|
117
|
+
reward_range = (-100., 100.),
|
|
118
|
+
reward_shaping_fns = [
|
|
119
|
+
lambda state: (state[3] - 2.5).pow(2).mean(),
|
|
120
|
+
lambda state: state[4:6].norm(dim = -1)
|
|
121
|
+
],
|
|
122
|
+
transformer = dict(
|
|
123
|
+
dim = 128,
|
|
124
|
+
depth = 1,
|
|
125
|
+
window_size = 512
|
|
126
|
+
)
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
import numpy as np
|
|
130
|
+
|
|
131
|
+
class MockEnv:
|
|
132
|
+
def reset(self):
|
|
133
|
+
return np.random.normal(size = (10,))
|
|
134
|
+
|
|
135
|
+
def step(self, *args, **kwargs):
|
|
136
|
+
return np.random.normal(size = (10,))
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
env = MockEnv()
|
|
140
|
+
|
|
141
|
+
reset_fn, step_fn = model.wrap_env_functions(env)
|
|
142
|
+
|
|
143
|
+
reset_fn()
|
|
144
|
+
|
|
145
|
+
_, rewards = step_fn(3)
|
|
146
|
+
|
|
147
|
+
assert len(rewards) == 2
|
|
@@ -25,7 +25,6 @@ import torch.nn.functional as F
|
|
|
25
25
|
from torch.utils.data import TensorDataset, DataLoader
|
|
26
26
|
from torch.optim import Adam
|
|
27
27
|
|
|
28
|
-
import einx
|
|
29
28
|
from einops import rearrange
|
|
30
29
|
|
|
31
30
|
from locoformer.locoformer import Locoformer, ReplayBuffer
|
|
@@ -60,8 +59,6 @@ def learn(
|
|
|
60
59
|
batch_size = 16,
|
|
61
60
|
epochs = 2,
|
|
62
61
|
):
|
|
63
|
-
device = accelerator.device
|
|
64
|
-
|
|
65
62
|
dl = replay.dataloader(batch_size = batch_size, shuffle = True)
|
|
66
63
|
model, dl, actor_optim, critic_optim = accelerator.prepare(model, dl, actor_optim, critic_optim)
|
|
67
64
|
|
|
@@ -70,18 +67,14 @@ def learn(
|
|
|
70
67
|
|
|
71
68
|
data = SimpleNamespace(**data)
|
|
72
69
|
|
|
73
|
-
seq_len = data.state.shape[1]
|
|
74
|
-
|
|
75
|
-
value_mask = einx.less('j, i -> i j', arange(seq_len, device = device), data._lens)
|
|
76
|
-
value = torch.where(value_mask, data.value, 0.)
|
|
77
|
-
|
|
78
70
|
actor_loss, critic_loss = model.ppo(
|
|
79
71
|
state = data.state,
|
|
80
72
|
action = data.action,
|
|
81
73
|
old_action_log_prob = data.action_log_prob,
|
|
82
74
|
reward = data.reward,
|
|
83
|
-
old_value = value,
|
|
75
|
+
old_value = data.value,
|
|
84
76
|
mask = data.learnable,
|
|
77
|
+
episode_lens = data._lens,
|
|
85
78
|
actor_optim = actor_optim,
|
|
86
79
|
critic_optim = critic_optim
|
|
87
80
|
)
|
|
@@ -94,7 +87,7 @@ def main(
|
|
|
94
87
|
env_name = 'LunarLander-v3',
|
|
95
88
|
num_episodes = 50_000,
|
|
96
89
|
max_timesteps = 500,
|
|
97
|
-
num_episodes_before_learn =
|
|
90
|
+
num_episodes_before_learn = 64,
|
|
98
91
|
clear_video = True,
|
|
99
92
|
video_folder = 'recordings',
|
|
100
93
|
record_every_episode = 250,
|
|
@@ -105,7 +98,8 @@ def main(
|
|
|
105
98
|
ppo_eps_clip = 0.2,
|
|
106
99
|
ppo_entropy_weight = .01,
|
|
107
100
|
batch_size = 16,
|
|
108
|
-
epochs =
|
|
101
|
+
epochs = 3,
|
|
102
|
+
reward_range = (-100., 100.)
|
|
109
103
|
):
|
|
110
104
|
|
|
111
105
|
# accelerate
|
|
@@ -153,7 +147,6 @@ def main(
|
|
|
153
147
|
locoformer = Locoformer(
|
|
154
148
|
embedder = MLP(dim_state, 64, bias = False),
|
|
155
149
|
unembedder = MLP(64, num_actions, bias = False),
|
|
156
|
-
value_network = MLP(64, 1, bias = False),
|
|
157
150
|
transformer = dict(
|
|
158
151
|
dim = 64,
|
|
159
152
|
dim_head = 32,
|
|
@@ -165,16 +158,20 @@ def main(
|
|
|
165
158
|
gae_lam = gae_lam,
|
|
166
159
|
ppo_eps_clip = ppo_eps_clip,
|
|
167
160
|
ppo_entropy_weight = ppo_entropy_weight,
|
|
161
|
+
use_spo = True,
|
|
162
|
+
value_network = MLP(64, 64),
|
|
163
|
+
dim_value_input = 64,
|
|
164
|
+
reward_range = reward_range,
|
|
165
|
+
hl_gauss_loss_kwargs = dict(),
|
|
166
|
+
recurrent_kv_cache = True,
|
|
168
167
|
calc_gae_kwargs = dict(
|
|
169
168
|
use_accelerated = False
|
|
170
|
-
)
|
|
169
|
+
),
|
|
171
170
|
).to(device)
|
|
172
171
|
|
|
173
172
|
optim_actor = Adam([*locoformer.transformer.parameters(), *locoformer.actor_parameters()], lr = learning_rate, betas = betas)
|
|
174
173
|
optim_critic = Adam([*locoformer.transformer.parameters(), *locoformer.critic_parameters()], lr = learning_rate, betas = betas)
|
|
175
174
|
|
|
176
|
-
timesteps_learn = 0
|
|
177
|
-
|
|
178
175
|
# able to wrap the env for all values to torch tensors and back
|
|
179
176
|
# all environments should follow usual MDP interface, domain randomization should be given at instantiation
|
|
180
177
|
|
|
@@ -205,7 +202,8 @@ def main(
|
|
|
205
202
|
|
|
206
203
|
# append to memory
|
|
207
204
|
|
|
208
|
-
|
|
205
|
+
exceeds_max_timesteps = timestep == (max_timesteps - 1)
|
|
206
|
+
done = truncated or terminated or tensor(exceeds_max_timesteps)
|
|
209
207
|
|
|
210
208
|
# get log prob of action
|
|
211
209
|
|
|
@@ -222,23 +220,24 @@ def main(
|
|
|
222
220
|
learnable = tensor(True)
|
|
223
221
|
)
|
|
224
222
|
|
|
225
|
-
#
|
|
226
|
-
# only if terminated signal not detected
|
|
223
|
+
# increment counters
|
|
227
224
|
|
|
228
|
-
|
|
229
|
-
_, next_value = stateful_forward(next_state, return_values = True)
|
|
225
|
+
timestep += 1
|
|
230
226
|
|
|
231
|
-
|
|
227
|
+
# break if done or exceed max timestep
|
|
232
228
|
|
|
233
|
-
|
|
229
|
+
if done:
|
|
234
230
|
|
|
235
|
-
|
|
231
|
+
# handle bootstrap value, which is a non-learnable timestep added with the next value for GAE
|
|
232
|
+
# only if terminated signal not detected
|
|
236
233
|
|
|
237
|
-
|
|
234
|
+
if not terminated:
|
|
235
|
+
_, next_value = stateful_forward(next_state, return_values = True)
|
|
238
236
|
|
|
239
|
-
|
|
237
|
+
memory._replace(value = next_value, learnable = False)
|
|
238
|
+
|
|
239
|
+
replay.store(**memory._asdict())
|
|
240
240
|
|
|
241
|
-
if done or timestep >= max_timesteps:
|
|
242
241
|
break
|
|
243
242
|
|
|
244
243
|
state = next_state
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|