locoformer 0.0.17__tar.gz → 0.0.37__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {locoformer-0.0.17 → locoformer-0.0.37}/PKG-INFO +4 -2
- {locoformer-0.0.17 → locoformer-0.0.37}/README.md +1 -1
- {locoformer-0.0.17 → locoformer-0.0.37}/locoformer/locoformer.py +458 -63
- {locoformer-0.0.17 → locoformer-0.0.37}/pyproject.toml +3 -1
- locoformer-0.0.37/tests/test_locoformer.py +182 -0
- {locoformer-0.0.17 → locoformer-0.0.37}/train.py +1 -1
- {locoformer-0.0.17 → locoformer-0.0.37}/train_gym.py +25 -26
- locoformer-0.0.17/tests/test_locoformer.py +0 -86
- {locoformer-0.0.17 → locoformer-0.0.37}/.github/workflows/python-publish.yml +0 -0
- {locoformer-0.0.17 → locoformer-0.0.37}/.github/workflows/test.yml +0 -0
- {locoformer-0.0.17 → locoformer-0.0.37}/.gitignore +0 -0
- {locoformer-0.0.17 → locoformer-0.0.37}/LICENSE +0 -0
- {locoformer-0.0.17 → locoformer-0.0.37}/data/README.md +0 -0
- {locoformer-0.0.17 → locoformer-0.0.37}/data/enwik8.gz +0 -0
- {locoformer-0.0.17 → locoformer-0.0.37}/fig3.png +0 -0
- {locoformer-0.0.17 → locoformer-0.0.37}/locoformer/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: locoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.37
|
|
4
4
|
Summary: LocoFormer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/locoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/locoformer
|
|
@@ -38,8 +38,10 @@ Requires-Dist: assoc-scan
|
|
|
38
38
|
Requires-Dist: beartype
|
|
39
39
|
Requires-Dist: einops>=0.8.0
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
|
+
Requires-Dist: hl-gauss-pytorch>=0.2.0
|
|
41
42
|
Requires-Dist: rotary-embedding-torch
|
|
42
43
|
Requires-Dist: torch>=2.4
|
|
44
|
+
Requires-Dist: x-evolution
|
|
43
45
|
Requires-Dist: x-mlps-pytorch
|
|
44
46
|
Provides-Extra: examples
|
|
45
47
|
Requires-Dist: accelerate; extra == 'examples'
|
|
@@ -54,7 +56,7 @@ Description-Content-Type: text/markdown
|
|
|
54
56
|
|
|
55
57
|
[LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
|
|
56
58
|
|
|
57
|
-
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment)
|
|
59
|
+
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) and extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
58
60
|
|
|
59
61
|
## Sponsors
|
|
60
62
|
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
[LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
|
|
6
6
|
|
|
7
|
-
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment)
|
|
7
|
+
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) and extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
8
8
|
|
|
9
9
|
## Sponsors
|
|
10
10
|
|
|
@@ -1,10 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
3
|
+
from types import SimpleNamespace
|
|
2
4
|
from functools import partial
|
|
3
5
|
|
|
4
6
|
from pathlib import Path
|
|
5
7
|
from contextlib import contextmanager
|
|
6
8
|
from collections import namedtuple
|
|
7
9
|
|
|
10
|
+
from inspect import signature
|
|
11
|
+
|
|
8
12
|
import numpy as np
|
|
9
13
|
from numpy import ndarray
|
|
10
14
|
from numpy.lib.format import open_memmap
|
|
@@ -16,7 +20,7 @@ import torch
|
|
|
16
20
|
from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
|
|
17
21
|
import torch.nn.functional as F
|
|
18
22
|
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
|
|
19
|
-
from torch.utils._pytree import tree_map
|
|
23
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
20
24
|
from torch.utils.data import Dataset, DataLoader
|
|
21
25
|
from torch.optim import Optimizer
|
|
22
26
|
|
|
@@ -26,12 +30,20 @@ from einops.layers.torch import Rearrange
|
|
|
26
30
|
|
|
27
31
|
from rotary_embedding_torch import RotaryEmbedding
|
|
28
32
|
|
|
33
|
+
from hl_gauss_pytorch import HLGaussLoss
|
|
34
|
+
|
|
29
35
|
from assoc_scan import AssocScan
|
|
30
36
|
|
|
37
|
+
from x_mlps_pytorch import MLP
|
|
38
|
+
|
|
39
|
+
from x_evolution import EvoStrategy
|
|
40
|
+
|
|
31
41
|
# constants
|
|
32
42
|
|
|
33
43
|
LinearNoBias = partial(Linear, bias = False)
|
|
34
44
|
|
|
45
|
+
Cache = namedtuple('Cache', ('curr_timestep', 'kv_cache')) # (int, Tensor)
|
|
46
|
+
|
|
35
47
|
# helper functions
|
|
36
48
|
|
|
37
49
|
def exists(v):
|
|
@@ -43,14 +55,24 @@ def default(v, d):
|
|
|
43
55
|
def first(arr):
|
|
44
56
|
return arr[0]
|
|
45
57
|
|
|
58
|
+
def xnor(x, y):
|
|
59
|
+
return not (x ^ y)
|
|
60
|
+
|
|
46
61
|
def divisible_by(num, den):
|
|
47
62
|
return (num % den) == 0
|
|
48
63
|
|
|
64
|
+
def get_param_names(fn):
|
|
65
|
+
parameters = signature(fn).parameters
|
|
66
|
+
return list(parameters.keys())
|
|
67
|
+
|
|
49
68
|
# tensor helpers
|
|
50
69
|
|
|
51
70
|
def log(t, eps = 1e-20):
|
|
52
71
|
return t.clamp_min(eps).log()
|
|
53
72
|
|
|
73
|
+
def is_empty(t):
|
|
74
|
+
return t.numel() == 0
|
|
75
|
+
|
|
54
76
|
def tree_map_tensor(x, fn):
|
|
55
77
|
return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
|
|
56
78
|
|
|
@@ -67,10 +89,102 @@ def pad_at_dim(
|
|
|
67
89
|
zeros = ((0, 0) * dims_from_right)
|
|
68
90
|
return F.pad(t, (*zeros, *pad), value = value)
|
|
69
91
|
|
|
92
|
+
def normalize(t, eps = 1e-5):
|
|
93
|
+
return (t - t.mean()) / t.std().clamp_min(eps)
|
|
94
|
+
|
|
95
|
+
def tensor_to_dict(
|
|
96
|
+
t: Tensor,
|
|
97
|
+
config: tuple[tuple[str, int] | str],
|
|
98
|
+
dim = -1,
|
|
99
|
+
return_dottable = True
|
|
100
|
+
):
|
|
101
|
+
config = tuple((c, 1) if isinstance(c, str) else c for c in config)
|
|
102
|
+
|
|
103
|
+
names, sizes = zip(*config)
|
|
104
|
+
assert sum(sizes) == t.shape[dim]
|
|
105
|
+
|
|
106
|
+
t = t.split(sizes, dim = dim)
|
|
107
|
+
tensor_dict = dict(zip(names, t))
|
|
108
|
+
|
|
109
|
+
if not return_dottable:
|
|
110
|
+
return tensor_dict
|
|
111
|
+
|
|
112
|
+
return SimpleNamespace(**tensor_dict)
|
|
113
|
+
|
|
70
114
|
def calc_entropy(logits):
|
|
71
115
|
prob = logits.softmax(dim = -1)
|
|
72
116
|
return -(prob * log(prob)).sum(dim = -1)
|
|
73
117
|
|
|
118
|
+
# reward functions - A.2
|
|
119
|
+
|
|
120
|
+
def reward_linear_velocity_command_tracking(
|
|
121
|
+
state,
|
|
122
|
+
command,
|
|
123
|
+
s1 = 1.
|
|
124
|
+
):
|
|
125
|
+
if not (hasattr(state, 'v_xy') and hasattr(command, 'v_xy')):
|
|
126
|
+
return 0.
|
|
127
|
+
|
|
128
|
+
error = (state.v_xy - command.v_xy).norm(dim = -1).pow(2)
|
|
129
|
+
return torch.exp(-error / s1)
|
|
130
|
+
|
|
131
|
+
def reward_angular_velocity_command_tracking(
|
|
132
|
+
state,
|
|
133
|
+
command,
|
|
134
|
+
s2 = 1.
|
|
135
|
+
):
|
|
136
|
+
if not (hasattr(state, 'w_z') and hasattr(command, 'w_z')):
|
|
137
|
+
return 0.
|
|
138
|
+
|
|
139
|
+
error = (state.w_z - command.w_z).norm(dim = -1).pow(2)
|
|
140
|
+
return torch.exp(-error / s2)
|
|
141
|
+
|
|
142
|
+
def reward_base_linear_velocity_penalty(
|
|
143
|
+
state
|
|
144
|
+
):
|
|
145
|
+
if not hasattr(state, 'v_z'):
|
|
146
|
+
return 0.
|
|
147
|
+
|
|
148
|
+
return -state.v_z.norm(dim = -1).pow(2)
|
|
149
|
+
|
|
150
|
+
def reward_base_angular_velocity_penalty(
|
|
151
|
+
state
|
|
152
|
+
):
|
|
153
|
+
if not hasattr(state, 'w_xy'):
|
|
154
|
+
return 0.
|
|
155
|
+
|
|
156
|
+
return -state.w_xy.norm(dim = -1).pow(2)
|
|
157
|
+
|
|
158
|
+
def reward_base_height_penalty(
|
|
159
|
+
state,
|
|
160
|
+
x_z_nominal = 0.27
|
|
161
|
+
):
|
|
162
|
+
if not hasattr(state, 'x_z'):
|
|
163
|
+
return 0.
|
|
164
|
+
|
|
165
|
+
return -(state.x_z - x_z_nominal).norm(dim = -1).pow(2)
|
|
166
|
+
|
|
167
|
+
def reward_joint_acceleration_penalty(
|
|
168
|
+
state
|
|
169
|
+
):
|
|
170
|
+
if not hasattr(state, 'joint_q'):
|
|
171
|
+
return 0.
|
|
172
|
+
|
|
173
|
+
return -state.joint_q.norm(dim = -1).pow(2)
|
|
174
|
+
|
|
175
|
+
def reward_torque_penalty(
|
|
176
|
+
state
|
|
177
|
+
):
|
|
178
|
+
if not hasattr(state, 'tau'):
|
|
179
|
+
return 0.
|
|
180
|
+
|
|
181
|
+
return -state.tau.norm(dim = -1).pow(2)
|
|
182
|
+
|
|
183
|
+
def reward_alive(
|
|
184
|
+
state
|
|
185
|
+
):
|
|
186
|
+
return 1.
|
|
187
|
+
|
|
74
188
|
# generalized advantage estimate
|
|
75
189
|
|
|
76
190
|
@torch.no_grad()
|
|
@@ -250,6 +364,57 @@ class ReplayDataset(Dataset):
|
|
|
250
364
|
|
|
251
365
|
return data
|
|
252
366
|
|
|
367
|
+
class RemappedReplayDataset(Dataset):
|
|
368
|
+
def __init__(
|
|
369
|
+
self,
|
|
370
|
+
dataset: ReplayDataset,
|
|
371
|
+
episode_mapping: Tensor | list[list[int]],
|
|
372
|
+
shuffle_episodes = False
|
|
373
|
+
):
|
|
374
|
+
assert len(dataset) > 0
|
|
375
|
+
self.dataset = dataset
|
|
376
|
+
|
|
377
|
+
if is_tensor(episode_mapping):
|
|
378
|
+
assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
|
|
379
|
+
episode_mapping = episode_mapping.tolist()
|
|
380
|
+
|
|
381
|
+
self.episode_mapping = episode_mapping
|
|
382
|
+
self.shuffle_episodes = shuffle_episodes
|
|
383
|
+
|
|
384
|
+
def __len__(self):
|
|
385
|
+
return len(self.episode_mapping)
|
|
386
|
+
|
|
387
|
+
def __getitem__(self, idx):
|
|
388
|
+
|
|
389
|
+
episode_indices = self.episode_mapping[idx]
|
|
390
|
+
|
|
391
|
+
episode_indices = tensor(episode_indices)
|
|
392
|
+
episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
|
|
393
|
+
|
|
394
|
+
assert not is_empty(episode_indices)
|
|
395
|
+
|
|
396
|
+
if self.shuffle_episodes and episode_indices.numel() > 1:
|
|
397
|
+
num_episodes = len(episode_indices)
|
|
398
|
+
episode_indices = episode_indices[torch.randperm(num_episodes)]
|
|
399
|
+
|
|
400
|
+
episode_data = [self.dataset[i] for i in episode_indices.tolist()]
|
|
401
|
+
|
|
402
|
+
episode_lens = stack([data.pop('_lens') for data in episode_data])
|
|
403
|
+
|
|
404
|
+
keys = first(episode_data).keys()
|
|
405
|
+
|
|
406
|
+
values = [list(data.values()) for data in episode_data]
|
|
407
|
+
|
|
408
|
+
values = [cat(field_values) for field_values in zip(*values)] # concat across time
|
|
409
|
+
|
|
410
|
+
multi_episode_data = dict(zip(keys, values))
|
|
411
|
+
|
|
412
|
+
multi_episode_data['_lens'] = episode_lens.sum()
|
|
413
|
+
|
|
414
|
+
multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
|
|
415
|
+
|
|
416
|
+
return multi_episode_data
|
|
417
|
+
|
|
253
418
|
class ReplayBuffer:
|
|
254
419
|
|
|
255
420
|
@beartype
|
|
@@ -314,6 +479,9 @@ class ReplayBuffer:
|
|
|
314
479
|
|
|
315
480
|
self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
|
|
316
481
|
|
|
482
|
+
def __len__(self):
|
|
483
|
+
return (self.episode_lens > 0).sum().item()
|
|
484
|
+
|
|
317
485
|
def reset_(self):
|
|
318
486
|
self.episode_lens[:] = 0
|
|
319
487
|
self.episode_index = 0
|
|
@@ -375,15 +543,92 @@ class ReplayBuffer:
|
|
|
375
543
|
|
|
376
544
|
return self.memory_namedtuple(**data)
|
|
377
545
|
|
|
378
|
-
def dataset(
|
|
546
|
+
def dataset(
|
|
547
|
+
self,
|
|
548
|
+
episode_mapping: Tensor | list[list[int]] | None = None,
|
|
549
|
+
) -> Dataset:
|
|
379
550
|
self.flush()
|
|
380
551
|
|
|
381
|
-
|
|
552
|
+
dataset = ReplayDataset(self.folder)
|
|
553
|
+
|
|
554
|
+
if not exists(episode_mapping):
|
|
555
|
+
return dataset
|
|
556
|
+
|
|
557
|
+
return RemappedReplayDataset(dataset, episode_mapping)
|
|
382
558
|
|
|
383
|
-
def dataloader(
|
|
559
|
+
def dataloader(
|
|
560
|
+
self,
|
|
561
|
+
batch_size,
|
|
562
|
+
episode_mapping: Tensor | list[list[int]] | None = None,
|
|
563
|
+
**kwargs
|
|
564
|
+
) -> DataLoader:
|
|
384
565
|
self.flush()
|
|
385
566
|
|
|
386
|
-
return DataLoader(self.dataset(), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
|
|
567
|
+
return DataLoader(self.dataset(episode_mapping), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
|
|
568
|
+
|
|
569
|
+
# normalization + conditioning (needed for the commands to the robot)
|
|
570
|
+
|
|
571
|
+
class MaybeAdaRMSNormWrapper(Module):
|
|
572
|
+
def __init__(
|
|
573
|
+
self,
|
|
574
|
+
fn: Module,
|
|
575
|
+
dim,
|
|
576
|
+
dim_cond = None
|
|
577
|
+
):
|
|
578
|
+
super().__init__()
|
|
579
|
+
condition = exists(dim_cond)
|
|
580
|
+
|
|
581
|
+
self.fn = fn
|
|
582
|
+
self.norm = nn.RMSNorm(dim, elementwise_affine = not condition)
|
|
583
|
+
|
|
584
|
+
self.accept_condition = condition
|
|
585
|
+
|
|
586
|
+
if condition:
|
|
587
|
+
self.to_gamma = LinearNoBias(dim_cond, dim)
|
|
588
|
+
self.to_ada_norm_zero = nn.Linear(dim_cond, dim)
|
|
589
|
+
|
|
590
|
+
nn.init.zeros_(self.to_gamma.weight)
|
|
591
|
+
nn.init.zeros_(self.to_ada_norm_zero.weight)
|
|
592
|
+
nn.init.constant_(self.to_ada_norm_zero.bias, -5.)
|
|
593
|
+
|
|
594
|
+
def forward(
|
|
595
|
+
self,
|
|
596
|
+
x,
|
|
597
|
+
cond = None,
|
|
598
|
+
**kwargs
|
|
599
|
+
):
|
|
600
|
+
|
|
601
|
+
need_cond = self.accept_condition
|
|
602
|
+
|
|
603
|
+
assert xnor(exists(cond), need_cond)
|
|
604
|
+
|
|
605
|
+
prenormed = self.norm(x)
|
|
606
|
+
|
|
607
|
+
if need_cond:
|
|
608
|
+
if cond.ndim == 2:
|
|
609
|
+
cond = rearrange(cond, 'b d -> b 1 d')
|
|
610
|
+
|
|
611
|
+
scale_in = self.to_gamma(cond)
|
|
612
|
+
prenormed = prenormed * (scale_in + 1.)
|
|
613
|
+
|
|
614
|
+
all_fn_out = self.fn(prenormed, **kwargs)
|
|
615
|
+
|
|
616
|
+
if not need_cond:
|
|
617
|
+
return all_fn_out
|
|
618
|
+
|
|
619
|
+
# function may return multiple args
|
|
620
|
+
|
|
621
|
+
(out, *rest), tree_spec = tree_flatten(all_fn_out)
|
|
622
|
+
|
|
623
|
+
if need_cond:
|
|
624
|
+
scale_out = self.to_ada_norm_zero(cond).sigmoid()
|
|
625
|
+
out = out * scale_out
|
|
626
|
+
|
|
627
|
+
# restore
|
|
628
|
+
|
|
629
|
+
all_fn_out = tree_unflatten((out, *rest), tree_spec)
|
|
630
|
+
|
|
631
|
+
return all_fn_out
|
|
387
632
|
|
|
388
633
|
# transformer-xl with ppo
|
|
389
634
|
|
|
@@ -394,15 +639,12 @@ class Attention(Module):
|
|
|
394
639
|
window_size,
|
|
395
640
|
dim_head = 64,
|
|
396
641
|
heads = 8,
|
|
397
|
-
pre_rmsnorm = True,
|
|
398
642
|
fixed_window_size = False,
|
|
399
643
|
accept_value_residual = False
|
|
400
644
|
):
|
|
401
645
|
super().__init__()
|
|
402
646
|
self.scale = dim_head ** -0.5
|
|
403
647
|
|
|
404
|
-
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
|
405
|
-
|
|
406
648
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
|
407
649
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
408
650
|
|
|
@@ -446,8 +688,6 @@ class Attention(Module):
|
|
|
446
688
|
|
|
447
689
|
device = tokens.device
|
|
448
690
|
|
|
449
|
-
tokens = self.norm(tokens)
|
|
450
|
-
|
|
451
691
|
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
|
|
452
692
|
|
|
453
693
|
q, k, v = map(self.split_heads, (q, k, v))
|
|
@@ -536,19 +776,26 @@ class TransformerXL(Module):
|
|
|
536
776
|
dim_head = 64,
|
|
537
777
|
heads = 8,
|
|
538
778
|
expansion_factor = 4.,
|
|
779
|
+
dim_cond = None,
|
|
539
780
|
final_norm = True,
|
|
540
781
|
fixed_window_size = False,
|
|
541
782
|
):
|
|
542
783
|
super().__init__()
|
|
543
784
|
|
|
785
|
+
condition = exists(dim_cond)
|
|
786
|
+
|
|
787
|
+
self.to_cond_tokens = MLP(dim_cond, dim * 2, activate_last = True) if exists(dim_cond) else None
|
|
788
|
+
|
|
789
|
+
norm_fn = partial(MaybeAdaRMSNormWrapper, dim = dim, dim_cond = (dim * 2) if condition else None)
|
|
790
|
+
|
|
544
791
|
layers = ModuleList([])
|
|
545
792
|
|
|
546
793
|
for i in range(depth):
|
|
547
794
|
is_first = i == 0
|
|
548
795
|
|
|
549
|
-
attn = Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first)
|
|
796
|
+
attn = norm_fn(Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first))
|
|
550
797
|
|
|
551
|
-
ff = FeedForward(dim = dim, expansion_factor = expansion_factor)
|
|
798
|
+
ff = norm_fn(FeedForward(dim = dim, expansion_factor = expansion_factor))
|
|
552
799
|
|
|
553
800
|
layers.append(ModuleList([
|
|
554
801
|
attn, ff
|
|
@@ -566,20 +813,32 @@ class TransformerXL(Module):
|
|
|
566
813
|
self,
|
|
567
814
|
x,
|
|
568
815
|
cache = None,
|
|
569
|
-
return_kv_cache = False
|
|
816
|
+
return_kv_cache = False,
|
|
817
|
+
condition: Tensor | None = None
|
|
570
818
|
):
|
|
571
819
|
|
|
820
|
+
# cache and residuals
|
|
821
|
+
|
|
572
822
|
cache = default(cache, (None,) * len(self.layers))
|
|
573
823
|
|
|
574
824
|
next_kv_caches = []
|
|
575
825
|
value_residual = None
|
|
576
826
|
|
|
827
|
+
# handle condition
|
|
828
|
+
|
|
829
|
+
cond_tokens = None
|
|
830
|
+
if exists(condition):
|
|
831
|
+
assert exists(self.to_cond_tokens)
|
|
832
|
+
cond_tokens = self.to_cond_tokens(condition)
|
|
833
|
+
|
|
834
|
+
# layers
|
|
835
|
+
|
|
577
836
|
for (attn, ff), kv_cache in zip(self.layers, cache):
|
|
578
837
|
|
|
579
|
-
attn_out, (next_kv_cache, values) = attn(x, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
|
|
838
|
+
attn_out, (next_kv_cache, values) = attn(x, cond = cond_tokens, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
|
|
580
839
|
|
|
581
840
|
x = attn_out + x
|
|
582
|
-
x = ff(x) + x
|
|
841
|
+
x = ff(x, cond = cond_tokens) + x
|
|
583
842
|
|
|
584
843
|
next_kv_caches.append(next_kv_cache)
|
|
585
844
|
value_residual = default(value_residual, values)
|
|
@@ -603,14 +862,21 @@ class Locoformer(Module):
|
|
|
603
862
|
embedder: Module,
|
|
604
863
|
unembedder: Module,
|
|
605
864
|
transformer: dict | TransformerXL,
|
|
606
|
-
value_network: Module | None = None,
|
|
607
865
|
discount_factor = 0.999,
|
|
608
866
|
gae_lam = 0.95,
|
|
609
867
|
ppo_eps_clip = 0.2,
|
|
610
868
|
ppo_entropy_weight = 0.01,
|
|
611
869
|
ppo_value_clip = 0.4,
|
|
870
|
+
dim_value_input = None, # needs to be set for value network to be available
|
|
871
|
+
value_network: Module = nn.Identity(),
|
|
872
|
+
reward_range: tuple[float, float] | None = None,
|
|
873
|
+
reward_shaping_fns: list[Callable[..., float | Tensor]] | None = None,
|
|
874
|
+
num_reward_bins = 32,
|
|
875
|
+
hl_gauss_loss_kwargs = dict(),
|
|
612
876
|
value_loss_weight = 0.5,
|
|
613
|
-
calc_gae_kwargs: dict = dict()
|
|
877
|
+
calc_gae_kwargs: dict = dict(),
|
|
878
|
+
recurrent_kv_cache = True,
|
|
879
|
+
use_spo = False # simple policy optimization https://arxiv.org/abs/2401.16025 - Levine's group (PI) verified it is more stable than PPO
|
|
614
880
|
):
|
|
615
881
|
super().__init__()
|
|
616
882
|
|
|
@@ -622,11 +888,30 @@ class Locoformer(Module):
|
|
|
622
888
|
self.embedder = embedder
|
|
623
889
|
self.unembedder = unembedder
|
|
624
890
|
|
|
625
|
-
self.value_network = value_network
|
|
626
|
-
|
|
627
891
|
self.fixed_window_size = transformer.fixed_window_size
|
|
628
892
|
self.window_size = transformer.window_size
|
|
629
893
|
|
|
894
|
+
# determine value network, using HL Gauss Layer
|
|
895
|
+
|
|
896
|
+
self.to_value_pred = None
|
|
897
|
+
|
|
898
|
+
if exists(dim_value_input):
|
|
899
|
+
assert exists(reward_range)
|
|
900
|
+
|
|
901
|
+
self.to_value_pred = nn.Sequential(
|
|
902
|
+
value_network,
|
|
903
|
+
LinearNoBias(dim_value_input, num_reward_bins)
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
reward_min, reward_max = reward_range
|
|
907
|
+
|
|
908
|
+
self.hl_gauss_loss = HLGaussLoss(
|
|
909
|
+
min_value = reward_min,
|
|
910
|
+
max_value = reward_max,
|
|
911
|
+
num_bins = num_reward_bins,
|
|
912
|
+
**hl_gauss_loss_kwargs
|
|
913
|
+
)
|
|
914
|
+
|
|
630
915
|
# ppo related
|
|
631
916
|
|
|
632
917
|
self.discount_factor = discount_factor
|
|
@@ -638,6 +923,19 @@ class Locoformer(Module):
|
|
|
638
923
|
|
|
639
924
|
self.calc_gae_kwargs = calc_gae_kwargs
|
|
640
925
|
|
|
926
|
+
# maybe use spo
|
|
927
|
+
|
|
928
|
+
self.use_spo = use_spo
|
|
929
|
+
|
|
930
|
+
# maybe recurrent kv cache, from Ding et al. https://arxiv.org/abs/2012.15688
|
|
931
|
+
|
|
932
|
+
self.recurrent_kv_cache = recurrent_kv_cache
|
|
933
|
+
|
|
934
|
+
# reward shaping function
|
|
935
|
+
|
|
936
|
+
self.has_reward_shaping = exists(reward_shaping_fns)
|
|
937
|
+
self.reward_shaping_fns = reward_shaping_fns
|
|
938
|
+
|
|
641
939
|
# loss related
|
|
642
940
|
|
|
643
941
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
@@ -650,10 +948,18 @@ class Locoformer(Module):
|
|
|
650
948
|
return self.unembedder.parameters()
|
|
651
949
|
|
|
652
950
|
def critic_parameters(self):
|
|
653
|
-
if not exists(self.
|
|
951
|
+
if not exists(self.to_value_pred):
|
|
654
952
|
return []
|
|
655
953
|
|
|
656
|
-
return self.
|
|
954
|
+
return self.to_value_pred.parameters()
|
|
955
|
+
|
|
956
|
+
def evolve(
|
|
957
|
+
self,
|
|
958
|
+
environment,
|
|
959
|
+
**kwargs
|
|
960
|
+
):
|
|
961
|
+
evo_strat = EvoStrategy(self, environment = environment, **kwargs)
|
|
962
|
+
evo_strat()
|
|
657
963
|
|
|
658
964
|
def ppo(
|
|
659
965
|
self,
|
|
@@ -663,12 +969,20 @@ class Locoformer(Module):
|
|
|
663
969
|
reward,
|
|
664
970
|
old_value,
|
|
665
971
|
mask,
|
|
972
|
+
episode_lens,
|
|
666
973
|
actor_optim: Optimizer | None = None,
|
|
667
974
|
critic_optim: Optimizer | None = None
|
|
668
975
|
):
|
|
669
976
|
window_size = self.window_size
|
|
670
977
|
total_learnable_tokens = mask.sum().item()
|
|
671
978
|
|
|
979
|
+
seq_len = state.shape[1]
|
|
980
|
+
gae_mask = einx.less('j, i -> i j', arange(seq_len, device = self.device), episode_lens)
|
|
981
|
+
|
|
982
|
+
advantage, returns = calc_gae(reward, old_value, masks = gae_mask, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
|
|
983
|
+
|
|
984
|
+
advantage = normalize(advantage)
|
|
985
|
+
|
|
672
986
|
windowed_tensors = [
|
|
673
987
|
t.split(window_size, dim = 1) for t in
|
|
674
988
|
(
|
|
@@ -677,7 +991,9 @@ class Locoformer(Module):
|
|
|
677
991
|
old_action_log_prob,
|
|
678
992
|
reward,
|
|
679
993
|
old_value,
|
|
680
|
-
mask
|
|
994
|
+
mask,
|
|
995
|
+
advantage,
|
|
996
|
+
returns
|
|
681
997
|
)
|
|
682
998
|
]
|
|
683
999
|
|
|
@@ -694,10 +1010,12 @@ class Locoformer(Module):
|
|
|
694
1010
|
old_action_log_prob,
|
|
695
1011
|
reward,
|
|
696
1012
|
old_value,
|
|
697
|
-
mask
|
|
1013
|
+
mask,
|
|
1014
|
+
advantage,
|
|
1015
|
+
returns
|
|
698
1016
|
) in zip(*windowed_tensors):
|
|
699
1017
|
|
|
700
|
-
(action_logits,
|
|
1018
|
+
(action_logits, value_logits), cache = self.forward(state, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True)
|
|
701
1019
|
entropy = calc_entropy(action_logits)
|
|
702
1020
|
|
|
703
1021
|
action = rearrange(action, 'b t -> b t 1')
|
|
@@ -709,9 +1027,10 @@ class Locoformer(Module):
|
|
|
709
1027
|
eps_clip = self.ppo_eps_clip
|
|
710
1028
|
ratio = (log_prob - old_action_log_prob).exp()
|
|
711
1029
|
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
1030
|
+
if self.use_spo:
|
|
1031
|
+
actor_loss = -(ratio * advantage - (advantage.abs() * (ratio - 1.).square()) / (2 * eps_clip))
|
|
1032
|
+
else:
|
|
1033
|
+
actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
|
|
715
1034
|
|
|
716
1035
|
actor_loss = actor_loss - self.ppo_entropy_weight * entropy
|
|
717
1036
|
|
|
@@ -720,11 +1039,13 @@ class Locoformer(Module):
|
|
|
720
1039
|
|
|
721
1040
|
# update critic
|
|
722
1041
|
|
|
723
|
-
value_loss =
|
|
1042
|
+
value_loss = self.hl_gauss_loss(value_logits, returns, reduction = 'none')
|
|
724
1043
|
|
|
725
1044
|
value_clip = self.ppo_value_clip
|
|
1045
|
+
value = self.hl_gauss_loss(value_logits)
|
|
1046
|
+
|
|
726
1047
|
clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
|
|
727
|
-
clipped_value_loss =
|
|
1048
|
+
clipped_value_loss = self.hl_gauss_loss(clipped_value, returns, reduction = 'none')
|
|
728
1049
|
|
|
729
1050
|
critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
|
|
730
1051
|
|
|
@@ -750,28 +1071,65 @@ class Locoformer(Module):
|
|
|
750
1071
|
|
|
751
1072
|
return mean_actor_loss.detach(), mean_critic_loss.detach()
|
|
752
1073
|
|
|
1074
|
+
def state_and_command_to_rewards(
|
|
1075
|
+
self,
|
|
1076
|
+
state,
|
|
1077
|
+
commands = None
|
|
1078
|
+
) -> Tensor:
|
|
1079
|
+
|
|
1080
|
+
assert self.has_reward_shaping
|
|
1081
|
+
|
|
1082
|
+
rewards = []
|
|
1083
|
+
|
|
1084
|
+
for fn in self.reward_shaping_fns:
|
|
1085
|
+
param_names = get_param_names(fn)
|
|
1086
|
+
param_names = set(param_names) & {'state', 'command'}
|
|
1087
|
+
|
|
1088
|
+
if param_names == {'state'}: # only state
|
|
1089
|
+
reward = fn(state = state)
|
|
1090
|
+
elif param_names == {'state', 'command'}: # state and command
|
|
1091
|
+
reward = fn(state = state, command = commands)
|
|
1092
|
+
else:
|
|
1093
|
+
raise ValueError('invalid number of arguments for reward shaping function')
|
|
1094
|
+
|
|
1095
|
+
rewards.append(reward)
|
|
1096
|
+
|
|
1097
|
+
# cast to Tensor if returns a float, just make it flexible for researcher
|
|
1098
|
+
|
|
1099
|
+
rewards = [tensor(reward) if not is_tensor(reward) else reward for reward in rewards]
|
|
1100
|
+
|
|
1101
|
+
return stack(rewards)
|
|
1102
|
+
|
|
753
1103
|
def wrap_env_functions(self, env):
|
|
754
1104
|
|
|
755
|
-
def
|
|
756
|
-
|
|
1105
|
+
def transform_output(el):
|
|
1106
|
+
if isinstance(el, ndarray):
|
|
1107
|
+
return from_numpy(el)
|
|
1108
|
+
elif isinstance(el, (int, bool, float)):
|
|
1109
|
+
return tensor(el)
|
|
1110
|
+
else:
|
|
1111
|
+
return el
|
|
757
1112
|
|
|
758
|
-
|
|
759
|
-
|
|
1113
|
+
def wrapped_reset(*args, **kwargs):
|
|
1114
|
+
env_reset_out = env.reset(*args, **kwargs)
|
|
760
1115
|
|
|
761
|
-
return
|
|
1116
|
+
return tree_map(transform_output, env_reset_out)
|
|
762
1117
|
|
|
763
1118
|
def wrapped_step(action, *args, **kwargs):
|
|
764
|
-
out = env.step(action.item(), *args, **kwargs)
|
|
765
1119
|
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
1120
|
+
if is_tensor(action):
|
|
1121
|
+
action = action.item()
|
|
1122
|
+
|
|
1123
|
+
env_step_out = env.step(action, *args, **kwargs)
|
|
1124
|
+
|
|
1125
|
+
env_step_out_torch = tree_map(transform_output, env_step_out)
|
|
1126
|
+
|
|
1127
|
+
if not self.has_reward_shaping:
|
|
1128
|
+
return env_step_out_torch
|
|
773
1129
|
|
|
774
|
-
|
|
1130
|
+
shaped_rewards = self.state_and_command_to_rewards(env_step_out_torch)
|
|
1131
|
+
|
|
1132
|
+
return env_step_out_torch, shaped_rewards
|
|
775
1133
|
|
|
776
1134
|
return wrapped_reset, wrapped_step
|
|
777
1135
|
|
|
@@ -781,13 +1139,18 @@ class Locoformer(Module):
|
|
|
781
1139
|
inference_mode = False,
|
|
782
1140
|
has_batch_dim = False,
|
|
783
1141
|
has_time_dim = False,
|
|
1142
|
+
state_time_dim = 1,
|
|
784
1143
|
**kwargs
|
|
785
1144
|
):
|
|
786
1145
|
window_size = self.window_size
|
|
787
1146
|
|
|
788
1147
|
cache = None
|
|
789
1148
|
|
|
790
|
-
def stateful_forward(
|
|
1149
|
+
def stateful_forward(
|
|
1150
|
+
state: Tensor,
|
|
1151
|
+
condition: Tensor | None = None,
|
|
1152
|
+
**override_kwargs
|
|
1153
|
+
):
|
|
791
1154
|
nonlocal cache
|
|
792
1155
|
|
|
793
1156
|
# handle no batch or time, for easier time rolling out against envs
|
|
@@ -795,24 +1158,23 @@ class Locoformer(Module):
|
|
|
795
1158
|
if not has_batch_dim:
|
|
796
1159
|
state = rearrange(state, '... -> 1 ...')
|
|
797
1160
|
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
# forwards
|
|
1161
|
+
if exists(command):
|
|
1162
|
+
condition = rearrange(condition, '... -> 1 ...')
|
|
802
1163
|
|
|
803
|
-
|
|
1164
|
+
if not has_time_dim:
|
|
1165
|
+
state = state.unsqueeze(state_time_dim)
|
|
804
1166
|
|
|
805
|
-
|
|
1167
|
+
if exists(command):
|
|
1168
|
+
condition = rearrange(condition, '... d -> ... 1 d')
|
|
806
1169
|
|
|
807
|
-
|
|
1170
|
+
# forwards
|
|
808
1171
|
|
|
809
|
-
|
|
810
|
-
cache = cache[..., -window_size:, :]
|
|
1172
|
+
out, cache = self.forward(state, condition = condition, cache = cache, **{**kwargs, **override_kwargs})
|
|
811
1173
|
|
|
812
1174
|
# maybe remove batch or time
|
|
813
1175
|
|
|
814
1176
|
if not has_time_dim:
|
|
815
|
-
out = tree_map_tensor(out, lambda t:
|
|
1177
|
+
out = tree_map_tensor(out, lambda t: t.squeeze(state_time_dim))
|
|
816
1178
|
|
|
817
1179
|
if not has_batch_dim:
|
|
818
1180
|
out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
|
|
@@ -841,16 +1203,36 @@ class Locoformer(Module):
|
|
|
841
1203
|
def forward(
|
|
842
1204
|
self,
|
|
843
1205
|
state: Tensor,
|
|
844
|
-
cache:
|
|
1206
|
+
cache: Cache | None = None,
|
|
1207
|
+
condition: Tensor | None = None,
|
|
845
1208
|
detach_cache = False,
|
|
846
|
-
return_values = False
|
|
1209
|
+
return_values = False,
|
|
1210
|
+
return_raw_value_logits = False
|
|
847
1211
|
):
|
|
848
1212
|
|
|
849
1213
|
state = state.to(self.device)
|
|
850
1214
|
|
|
851
1215
|
tokens = self.embedder(state)
|
|
852
1216
|
|
|
853
|
-
|
|
1217
|
+
# time
|
|
1218
|
+
|
|
1219
|
+
time = tokens.shape[-2]
|
|
1220
|
+
|
|
1221
|
+
# destruct the cache for the current timestep and the cache
|
|
1222
|
+
|
|
1223
|
+
prev_kv_cache = None
|
|
1224
|
+
timestep_start = 0
|
|
1225
|
+
|
|
1226
|
+
if exists(cache):
|
|
1227
|
+
timestep_start, prev_kv_cache = cache
|
|
1228
|
+
|
|
1229
|
+
# an assert - make sure during training or inference, forward never gets anything that crosses the window segment boundary, to open up some possibilities with extending memory
|
|
1230
|
+
|
|
1231
|
+
assert ((timestep_start % self.window_size) + time) <= self.window_size
|
|
1232
|
+
|
|
1233
|
+
# attention
|
|
1234
|
+
|
|
1235
|
+
embed, kv_cache = self.transformer(tokens, condition = condition, cache = prev_kv_cache, return_kv_cache = True)
|
|
854
1236
|
|
|
855
1237
|
# unembed to actions - in language models this would be the next state
|
|
856
1238
|
|
|
@@ -866,16 +1248,29 @@ class Locoformer(Module):
|
|
|
866
1248
|
# handle returning of values
|
|
867
1249
|
|
|
868
1250
|
if return_values:
|
|
869
|
-
assert exists(self.
|
|
1251
|
+
assert exists(self.to_value_pred)
|
|
870
1252
|
|
|
871
|
-
values = self.
|
|
1253
|
+
values = self.to_value_pred(embed)
|
|
872
1254
|
|
|
873
|
-
if
|
|
874
|
-
|
|
875
|
-
values = rearrange(values, '... 1 -> ...')
|
|
1255
|
+
if not return_raw_value_logits:
|
|
1256
|
+
values = self.hl_gauss_loss(values) # converts the value logits to scalar values
|
|
876
1257
|
|
|
877
1258
|
out = (out, values)
|
|
878
1259
|
|
|
879
1260
|
# output and cache
|
|
880
1261
|
|
|
881
|
-
|
|
1262
|
+
next_timestep = time + timestep_start
|
|
1263
|
+
|
|
1264
|
+
# handle curtailing kv cache at the right intervals
|
|
1265
|
+
|
|
1266
|
+
window_size = self.window_size
|
|
1267
|
+
|
|
1268
|
+
if self.fixed_window_size or divisible_by(next_timestep, window_size * 2):
|
|
1269
|
+
kv_cache = kv_cache[..., -window_size:, :]
|
|
1270
|
+
|
|
1271
|
+
# maybe recurrent cache - shift the kv cache from one layer above to the one below, for extending on receptive field of past
|
|
1272
|
+
|
|
1273
|
+
if self.recurrent_kv_cache and divisible_by(next_timestep, window_size):
|
|
1274
|
+
kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
|
|
1275
|
+
|
|
1276
|
+
return out, (next_timestep, kv_cache)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "locoformer"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.37"
|
|
4
4
|
description = "LocoFormer"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -30,8 +30,10 @@ dependencies = [
|
|
|
30
30
|
"beartype",
|
|
31
31
|
"einx>=0.3.0",
|
|
32
32
|
"einops>=0.8.0",
|
|
33
|
+
"hl-gauss-pytorch>=0.2.0",
|
|
33
34
|
"rotary-embedding-torch",
|
|
34
35
|
"torch>=2.4",
|
|
36
|
+
"x-evolution",
|
|
35
37
|
"x-mlps-pytorch",
|
|
36
38
|
]
|
|
37
39
|
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
param = pytest.mark.parametrize
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
from x_mlps_pytorch import MLP
|
|
7
|
+
|
|
8
|
+
from einops import rearrange
|
|
9
|
+
|
|
10
|
+
from locoformer.locoformer import Locoformer
|
|
11
|
+
|
|
12
|
+
@param('recurrent_kv_cache', (False, True))
|
|
13
|
+
@param('has_commands', (False, True))
|
|
14
|
+
def test_locoformer(
|
|
15
|
+
recurrent_kv_cache,
|
|
16
|
+
has_commands
|
|
17
|
+
):
|
|
18
|
+
|
|
19
|
+
model = Locoformer(
|
|
20
|
+
embedder = nn.Embedding(256, 128),
|
|
21
|
+
unembedder = nn.Linear(128, 256, bias = False),
|
|
22
|
+
value_network = MLP(128, 64, 32),
|
|
23
|
+
dim_value_input = 32,
|
|
24
|
+
reward_range = (-100., 100.),
|
|
25
|
+
recurrent_kv_cache = recurrent_kv_cache,
|
|
26
|
+
transformer = dict(
|
|
27
|
+
dim = 128,
|
|
28
|
+
depth = 1,
|
|
29
|
+
window_size = 512,
|
|
30
|
+
dim_cond = 2 if has_commands else None
|
|
31
|
+
)
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
seq = torch.randint(0, 256, (3, 512))
|
|
35
|
+
|
|
36
|
+
commands = None
|
|
37
|
+
if has_commands:
|
|
38
|
+
commands = torch.randn(3, 512, 2)
|
|
39
|
+
|
|
40
|
+
(logits, values), cache = model(seq, condition = commands, return_values = True)
|
|
41
|
+
(logits, values), cache = model(seq, condition = commands, return_values = True, cache = cache)
|
|
42
|
+
(logits, values), cache = model(seq, condition = commands, return_values = True, cache = cache)
|
|
43
|
+
|
|
44
|
+
assert logits.shape == (3, 512, 256)
|
|
45
|
+
|
|
46
|
+
stateful_forward = model.get_stateful_forward(has_batch_dim = True, has_time_dim = True, return_values = True, inference_mode = True)
|
|
47
|
+
|
|
48
|
+
inference_command = torch.randn(1, 1, 2) if has_commands else None
|
|
49
|
+
|
|
50
|
+
for state in seq.unbind(dim = -1):
|
|
51
|
+
state = rearrange(state, 'b -> b 1')
|
|
52
|
+
|
|
53
|
+
logits, values = stateful_forward(state, condition = inference_command)
|
|
54
|
+
assert logits.shape == (3, 1, 256)
|
|
55
|
+
|
|
56
|
+
def test_replay():
|
|
57
|
+
from locoformer.locoformer import ReplayBuffer
|
|
58
|
+
|
|
59
|
+
replay_buffer = ReplayBuffer(
|
|
60
|
+
'./replay_data',
|
|
61
|
+
max_episodes = 10_000,
|
|
62
|
+
max_timesteps = 501,
|
|
63
|
+
fields = dict(
|
|
64
|
+
state = ('float', (8,)),
|
|
65
|
+
action = 'int',
|
|
66
|
+
action_log_prob = 'float',
|
|
67
|
+
reward = 'float',
|
|
68
|
+
value = 'float',
|
|
69
|
+
done = 'bool'
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
lens = [3, 5, 4]
|
|
74
|
+
|
|
75
|
+
for episode_len in lens:
|
|
76
|
+
with replay_buffer.one_episode():
|
|
77
|
+
for _ in range(episode_len):
|
|
78
|
+
state = torch.randn((8,))
|
|
79
|
+
action = torch.randint(0, 4, ())
|
|
80
|
+
log_prob = torch.randn(())
|
|
81
|
+
reward = torch.randn(())
|
|
82
|
+
value = torch.randn(())
|
|
83
|
+
done = torch.randint(0, 2, ()).bool()
|
|
84
|
+
|
|
85
|
+
replay_buffer.store(
|
|
86
|
+
state = state,
|
|
87
|
+
action = action,
|
|
88
|
+
action_log_prob = log_prob,
|
|
89
|
+
reward = reward,
|
|
90
|
+
value = value,
|
|
91
|
+
done = done
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
dataset = replay_buffer.dataset()
|
|
95
|
+
|
|
96
|
+
assert len(dataset) == 3
|
|
97
|
+
|
|
98
|
+
assert torch.is_tensor(dataset[0]['state'])
|
|
99
|
+
|
|
100
|
+
dataloader = replay_buffer.dataloader(batch_size = 3)
|
|
101
|
+
|
|
102
|
+
assert next(iter(dataloader))['state'].shape[0] == 3
|
|
103
|
+
|
|
104
|
+
# we will now consider consecutive pairs of episodes as 2 trials to be used for in-context adaptation
|
|
105
|
+
# but realistically there will be a function that converts a given ReplayBuffer -> Int[batch, episode_indices]
|
|
106
|
+
|
|
107
|
+
from torch import stack, arange
|
|
108
|
+
|
|
109
|
+
episode_indices = arange(len(replay_buffer))
|
|
110
|
+
remapped_episodes = stack((episode_indices[:-1], episode_indices[1:]))
|
|
111
|
+
|
|
112
|
+
dataloader = replay_buffer.dataloader(
|
|
113
|
+
batch_size = 1,
|
|
114
|
+
episode_mapping = remapped_episodes
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
assert next(iter(dataloader))['_lens'][0] == (3 + 5) # first and second episodes are concatted together timewise
|
|
118
|
+
|
|
119
|
+
def test_reward_shaping():
|
|
120
|
+
|
|
121
|
+
model = Locoformer(
|
|
122
|
+
embedder = nn.Embedding(256, 128),
|
|
123
|
+
unembedder = nn.Linear(128, 256, bias = False),
|
|
124
|
+
value_network = MLP(128, 64, 32),
|
|
125
|
+
dim_value_input = 32,
|
|
126
|
+
reward_range = (-100., 100.),
|
|
127
|
+
reward_shaping_fns = [
|
|
128
|
+
lambda state: (state[3] - 2.5).pow(2).mean(),
|
|
129
|
+
lambda state, command: state[4:6].norm(dim = -1)
|
|
130
|
+
],
|
|
131
|
+
transformer = dict(
|
|
132
|
+
dim = 128,
|
|
133
|
+
depth = 1,
|
|
134
|
+
window_size = 512
|
|
135
|
+
)
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
import numpy as np
|
|
139
|
+
|
|
140
|
+
class MockEnv:
|
|
141
|
+
def reset(self):
|
|
142
|
+
return np.random.normal(size = (10,))
|
|
143
|
+
|
|
144
|
+
def step(self, *args, **kwargs):
|
|
145
|
+
return np.random.normal(size = (10,))
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
env = MockEnv()
|
|
149
|
+
|
|
150
|
+
reset_fn, step_fn = model.wrap_env_functions(env)
|
|
151
|
+
|
|
152
|
+
reset_fn()
|
|
153
|
+
|
|
154
|
+
_, rewards = step_fn(3)
|
|
155
|
+
|
|
156
|
+
assert len(rewards) == 2
|
|
157
|
+
|
|
158
|
+
def test_tensor_to_dict():
|
|
159
|
+
state = torch.randn(1, 3, 5)
|
|
160
|
+
config = (('xyz', 3), 'vx', 'vy')
|
|
161
|
+
|
|
162
|
+
from locoformer.locoformer import tensor_to_dict
|
|
163
|
+
|
|
164
|
+
state_dict = tensor_to_dict(state, config)
|
|
165
|
+
assert hasattr(state_dict, 'xyz') and state_dict.xyz.shape == (1, 3, 3)
|
|
166
|
+
|
|
167
|
+
def test_evo():
|
|
168
|
+
|
|
169
|
+
model = Locoformer(
|
|
170
|
+
embedder = nn.Embedding(256, 128),
|
|
171
|
+
unembedder = nn.Linear(128, 256, bias = False),
|
|
172
|
+
value_network = MLP(128, 64, 32),
|
|
173
|
+
dim_value_input = 32,
|
|
174
|
+
reward_range = (-100., 100.),
|
|
175
|
+
transformer = dict(
|
|
176
|
+
dim = 128,
|
|
177
|
+
depth = 1,
|
|
178
|
+
window_size = 512,
|
|
179
|
+
)
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
model.evolve(lambda model: 1., num_generations = 1)
|
|
@@ -25,7 +25,6 @@ import torch.nn.functional as F
|
|
|
25
25
|
from torch.utils.data import TensorDataset, DataLoader
|
|
26
26
|
from torch.optim import Adam
|
|
27
27
|
|
|
28
|
-
import einx
|
|
29
28
|
from einops import rearrange
|
|
30
29
|
|
|
31
30
|
from locoformer.locoformer import Locoformer, ReplayBuffer
|
|
@@ -60,8 +59,6 @@ def learn(
|
|
|
60
59
|
batch_size = 16,
|
|
61
60
|
epochs = 2,
|
|
62
61
|
):
|
|
63
|
-
device = accelerator.device
|
|
64
|
-
|
|
65
62
|
dl = replay.dataloader(batch_size = batch_size, shuffle = True)
|
|
66
63
|
model, dl, actor_optim, critic_optim = accelerator.prepare(model, dl, actor_optim, critic_optim)
|
|
67
64
|
|
|
@@ -70,18 +67,14 @@ def learn(
|
|
|
70
67
|
|
|
71
68
|
data = SimpleNamespace(**data)
|
|
72
69
|
|
|
73
|
-
seq_len = data.state.shape[1]
|
|
74
|
-
|
|
75
|
-
value_mask = einx.less('j, i -> i j', arange(seq_len, device = device), data._lens)
|
|
76
|
-
value = torch.where(value_mask, data.value, 0.)
|
|
77
|
-
|
|
78
70
|
actor_loss, critic_loss = model.ppo(
|
|
79
71
|
state = data.state,
|
|
80
72
|
action = data.action,
|
|
81
73
|
old_action_log_prob = data.action_log_prob,
|
|
82
74
|
reward = data.reward,
|
|
83
|
-
old_value = value,
|
|
75
|
+
old_value = data.value,
|
|
84
76
|
mask = data.learnable,
|
|
77
|
+
episode_lens = data._lens,
|
|
85
78
|
actor_optim = actor_optim,
|
|
86
79
|
critic_optim = critic_optim
|
|
87
80
|
)
|
|
@@ -94,7 +87,7 @@ def main(
|
|
|
94
87
|
env_name = 'LunarLander-v3',
|
|
95
88
|
num_episodes = 50_000,
|
|
96
89
|
max_timesteps = 500,
|
|
97
|
-
num_episodes_before_learn =
|
|
90
|
+
num_episodes_before_learn = 64,
|
|
98
91
|
clear_video = True,
|
|
99
92
|
video_folder = 'recordings',
|
|
100
93
|
record_every_episode = 250,
|
|
@@ -105,7 +98,8 @@ def main(
|
|
|
105
98
|
ppo_eps_clip = 0.2,
|
|
106
99
|
ppo_entropy_weight = .01,
|
|
107
100
|
batch_size = 16,
|
|
108
|
-
epochs =
|
|
101
|
+
epochs = 3,
|
|
102
|
+
reward_range = (-100., 100.)
|
|
109
103
|
):
|
|
110
104
|
|
|
111
105
|
# accelerate
|
|
@@ -153,7 +147,6 @@ def main(
|
|
|
153
147
|
locoformer = Locoformer(
|
|
154
148
|
embedder = MLP(dim_state, 64, bias = False),
|
|
155
149
|
unembedder = MLP(64, num_actions, bias = False),
|
|
156
|
-
value_network = MLP(64, 1, bias = False),
|
|
157
150
|
transformer = dict(
|
|
158
151
|
dim = 64,
|
|
159
152
|
dim_head = 32,
|
|
@@ -165,16 +158,20 @@ def main(
|
|
|
165
158
|
gae_lam = gae_lam,
|
|
166
159
|
ppo_eps_clip = ppo_eps_clip,
|
|
167
160
|
ppo_entropy_weight = ppo_entropy_weight,
|
|
161
|
+
use_spo = True,
|
|
162
|
+
value_network = MLP(64, 64),
|
|
163
|
+
dim_value_input = 64,
|
|
164
|
+
reward_range = reward_range,
|
|
165
|
+
hl_gauss_loss_kwargs = dict(),
|
|
166
|
+
recurrent_kv_cache = True,
|
|
168
167
|
calc_gae_kwargs = dict(
|
|
169
168
|
use_accelerated = False
|
|
170
|
-
)
|
|
169
|
+
),
|
|
171
170
|
).to(device)
|
|
172
171
|
|
|
173
172
|
optim_actor = Adam([*locoformer.transformer.parameters(), *locoformer.actor_parameters()], lr = learning_rate, betas = betas)
|
|
174
173
|
optim_critic = Adam([*locoformer.transformer.parameters(), *locoformer.critic_parameters()], lr = learning_rate, betas = betas)
|
|
175
174
|
|
|
176
|
-
timesteps_learn = 0
|
|
177
|
-
|
|
178
175
|
# able to wrap the env for all values to torch tensors and back
|
|
179
176
|
# all environments should follow usual MDP interface, domain randomization should be given at instantiation
|
|
180
177
|
|
|
@@ -205,7 +202,8 @@ def main(
|
|
|
205
202
|
|
|
206
203
|
# append to memory
|
|
207
204
|
|
|
208
|
-
|
|
205
|
+
exceeds_max_timesteps = timestep == (max_timesteps - 1)
|
|
206
|
+
done = truncated or terminated or tensor(exceeds_max_timesteps)
|
|
209
207
|
|
|
210
208
|
# get log prob of action
|
|
211
209
|
|
|
@@ -222,23 +220,24 @@ def main(
|
|
|
222
220
|
learnable = tensor(True)
|
|
223
221
|
)
|
|
224
222
|
|
|
225
|
-
#
|
|
226
|
-
# only if terminated signal not detected
|
|
223
|
+
# increment counters
|
|
227
224
|
|
|
228
|
-
|
|
229
|
-
_, next_value = stateful_forward(next_state, return_values = True)
|
|
225
|
+
timestep += 1
|
|
230
226
|
|
|
231
|
-
|
|
227
|
+
# break if done or exceed max timestep
|
|
232
228
|
|
|
233
|
-
|
|
229
|
+
if done:
|
|
234
230
|
|
|
235
|
-
|
|
231
|
+
# handle bootstrap value, which is a non-learnable timestep added with the next value for GAE
|
|
232
|
+
# only if terminated signal not detected
|
|
236
233
|
|
|
237
|
-
|
|
234
|
+
if not terminated:
|
|
235
|
+
_, next_value = stateful_forward(next_state, return_values = True)
|
|
238
236
|
|
|
239
|
-
|
|
237
|
+
memory._replace(value = next_value, learnable = False)
|
|
238
|
+
|
|
239
|
+
replay.store(**memory._asdict())
|
|
240
240
|
|
|
241
|
-
if done or timestep >= max_timesteps:
|
|
242
241
|
break
|
|
243
242
|
|
|
244
243
|
state = next_state
|
|
@@ -1,86 +0,0 @@
|
|
|
1
|
-
import pytest
|
|
2
|
-
param = pytest.mark.parametrize
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
from x_mlps_pytorch import MLP
|
|
6
|
-
|
|
7
|
-
from einops import rearrange
|
|
8
|
-
|
|
9
|
-
def test_locoformer():
|
|
10
|
-
from locoformer.locoformer import Locoformer
|
|
11
|
-
from torch import nn
|
|
12
|
-
|
|
13
|
-
model = Locoformer(
|
|
14
|
-
embedder = nn.Embedding(256, 128),
|
|
15
|
-
unembedder = nn.Linear(128, 256, bias = False),
|
|
16
|
-
value_network = MLP(128, 32, 1),
|
|
17
|
-
transformer = dict(
|
|
18
|
-
dim = 128,
|
|
19
|
-
depth = 1,
|
|
20
|
-
window_size = 512
|
|
21
|
-
)
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
seq = torch.randint(0, 256, (3, 512))
|
|
25
|
-
|
|
26
|
-
(logits, values), cache = model(seq, return_values = True)
|
|
27
|
-
(logits, values), cache = model(seq, return_values = True, cache = cache)
|
|
28
|
-
(logits, values), cache = model(seq, return_values = True, cache = cache)
|
|
29
|
-
|
|
30
|
-
assert logits.shape == (3, 512, 256)
|
|
31
|
-
|
|
32
|
-
stateful_forward = model.get_stateful_forward(has_batch_dim = True, has_time_dim = True, return_values = True, inference_mode = True)
|
|
33
|
-
|
|
34
|
-
for state in seq.unbind(dim = -1):
|
|
35
|
-
state = rearrange(state, 'b -> b 1')
|
|
36
|
-
|
|
37
|
-
logits, values = stateful_forward(state)
|
|
38
|
-
assert logits.shape == (3, 1, 256)
|
|
39
|
-
|
|
40
|
-
def test_replay():
|
|
41
|
-
from locoformer.locoformer import ReplayBuffer
|
|
42
|
-
|
|
43
|
-
replay_buffer = ReplayBuffer(
|
|
44
|
-
'./replay_data',
|
|
45
|
-
max_episodes = 10_000,
|
|
46
|
-
max_timesteps = 501,
|
|
47
|
-
fields = dict(
|
|
48
|
-
state = ('float', (8,)),
|
|
49
|
-
action = 'int',
|
|
50
|
-
action_log_prob = 'float',
|
|
51
|
-
reward = 'float',
|
|
52
|
-
value = 'float',
|
|
53
|
-
done = 'bool'
|
|
54
|
-
)
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
lens = [3, 5, 4]
|
|
58
|
-
|
|
59
|
-
for episode_len in lens:
|
|
60
|
-
with replay_buffer.one_episode():
|
|
61
|
-
for _ in range(episode_len):
|
|
62
|
-
state = torch.randn((8,))
|
|
63
|
-
action = torch.randint(0, 4, ())
|
|
64
|
-
log_prob = torch.randn(())
|
|
65
|
-
reward = torch.randn(())
|
|
66
|
-
value = torch.randn(())
|
|
67
|
-
done = torch.randint(0, 2, ()).bool()
|
|
68
|
-
|
|
69
|
-
replay_buffer.store(
|
|
70
|
-
state = state,
|
|
71
|
-
action = action,
|
|
72
|
-
action_log_prob = log_prob,
|
|
73
|
-
reward = reward,
|
|
74
|
-
value = value,
|
|
75
|
-
done = done
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
dataset = replay_buffer.dataset()
|
|
79
|
-
|
|
80
|
-
assert len(dataset) == 3
|
|
81
|
-
|
|
82
|
-
assert torch.is_tensor(dataset[0]['state'])
|
|
83
|
-
|
|
84
|
-
dataloader = replay_buffer.dataloader(batch_size = 3)
|
|
85
|
-
|
|
86
|
-
assert next(iter(dataloader))['state'].shape[0] == 3
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|