locoformer 0.0.11__tar.gz → 0.0.30__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {locoformer-0.0.11 → locoformer-0.0.30}/PKG-INFO +3 -2
- {locoformer-0.0.11 → locoformer-0.0.30}/README.md +1 -1
- {locoformer-0.0.11 → locoformer-0.0.30}/locoformer/locoformer.py +441 -59
- {locoformer-0.0.11 → locoformer-0.0.30}/pyproject.toml +2 -1
- {locoformer-0.0.11 → locoformer-0.0.30}/tests/test_locoformer.py +66 -5
- {locoformer-0.0.11 → locoformer-0.0.30}/train.py +2 -2
- {locoformer-0.0.11 → locoformer-0.0.30}/train_gym.py +95 -27
- {locoformer-0.0.11 → locoformer-0.0.30}/.github/workflows/python-publish.yml +0 -0
- {locoformer-0.0.11 → locoformer-0.0.30}/.github/workflows/test.yml +0 -0
- {locoformer-0.0.11 → locoformer-0.0.30}/.gitignore +0 -0
- {locoformer-0.0.11 → locoformer-0.0.30}/LICENSE +0 -0
- {locoformer-0.0.11 → locoformer-0.0.30}/data/README.md +0 -0
- {locoformer-0.0.11 → locoformer-0.0.30}/data/enwik8.gz +0 -0
- {locoformer-0.0.11 → locoformer-0.0.30}/fig3.png +0 -0
- {locoformer-0.0.11 → locoformer-0.0.30}/locoformer/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: locoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.30
|
|
4
4
|
Summary: LocoFormer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/locoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/locoformer
|
|
@@ -38,6 +38,7 @@ Requires-Dist: assoc-scan
|
|
|
38
38
|
Requires-Dist: beartype
|
|
39
39
|
Requires-Dist: einops>=0.8.0
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
|
+
Requires-Dist: hl-gauss-pytorch>=0.2.0
|
|
41
42
|
Requires-Dist: rotary-embedding-torch
|
|
42
43
|
Requires-Dist: torch>=2.4
|
|
43
44
|
Requires-Dist: x-mlps-pytorch
|
|
@@ -54,7 +55,7 @@ Description-Content-Type: text/markdown
|
|
|
54
55
|
|
|
55
56
|
[LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
|
|
56
57
|
|
|
57
|
-
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment)
|
|
58
|
+
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) and extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
58
59
|
|
|
59
60
|
## Sponsors
|
|
60
61
|
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
[LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
|
|
6
6
|
|
|
7
|
-
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment)
|
|
7
|
+
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) and extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
8
8
|
|
|
9
9
|
## Sponsors
|
|
10
10
|
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
2
3
|
from functools import partial
|
|
3
4
|
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from contextlib import contextmanager
|
|
7
|
+
from collections import namedtuple
|
|
6
8
|
|
|
7
9
|
import numpy as np
|
|
8
10
|
from numpy import ndarray
|
|
@@ -15,8 +17,9 @@ import torch
|
|
|
15
17
|
from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
|
|
16
18
|
import torch.nn.functional as F
|
|
17
19
|
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
|
|
18
|
-
from torch.utils._pytree import tree_map
|
|
20
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
19
21
|
from torch.utils.data import Dataset, DataLoader
|
|
22
|
+
from torch.optim import Optimizer
|
|
20
23
|
|
|
21
24
|
import einx
|
|
22
25
|
from einops import rearrange, einsum
|
|
@@ -24,10 +27,16 @@ from einops.layers.torch import Rearrange
|
|
|
24
27
|
|
|
25
28
|
from rotary_embedding_torch import RotaryEmbedding
|
|
26
29
|
|
|
30
|
+
from hl_gauss_pytorch import HLGaussLoss
|
|
31
|
+
|
|
27
32
|
from assoc_scan import AssocScan
|
|
28
33
|
|
|
34
|
+
# constants
|
|
35
|
+
|
|
29
36
|
LinearNoBias = partial(Linear, bias = False)
|
|
30
37
|
|
|
38
|
+
Cache = namedtuple('Cache', ('curr_timestep', 'kv_cache')) # (int, Tensor)
|
|
39
|
+
|
|
31
40
|
# helper functions
|
|
32
41
|
|
|
33
42
|
def exists(v):
|
|
@@ -39,15 +48,23 @@ def default(v, d):
|
|
|
39
48
|
def first(arr):
|
|
40
49
|
return arr[0]
|
|
41
50
|
|
|
51
|
+
def xnor(x, y):
|
|
52
|
+
return not (x ^ y)
|
|
53
|
+
|
|
42
54
|
def divisible_by(num, den):
|
|
43
55
|
return (num % den) == 0
|
|
44
56
|
|
|
57
|
+
# tensor helpers
|
|
58
|
+
|
|
59
|
+
def log(t, eps = 1e-20):
|
|
60
|
+
return t.clamp_min(eps).log()
|
|
61
|
+
|
|
62
|
+
def is_empty(t):
|
|
63
|
+
return t.numel() == 0
|
|
64
|
+
|
|
45
65
|
def tree_map_tensor(x, fn):
|
|
46
66
|
return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
|
|
47
67
|
|
|
48
|
-
def detach_all(x):
|
|
49
|
-
return tree_map_tensor(x, lambda t: t.detach())
|
|
50
|
-
|
|
51
68
|
def pad_at_dim(
|
|
52
69
|
t,
|
|
53
70
|
pad: tuple[int, int],
|
|
@@ -61,13 +78,20 @@ def pad_at_dim(
|
|
|
61
78
|
zeros = ((0, 0) * dims_from_right)
|
|
62
79
|
return F.pad(t, (*zeros, *pad), value = value)
|
|
63
80
|
|
|
81
|
+
def normalize(t, eps = 1e-5):
|
|
82
|
+
return (t - t.mean()) / t.std().clamp_min(eps)
|
|
83
|
+
|
|
84
|
+
def calc_entropy(logits):
|
|
85
|
+
prob = logits.softmax(dim = -1)
|
|
86
|
+
return -(prob * log(prob)).sum(dim = -1)
|
|
87
|
+
|
|
64
88
|
# generalized advantage estimate
|
|
65
89
|
|
|
66
90
|
@torch.no_grad()
|
|
67
91
|
def calc_gae(
|
|
68
92
|
rewards,
|
|
69
93
|
values,
|
|
70
|
-
masks,
|
|
94
|
+
masks = None,
|
|
71
95
|
gamma = 0.99,
|
|
72
96
|
lam = 0.95,
|
|
73
97
|
use_accelerated = None
|
|
@@ -78,6 +102,9 @@ def calc_gae(
|
|
|
78
102
|
values = F.pad(values, (0, 1), value = 0.)
|
|
79
103
|
values, values_next = values[..., :-1], values[..., 1:]
|
|
80
104
|
|
|
105
|
+
if not exists(masks):
|
|
106
|
+
masks = torch.ones_like(values)
|
|
107
|
+
|
|
81
108
|
delta = rewards + gamma * values_next * masks - values
|
|
82
109
|
gates = gamma * lam * masks
|
|
83
110
|
|
|
@@ -87,7 +114,7 @@ def calc_gae(
|
|
|
87
114
|
|
|
88
115
|
returns = gae + values
|
|
89
116
|
|
|
90
|
-
return returns
|
|
117
|
+
return gae, returns
|
|
91
118
|
|
|
92
119
|
# transformer-xl mask w/ flex attn
|
|
93
120
|
|
|
@@ -129,8 +156,8 @@ def create_xl_mask(
|
|
|
129
156
|
# handle intra-episodic attention if needed
|
|
130
157
|
|
|
131
158
|
if exists(episode_ids):
|
|
132
|
-
q_episode =
|
|
133
|
-
k_episode =
|
|
159
|
+
q_episode = episode_ids[b, q + offset]
|
|
160
|
+
k_episode = episode_ids[b, k]
|
|
134
161
|
|
|
135
162
|
intra_episode_mask = q_episode == k_episode
|
|
136
163
|
mask = mask & intra_episode_mask
|
|
@@ -231,12 +258,63 @@ class ReplayDataset(Dataset):
|
|
|
231
258
|
|
|
232
259
|
episode_len = self.episode_lens[episode_index]
|
|
233
260
|
|
|
234
|
-
data = {field:
|
|
261
|
+
data = {field: from_numpy(memmap[episode_index, :episode_len].copy()) for field, memmap in self.memmaps.items()}
|
|
235
262
|
|
|
236
263
|
data['_lens'] = tensor(episode_len)
|
|
237
264
|
|
|
238
265
|
return data
|
|
239
266
|
|
|
267
|
+
class RemappedReplayDataset(Dataset):
|
|
268
|
+
def __init__(
|
|
269
|
+
self,
|
|
270
|
+
dataset: ReplayDataset,
|
|
271
|
+
episode_mapping: Tensor | list[list[int]],
|
|
272
|
+
shuffle_episodes = False
|
|
273
|
+
):
|
|
274
|
+
assert len(dataset) > 0
|
|
275
|
+
self.dataset = dataset
|
|
276
|
+
|
|
277
|
+
if is_tensor(episode_mapping):
|
|
278
|
+
assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
|
|
279
|
+
episode_mapping = episode_mapping.tolist()
|
|
280
|
+
|
|
281
|
+
self.episode_mapping = episode_mapping
|
|
282
|
+
self.shuffle_episodes = shuffle_episodes
|
|
283
|
+
|
|
284
|
+
def __len__(self):
|
|
285
|
+
return len(self.episode_mapping)
|
|
286
|
+
|
|
287
|
+
def __getitem__(self, idx):
|
|
288
|
+
|
|
289
|
+
episode_indices = self.episode_mapping[idx]
|
|
290
|
+
|
|
291
|
+
episode_indices = tensor(episode_indices)
|
|
292
|
+
episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
|
|
293
|
+
|
|
294
|
+
assert not is_empty(episode_indices)
|
|
295
|
+
|
|
296
|
+
if self.shuffle_episodes and episode_indices.numel() > 1:
|
|
297
|
+
num_episodes = len(episode_indices)
|
|
298
|
+
episode_indices = episode_indices[torch.randperm(num_episodes)]
|
|
299
|
+
|
|
300
|
+
episode_data = [self.dataset[i] for i in episode_indices.tolist()]
|
|
301
|
+
|
|
302
|
+
episode_lens = stack([data.pop('_lens') for data in episode_data])
|
|
303
|
+
|
|
304
|
+
keys = first(episode_data).keys()
|
|
305
|
+
|
|
306
|
+
values = [list(data.values()) for data in episode_data]
|
|
307
|
+
|
|
308
|
+
values = [cat(field_values) for field_values in zip(*values)] # concat across time
|
|
309
|
+
|
|
310
|
+
multi_episode_data = dict(zip(keys, values))
|
|
311
|
+
|
|
312
|
+
multi_episode_data['_lens'] = episode_lens.sum()
|
|
313
|
+
|
|
314
|
+
multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
|
|
315
|
+
|
|
316
|
+
return multi_episode_data
|
|
317
|
+
|
|
240
318
|
class ReplayBuffer:
|
|
241
319
|
|
|
242
320
|
@beartype
|
|
@@ -299,6 +377,16 @@ class ReplayBuffer:
|
|
|
299
377
|
self.shapes[field_name] = shape
|
|
300
378
|
self.dtypes[field_name] = dtype
|
|
301
379
|
|
|
380
|
+
self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
|
|
381
|
+
|
|
382
|
+
def __len__(self):
|
|
383
|
+
return (self.episode_lens > 0).sum().item()
|
|
384
|
+
|
|
385
|
+
def reset_(self):
|
|
386
|
+
self.episode_lens[:] = 0
|
|
387
|
+
self.episode_index = 0
|
|
388
|
+
self.timestep_index = 0
|
|
389
|
+
|
|
302
390
|
def advance_episode(self):
|
|
303
391
|
self.episode_index = (self.episode_index + 1) % self.max_episodes
|
|
304
392
|
self.timestep_index = 0
|
|
@@ -353,15 +441,93 @@ class ReplayBuffer:
|
|
|
353
441
|
|
|
354
442
|
self.timestep_index += 1
|
|
355
443
|
|
|
356
|
-
|
|
444
|
+
return self.memory_namedtuple(**data)
|
|
445
|
+
|
|
446
|
+
def dataset(
|
|
447
|
+
self,
|
|
448
|
+
episode_mapping: Tensor | list[list[int]] | None = None,
|
|
449
|
+
) -> Dataset:
|
|
357
450
|
self.flush()
|
|
358
451
|
|
|
359
|
-
|
|
452
|
+
dataset = ReplayDataset(self.folder)
|
|
453
|
+
|
|
454
|
+
if not exists(episode_mapping):
|
|
455
|
+
return dataset
|
|
360
456
|
|
|
361
|
-
|
|
457
|
+
return RemappedReplayDataset(dataset, episode_mapping)
|
|
458
|
+
|
|
459
|
+
def dataloader(
|
|
460
|
+
self,
|
|
461
|
+
batch_size,
|
|
462
|
+
episode_mapping: Tensor | list[list[int]] | None = None,
|
|
463
|
+
**kwargs
|
|
464
|
+
) -> DataLoader:
|
|
362
465
|
self.flush()
|
|
363
466
|
|
|
364
|
-
return DataLoader(self.dataset(), collate_fn = collate_var_time, **kwargs)
|
|
467
|
+
return DataLoader(self.dataset(episode_mapping), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
|
|
468
|
+
|
|
469
|
+
# normalization + conditioning (needed for the commands to the robot)
|
|
470
|
+
|
|
471
|
+
class MaybeAdaRMSNormWrapper(Module):
|
|
472
|
+
def __init__(
|
|
473
|
+
self,
|
|
474
|
+
fn: Module,
|
|
475
|
+
dim,
|
|
476
|
+
dim_cond = None
|
|
477
|
+
):
|
|
478
|
+
super().__init__()
|
|
479
|
+
condition = exists(dim_cond)
|
|
480
|
+
|
|
481
|
+
self.fn = fn
|
|
482
|
+
self.norm = nn.RMSNorm(dim, elementwise_affine = not condition)
|
|
483
|
+
|
|
484
|
+
self.accept_condition = condition
|
|
485
|
+
|
|
486
|
+
if condition:
|
|
487
|
+
self.to_gamma = LinearNoBias(dim_cond, dim)
|
|
488
|
+
self.to_ada_norm_zero = nn.Linear(dim_cond, dim)
|
|
489
|
+
|
|
490
|
+
nn.init.zeros_(self.to_gamma.weight, 0.)
|
|
491
|
+
nn.init.zeros_(self.to_ada_norm_zero.weight, 0.)
|
|
492
|
+
nn.init.constant_(self.to_ada_norm_zero.bias, -5.)
|
|
493
|
+
|
|
494
|
+
def forward(
|
|
495
|
+
self,
|
|
496
|
+
x,
|
|
497
|
+
cond = None,
|
|
498
|
+
**kwargs
|
|
499
|
+
):
|
|
500
|
+
|
|
501
|
+
need_cond = self.accept_condition
|
|
502
|
+
assert xnor(exists(cond), need_cond)
|
|
503
|
+
|
|
504
|
+
prenormed = self.norm(x)
|
|
505
|
+
|
|
506
|
+
if need_cond:
|
|
507
|
+
if cond.ndim == 2:
|
|
508
|
+
cond = rearrange(cond, 'b d -> b 1 d')
|
|
509
|
+
|
|
510
|
+
scale_in = self.to_gamma(cond)
|
|
511
|
+
prenormed = prenormed * (scale_in + 1.)
|
|
512
|
+
|
|
513
|
+
all_fn_out = self.fn(prenormed, **kwargs)
|
|
514
|
+
|
|
515
|
+
if not need_cond:
|
|
516
|
+
return all_fn_out
|
|
517
|
+
|
|
518
|
+
# function may return multiple args
|
|
519
|
+
|
|
520
|
+
(out, *rest), tree_spec = tree_flatten(all_fn_out)
|
|
521
|
+
|
|
522
|
+
if need_cond:
|
|
523
|
+
scale_out = self.to_ada_norm_zero(cond).sigmoid()
|
|
524
|
+
out = out * scale_out
|
|
525
|
+
|
|
526
|
+
# restore
|
|
527
|
+
|
|
528
|
+
all_fn_out = tree_unflatten((out, *rest), tree_spec)
|
|
529
|
+
|
|
530
|
+
return all_fn_out
|
|
365
531
|
|
|
366
532
|
# transformer-xl with ppo
|
|
367
533
|
|
|
@@ -372,15 +538,12 @@ class Attention(Module):
|
|
|
372
538
|
window_size,
|
|
373
539
|
dim_head = 64,
|
|
374
540
|
heads = 8,
|
|
375
|
-
pre_rmsnorm = True,
|
|
376
541
|
fixed_window_size = False,
|
|
377
542
|
accept_value_residual = False
|
|
378
543
|
):
|
|
379
544
|
super().__init__()
|
|
380
545
|
self.scale = dim_head ** -0.5
|
|
381
546
|
|
|
382
|
-
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
|
383
|
-
|
|
384
547
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
|
385
548
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
386
549
|
|
|
@@ -421,12 +584,9 @@ class Attention(Module):
|
|
|
421
584
|
return_kv_cache = False,
|
|
422
585
|
):
|
|
423
586
|
seq_len = tokens.shape[-2]
|
|
424
|
-
assert seq_len <= self.window_size
|
|
425
587
|
|
|
426
588
|
device = tokens.device
|
|
427
589
|
|
|
428
|
-
tokens = self.norm(tokens)
|
|
429
|
-
|
|
430
590
|
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
|
|
431
591
|
|
|
432
592
|
q, k, v = map(self.split_heads, (q, k, v))
|
|
@@ -515,19 +675,24 @@ class TransformerXL(Module):
|
|
|
515
675
|
dim_head = 64,
|
|
516
676
|
heads = 8,
|
|
517
677
|
expansion_factor = 4.,
|
|
678
|
+
dim_cond = None,
|
|
518
679
|
final_norm = True,
|
|
519
680
|
fixed_window_size = False,
|
|
520
681
|
):
|
|
521
682
|
super().__init__()
|
|
522
683
|
|
|
684
|
+
condition = exists(dim_cond)
|
|
685
|
+
|
|
686
|
+
norm_fn = partial(MaybeAdaRMSNormWrapper, dim = dim, dim_cond = dim_cond)
|
|
687
|
+
|
|
523
688
|
layers = ModuleList([])
|
|
524
689
|
|
|
525
690
|
for i in range(depth):
|
|
526
691
|
is_first = i == 0
|
|
527
692
|
|
|
528
|
-
attn = Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first)
|
|
693
|
+
attn = norm_fn(Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first))
|
|
529
694
|
|
|
530
|
-
ff = FeedForward(dim = dim, expansion_factor = expansion_factor)
|
|
695
|
+
ff = norm_fn(FeedForward(dim = dim, expansion_factor = expansion_factor))
|
|
531
696
|
|
|
532
697
|
layers.append(ModuleList([
|
|
533
698
|
attn, ff
|
|
@@ -582,7 +747,21 @@ class Locoformer(Module):
|
|
|
582
747
|
embedder: Module,
|
|
583
748
|
unembedder: Module,
|
|
584
749
|
transformer: dict | TransformerXL,
|
|
585
|
-
|
|
750
|
+
discount_factor = 0.999,
|
|
751
|
+
gae_lam = 0.95,
|
|
752
|
+
ppo_eps_clip = 0.2,
|
|
753
|
+
ppo_entropy_weight = 0.01,
|
|
754
|
+
ppo_value_clip = 0.4,
|
|
755
|
+
dim_value_input = None, # needs to be set for value network to be available
|
|
756
|
+
value_network: Module = nn.Identity(),
|
|
757
|
+
reward_range: tuple[float, float] | None = None,
|
|
758
|
+
reward_shaping_fns: list[Callable[[Tensor], float | Tensor]] | None = None,
|
|
759
|
+
num_reward_bins = 32,
|
|
760
|
+
hl_gauss_loss_kwargs = dict(),
|
|
761
|
+
value_loss_weight = 0.5,
|
|
762
|
+
calc_gae_kwargs: dict = dict(),
|
|
763
|
+
recurrent_kv_cache = True,
|
|
764
|
+
use_spo = False # simple policy optimization https://arxiv.org/abs/2401.16025 - Levine's group (PI) verified it is more stable than PPO
|
|
586
765
|
):
|
|
587
766
|
super().__init__()
|
|
588
767
|
|
|
@@ -594,11 +773,58 @@ class Locoformer(Module):
|
|
|
594
773
|
self.embedder = embedder
|
|
595
774
|
self.unembedder = unembedder
|
|
596
775
|
|
|
597
|
-
self.value_network = value_network
|
|
598
|
-
|
|
599
776
|
self.fixed_window_size = transformer.fixed_window_size
|
|
600
777
|
self.window_size = transformer.window_size
|
|
601
778
|
|
|
779
|
+
# determine value network, using HL Gauss Layer
|
|
780
|
+
|
|
781
|
+
self.to_value_pred = None
|
|
782
|
+
|
|
783
|
+
if exists(dim_value_input):
|
|
784
|
+
assert exists(reward_range)
|
|
785
|
+
|
|
786
|
+
self.to_value_pred = nn.Sequential(
|
|
787
|
+
value_network,
|
|
788
|
+
LinearNoBias(dim_value_input, num_reward_bins)
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
reward_min, reward_max = reward_range
|
|
792
|
+
|
|
793
|
+
self.hl_gauss_loss = HLGaussLoss(
|
|
794
|
+
min_value = reward_min,
|
|
795
|
+
max_value = reward_max,
|
|
796
|
+
num_bins = num_reward_bins,
|
|
797
|
+
**hl_gauss_loss_kwargs
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
# ppo related
|
|
801
|
+
|
|
802
|
+
self.discount_factor = discount_factor
|
|
803
|
+
self.gae_lam = gae_lam
|
|
804
|
+
self.ppo_eps_clip = ppo_eps_clip
|
|
805
|
+
self.ppo_entropy_weight = ppo_entropy_weight
|
|
806
|
+
self.ppo_value_clip = ppo_value_clip
|
|
807
|
+
self.value_loss_weight = value_loss_weight
|
|
808
|
+
|
|
809
|
+
self.calc_gae_kwargs = calc_gae_kwargs
|
|
810
|
+
|
|
811
|
+
# maybe use spo
|
|
812
|
+
|
|
813
|
+
self.use_spo = use_spo
|
|
814
|
+
|
|
815
|
+
# maybe recurrent kv cache, from Ding et al. https://arxiv.org/abs/2012.15688
|
|
816
|
+
|
|
817
|
+
self.recurrent_kv_cache = recurrent_kv_cache
|
|
818
|
+
|
|
819
|
+
# reward shaping function
|
|
820
|
+
|
|
821
|
+
self.has_reward_shaping = exists(reward_shaping_fns)
|
|
822
|
+
self.reward_shaping_fns = reward_shaping_fns
|
|
823
|
+
|
|
824
|
+
# loss related
|
|
825
|
+
|
|
826
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
827
|
+
|
|
602
828
|
@property
|
|
603
829
|
def device(self):
|
|
604
830
|
return next(self.parameters()).device
|
|
@@ -607,33 +833,163 @@ class Locoformer(Module):
|
|
|
607
833
|
return self.unembedder.parameters()
|
|
608
834
|
|
|
609
835
|
def critic_parameters(self):
|
|
610
|
-
if not exists(self.
|
|
836
|
+
if not exists(self.to_value_pred):
|
|
611
837
|
return []
|
|
612
838
|
|
|
613
|
-
return self.
|
|
839
|
+
return self.to_value_pred.parameters()
|
|
840
|
+
|
|
841
|
+
def ppo(
|
|
842
|
+
self,
|
|
843
|
+
state,
|
|
844
|
+
action,
|
|
845
|
+
old_action_log_prob,
|
|
846
|
+
reward,
|
|
847
|
+
old_value,
|
|
848
|
+
mask,
|
|
849
|
+
episode_lens,
|
|
850
|
+
actor_optim: Optimizer | None = None,
|
|
851
|
+
critic_optim: Optimizer | None = None
|
|
852
|
+
):
|
|
853
|
+
window_size = self.window_size
|
|
854
|
+
total_learnable_tokens = mask.sum().item()
|
|
855
|
+
|
|
856
|
+
seq_len = state.shape[1]
|
|
857
|
+
gae_mask = einx.less('j, i -> i j', arange(seq_len, device = self.device), episode_lens)
|
|
858
|
+
|
|
859
|
+
advantage, returns = calc_gae(reward, old_value, masks = gae_mask, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
|
|
860
|
+
|
|
861
|
+
advantage = normalize(advantage)
|
|
862
|
+
|
|
863
|
+
windowed_tensors = [
|
|
864
|
+
t.split(window_size, dim = 1) for t in
|
|
865
|
+
(
|
|
866
|
+
state,
|
|
867
|
+
action,
|
|
868
|
+
old_action_log_prob,
|
|
869
|
+
reward,
|
|
870
|
+
old_value,
|
|
871
|
+
mask,
|
|
872
|
+
advantage,
|
|
873
|
+
returns
|
|
874
|
+
)
|
|
875
|
+
]
|
|
876
|
+
|
|
877
|
+
mean_actor_loss = self.zero.clone()
|
|
878
|
+
mean_critic_loss = self.zero.clone()
|
|
879
|
+
|
|
880
|
+
# learn across windows
|
|
881
|
+
|
|
882
|
+
cache = None
|
|
883
|
+
|
|
884
|
+
for (
|
|
885
|
+
state,
|
|
886
|
+
action,
|
|
887
|
+
old_action_log_prob,
|
|
888
|
+
reward,
|
|
889
|
+
old_value,
|
|
890
|
+
mask,
|
|
891
|
+
advantage,
|
|
892
|
+
returns
|
|
893
|
+
) in zip(*windowed_tensors):
|
|
894
|
+
|
|
895
|
+
(action_logits, value_logits), cache = self.forward(state, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True)
|
|
896
|
+
entropy = calc_entropy(action_logits)
|
|
897
|
+
|
|
898
|
+
action = rearrange(action, 'b t -> b t 1')
|
|
899
|
+
log_prob = action_logits.gather(-1, action)
|
|
900
|
+
log_prob = rearrange(log_prob, 'b t 1 -> b t')
|
|
901
|
+
|
|
902
|
+
# update actor, classic clipped surrogate loss
|
|
903
|
+
|
|
904
|
+
eps_clip = self.ppo_eps_clip
|
|
905
|
+
ratio = (log_prob - old_action_log_prob).exp()
|
|
906
|
+
|
|
907
|
+
if self.use_spo:
|
|
908
|
+
actor_loss = -(ratio * advantage - (advantage.abs() * (ratio - 1.).square()) / (2 * eps_clip))
|
|
909
|
+
else:
|
|
910
|
+
actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
|
|
911
|
+
|
|
912
|
+
actor_loss = actor_loss - self.ppo_entropy_weight * entropy
|
|
913
|
+
|
|
914
|
+
windowed_actor_loss = actor_loss[mask].sum() / total_learnable_tokens
|
|
915
|
+
windowed_actor_loss.backward(retain_graph = True)
|
|
916
|
+
|
|
917
|
+
# update critic
|
|
918
|
+
|
|
919
|
+
value_loss = self.hl_gauss_loss(value_logits, returns, reduction = 'none')
|
|
920
|
+
|
|
921
|
+
value_clip = self.ppo_value_clip
|
|
922
|
+
value = self.hl_gauss_loss(value_logits)
|
|
923
|
+
|
|
924
|
+
clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
|
|
925
|
+
clipped_value_loss = self.hl_gauss_loss(clipped_value, returns, reduction = 'none')
|
|
926
|
+
|
|
927
|
+
critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
|
|
928
|
+
|
|
929
|
+
windowed_critic_loss = critic_loss[mask].sum() / total_learnable_tokens
|
|
930
|
+
windowed_critic_loss.backward(retain_graph = True)
|
|
931
|
+
|
|
932
|
+
# accumulate
|
|
933
|
+
|
|
934
|
+
mean_actor_loss.add_(windowed_actor_loss)
|
|
935
|
+
mean_critic_loss.add_(windowed_critic_loss)
|
|
936
|
+
|
|
937
|
+
# optimizer update
|
|
938
|
+
|
|
939
|
+
if exists(actor_optim):
|
|
940
|
+
actor_optim.step()
|
|
941
|
+
actor_optim.zero_grad()
|
|
942
|
+
|
|
943
|
+
if exists(critic_optim):
|
|
944
|
+
critic_optim.step()
|
|
945
|
+
critic_optim.zero_grad()
|
|
946
|
+
|
|
947
|
+
# return losses for logging
|
|
948
|
+
|
|
949
|
+
return mean_actor_loss.detach(), mean_critic_loss.detach()
|
|
950
|
+
|
|
951
|
+
def state_to_rewards(
|
|
952
|
+
self,
|
|
953
|
+
state
|
|
954
|
+
) -> Tensor:
|
|
955
|
+
|
|
956
|
+
assert self.has_reward_shaping
|
|
957
|
+
|
|
958
|
+
rewards = [fn(state) for fn in self.reward_shaping_fns]
|
|
959
|
+
|
|
960
|
+
rewards = [tensor(reward) if not is_tensor(reward) else reward for reward in rewards]
|
|
961
|
+
return stack(rewards)
|
|
614
962
|
|
|
615
963
|
def wrap_env_functions(self, env):
|
|
616
964
|
|
|
617
|
-
def
|
|
618
|
-
|
|
965
|
+
def transform_output(el):
|
|
966
|
+
if isinstance(el, ndarray):
|
|
967
|
+
return from_numpy(el)
|
|
968
|
+
elif isinstance(el, (int, bool, float)):
|
|
969
|
+
return tensor(el)
|
|
970
|
+
else:
|
|
971
|
+
return el
|
|
619
972
|
|
|
620
|
-
|
|
621
|
-
|
|
973
|
+
def wrapped_reset(*args, **kwargs):
|
|
974
|
+
env_reset_out = env.reset(*args, **kwargs)
|
|
622
975
|
|
|
623
|
-
return
|
|
976
|
+
return tree_map(transform_output, env_reset_out)
|
|
624
977
|
|
|
625
978
|
def wrapped_step(action, *args, **kwargs):
|
|
626
|
-
out = env.step(action.item(), *args, **kwargs)
|
|
627
979
|
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
return el
|
|
980
|
+
if is_tensor(action):
|
|
981
|
+
action = action.item()
|
|
982
|
+
|
|
983
|
+
env_step_out = env.step(action, *args, **kwargs)
|
|
984
|
+
|
|
985
|
+
env_step_out_torch = tree_map(transform_output, env_step_out)
|
|
635
986
|
|
|
636
|
-
|
|
987
|
+
if not self.has_reward_shaping:
|
|
988
|
+
return env_step_out_torch
|
|
989
|
+
|
|
990
|
+
shaped_rewards = self.state_to_rewards(env_step_out_torch)
|
|
991
|
+
|
|
992
|
+
return env_step_out_torch, shaped_rewards
|
|
637
993
|
|
|
638
994
|
return wrapped_reset, wrapped_step
|
|
639
995
|
|
|
@@ -643,6 +999,7 @@ class Locoformer(Module):
|
|
|
643
999
|
inference_mode = False,
|
|
644
1000
|
has_batch_dim = False,
|
|
645
1001
|
has_time_dim = False,
|
|
1002
|
+
state_time_dim = 1,
|
|
646
1003
|
**kwargs
|
|
647
1004
|
):
|
|
648
1005
|
window_size = self.window_size
|
|
@@ -658,23 +1015,16 @@ class Locoformer(Module):
|
|
|
658
1015
|
state = rearrange(state, '... -> 1 ...')
|
|
659
1016
|
|
|
660
1017
|
if not has_time_dim:
|
|
661
|
-
state =
|
|
1018
|
+
state = state.unsqueeze(state_time_dim)
|
|
662
1019
|
|
|
663
1020
|
# forwards
|
|
664
1021
|
|
|
665
1022
|
out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
|
|
666
1023
|
|
|
667
|
-
# handle cache
|
|
668
|
-
|
|
669
|
-
cache_len = cache.shape[-2]
|
|
670
|
-
|
|
671
|
-
if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
|
|
672
|
-
cache = cache[..., -window_size:, :]
|
|
673
|
-
|
|
674
1024
|
# maybe remove batch or time
|
|
675
1025
|
|
|
676
1026
|
if not has_time_dim:
|
|
677
|
-
out = tree_map_tensor(out, lambda t:
|
|
1027
|
+
out = tree_map_tensor(out, lambda t: t.squeeze(state_time_dim))
|
|
678
1028
|
|
|
679
1029
|
if not has_batch_dim:
|
|
680
1030
|
out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
|
|
@@ -703,16 +1053,35 @@ class Locoformer(Module):
|
|
|
703
1053
|
def forward(
|
|
704
1054
|
self,
|
|
705
1055
|
state: Tensor,
|
|
706
|
-
cache:
|
|
1056
|
+
cache: Cache | None = None,
|
|
707
1057
|
detach_cache = False,
|
|
708
|
-
return_values = False
|
|
1058
|
+
return_values = False,
|
|
1059
|
+
return_raw_value_logits = False
|
|
709
1060
|
):
|
|
710
1061
|
|
|
711
1062
|
state = state.to(self.device)
|
|
712
1063
|
|
|
713
1064
|
tokens = self.embedder(state)
|
|
714
1065
|
|
|
715
|
-
|
|
1066
|
+
# time
|
|
1067
|
+
|
|
1068
|
+
time = tokens.shape[-2]
|
|
1069
|
+
|
|
1070
|
+
# destruct the cache for the current timestep and the cache
|
|
1071
|
+
|
|
1072
|
+
prev_kv_cache = None
|
|
1073
|
+
timestep_start = 0
|
|
1074
|
+
|
|
1075
|
+
if exists(cache):
|
|
1076
|
+
timestep_start, prev_kv_cache = cache
|
|
1077
|
+
|
|
1078
|
+
# an assert - make sure during training or inference, forward never gets anything that crosses the window segment boundary, to open up some possibilities with extending memory
|
|
1079
|
+
|
|
1080
|
+
assert ((timestep_start % self.window_size) + time) <= self.window_size
|
|
1081
|
+
|
|
1082
|
+
# attention
|
|
1083
|
+
|
|
1084
|
+
embed, kv_cache = self.transformer(tokens, cache = prev_kv_cache, return_kv_cache = True)
|
|
716
1085
|
|
|
717
1086
|
# unembed to actions - in language models this would be the next state
|
|
718
1087
|
|
|
@@ -723,21 +1092,34 @@ class Locoformer(Module):
|
|
|
723
1092
|
# maybe detach cache
|
|
724
1093
|
|
|
725
1094
|
if detach_cache:
|
|
726
|
-
kv_cache =
|
|
1095
|
+
kv_cache = kv_cache.detach()
|
|
727
1096
|
|
|
728
1097
|
# handle returning of values
|
|
729
1098
|
|
|
730
1099
|
if return_values:
|
|
731
|
-
assert exists(self.
|
|
1100
|
+
assert exists(self.to_value_pred)
|
|
732
1101
|
|
|
733
|
-
values = self.
|
|
1102
|
+
values = self.to_value_pred(embed)
|
|
734
1103
|
|
|
735
|
-
if
|
|
736
|
-
|
|
737
|
-
values = rearrange(values, '... 1 -> ...')
|
|
1104
|
+
if not return_raw_value_logits:
|
|
1105
|
+
values = self.hl_gauss_loss(values) # converts the value logits to scalar values
|
|
738
1106
|
|
|
739
1107
|
out = (out, values)
|
|
740
1108
|
|
|
741
1109
|
# output and cache
|
|
742
1110
|
|
|
743
|
-
|
|
1111
|
+
next_timestep = time + timestep_start
|
|
1112
|
+
|
|
1113
|
+
# handle curtailing kv cache at the right intervals
|
|
1114
|
+
|
|
1115
|
+
window_size = self.window_size
|
|
1116
|
+
|
|
1117
|
+
if self.fixed_window_size or divisible_by(next_timestep, window_size * 2):
|
|
1118
|
+
kv_cache = kv_cache[..., -window_size:, :]
|
|
1119
|
+
|
|
1120
|
+
# maybe recurrent cache - shift the kv cache from one layer above to the one below, for extending on receptive field of past
|
|
1121
|
+
|
|
1122
|
+
if self.recurrent_kv_cache and divisible_by(next_timestep, window_size):
|
|
1123
|
+
kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
|
|
1124
|
+
|
|
1125
|
+
return out, (next_timestep, kv_cache)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "locoformer"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.30"
|
|
4
4
|
description = "LocoFormer"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -30,6 +30,7 @@ dependencies = [
|
|
|
30
30
|
"beartype",
|
|
31
31
|
"einx>=0.3.0",
|
|
32
32
|
"einops>=0.8.0",
|
|
33
|
+
"hl-gauss-pytorch>=0.2.0",
|
|
33
34
|
"rotary-embedding-torch",
|
|
34
35
|
"torch>=2.4",
|
|
35
36
|
"x-mlps-pytorch",
|
|
@@ -2,18 +2,25 @@ import pytest
|
|
|
2
2
|
param = pytest.mark.parametrize
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
+
from torch import nn
|
|
5
6
|
from x_mlps_pytorch import MLP
|
|
6
7
|
|
|
7
8
|
from einops import rearrange
|
|
8
9
|
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
10
|
+
from locoformer.locoformer import Locoformer
|
|
11
|
+
|
|
12
|
+
@param('recurrent_kv_cache', (False, True))
|
|
13
|
+
def test_locoformer(
|
|
14
|
+
recurrent_kv_cache
|
|
15
|
+
):
|
|
12
16
|
|
|
13
17
|
model = Locoformer(
|
|
14
18
|
embedder = nn.Embedding(256, 128),
|
|
15
19
|
unembedder = nn.Linear(128, 256, bias = False),
|
|
16
|
-
value_network = MLP(128,
|
|
20
|
+
value_network = MLP(128, 64, 32),
|
|
21
|
+
dim_value_input = 32,
|
|
22
|
+
reward_range = (-100., 100.),
|
|
23
|
+
recurrent_kv_cache = recurrent_kv_cache,
|
|
17
24
|
transformer = dict(
|
|
18
25
|
dim = 128,
|
|
19
26
|
depth = 1,
|
|
@@ -83,4 +90,58 @@ def test_replay():
|
|
|
83
90
|
|
|
84
91
|
dataloader = replay_buffer.dataloader(batch_size = 3)
|
|
85
92
|
|
|
86
|
-
assert next(iter(dataloader))['state'].shape[0] == 3
|
|
93
|
+
assert next(iter(dataloader))['state'].shape[0] == 3
|
|
94
|
+
|
|
95
|
+
# we will now consider consecutive pairs of episodes as 2 trials to be used for in-context adaptation
|
|
96
|
+
# but realistically there will be a function that converts a given ReplayBuffer -> Int[batch, episode_indices]
|
|
97
|
+
|
|
98
|
+
from torch import stack, arange
|
|
99
|
+
|
|
100
|
+
episode_indices = arange(len(replay_buffer))
|
|
101
|
+
remapped_episodes = stack((episode_indices[:-1], episode_indices[1:]))
|
|
102
|
+
|
|
103
|
+
dataloader = replay_buffer.dataloader(
|
|
104
|
+
batch_size = 1,
|
|
105
|
+
episode_mapping = remapped_episodes
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
assert next(iter(dataloader))['_lens'][0] == (3 + 5) # first and second episodes are concatted together timewise
|
|
109
|
+
|
|
110
|
+
def test_reward_shaping():
|
|
111
|
+
|
|
112
|
+
model = Locoformer(
|
|
113
|
+
embedder = nn.Embedding(256, 128),
|
|
114
|
+
unembedder = nn.Linear(128, 256, bias = False),
|
|
115
|
+
value_network = MLP(128, 64, 32),
|
|
116
|
+
dim_value_input = 32,
|
|
117
|
+
reward_range = (-100., 100.),
|
|
118
|
+
reward_shaping_fns = [
|
|
119
|
+
lambda state: (state[3] - 2.5).pow(2).mean(),
|
|
120
|
+
lambda state: state[4:6].norm(dim = -1)
|
|
121
|
+
],
|
|
122
|
+
transformer = dict(
|
|
123
|
+
dim = 128,
|
|
124
|
+
depth = 1,
|
|
125
|
+
window_size = 512
|
|
126
|
+
)
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
import numpy as np
|
|
130
|
+
|
|
131
|
+
class MockEnv:
|
|
132
|
+
def reset(self):
|
|
133
|
+
return np.random.normal(size = (10,))
|
|
134
|
+
|
|
135
|
+
def step(self, *args, **kwargs):
|
|
136
|
+
return np.random.normal(size = (10,))
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
env = MockEnv()
|
|
140
|
+
|
|
141
|
+
reset_fn, step_fn = model.wrap_env_functions(env)
|
|
142
|
+
|
|
143
|
+
reset_fn()
|
|
144
|
+
|
|
145
|
+
_, rewards = step_fn(3)
|
|
146
|
+
|
|
147
|
+
assert len(rewards) == 2
|
|
@@ -160,7 +160,7 @@ for i in range(NUM_BATCHES):
|
|
|
160
160
|
optim.step()
|
|
161
161
|
optim.zero_grad()
|
|
162
162
|
|
|
163
|
-
if divisible_by(i
|
|
163
|
+
if divisible_by(i, GENERATE_EVERY):
|
|
164
164
|
model.eval()
|
|
165
165
|
|
|
166
166
|
val_seq = next(val_loader_iter)
|
|
@@ -169,7 +169,7 @@ for i in range(NUM_BATCHES):
|
|
|
169
169
|
prime = prime.to(model.device)
|
|
170
170
|
out = prime
|
|
171
171
|
|
|
172
|
-
stateful_forward, logits = model.get_stateful_forward(has_batch_dim = False, initial_states = prime, inference_mode = True)
|
|
172
|
+
stateful_forward, logits = model.get_stateful_forward(has_batch_dim = False, has_time_dim = True, initial_states = prime, inference_mode = True)
|
|
173
173
|
|
|
174
174
|
# sample
|
|
175
175
|
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# "accelerate",
|
|
4
4
|
# "fire",
|
|
5
5
|
# "gymnasium[box2d]>=1.0.0",
|
|
6
|
-
# "locoformer",
|
|
6
|
+
# "locoformer>=0.0.12",
|
|
7
7
|
# "moviepy",
|
|
8
8
|
# "tqdm"
|
|
9
9
|
# ]
|
|
@@ -13,13 +13,14 @@ from fire import Fire
|
|
|
13
13
|
from shutil import rmtree
|
|
14
14
|
from tqdm import tqdm
|
|
15
15
|
from collections import deque
|
|
16
|
+
from types import SimpleNamespace
|
|
16
17
|
|
|
17
18
|
from accelerate import Accelerator
|
|
18
19
|
|
|
19
20
|
import gymnasium as gym
|
|
20
21
|
|
|
21
22
|
import torch
|
|
22
|
-
from torch import from_numpy, randint, tensor, stack
|
|
23
|
+
from torch import from_numpy, randint, tensor, stack, arange
|
|
23
24
|
import torch.nn.functional as F
|
|
24
25
|
from torch.utils.data import TensorDataset, DataLoader
|
|
25
26
|
from torch.optim import Adam
|
|
@@ -47,26 +48,64 @@ def gumbel_sample(logits, temperature = 1., eps = 1e-6):
|
|
|
47
48
|
noise = gumbel_noise(logits)
|
|
48
49
|
return ((logits / max(temperature, eps)) + noise).argmax(dim = -1)
|
|
49
50
|
|
|
51
|
+
# learn
|
|
52
|
+
|
|
53
|
+
def learn(
|
|
54
|
+
model,
|
|
55
|
+
actor_optim,
|
|
56
|
+
critic_optim,
|
|
57
|
+
accelerator,
|
|
58
|
+
replay,
|
|
59
|
+
batch_size = 16,
|
|
60
|
+
epochs = 2,
|
|
61
|
+
):
|
|
62
|
+
dl = replay.dataloader(batch_size = batch_size, shuffle = True)
|
|
63
|
+
model, dl, actor_optim, critic_optim = accelerator.prepare(model, dl, actor_optim, critic_optim)
|
|
64
|
+
|
|
65
|
+
for _ in range(epochs):
|
|
66
|
+
for data in dl:
|
|
67
|
+
|
|
68
|
+
data = SimpleNamespace(**data)
|
|
69
|
+
|
|
70
|
+
actor_loss, critic_loss = model.ppo(
|
|
71
|
+
state = data.state,
|
|
72
|
+
action = data.action,
|
|
73
|
+
old_action_log_prob = data.action_log_prob,
|
|
74
|
+
reward = data.reward,
|
|
75
|
+
old_value = data.value,
|
|
76
|
+
mask = data.learnable,
|
|
77
|
+
episode_lens = data._lens,
|
|
78
|
+
actor_optim = actor_optim,
|
|
79
|
+
critic_optim = critic_optim
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
accelerator.print(f'actor: {actor_loss.item():.3f} | critic: {critic_loss.item():.3f}')
|
|
83
|
+
|
|
50
84
|
# main function
|
|
51
85
|
|
|
52
86
|
def main(
|
|
53
87
|
env_name = 'LunarLander-v3',
|
|
54
88
|
num_episodes = 50_000,
|
|
55
89
|
max_timesteps = 500,
|
|
56
|
-
|
|
90
|
+
num_episodes_before_learn = 64,
|
|
57
91
|
clear_video = True,
|
|
58
92
|
video_folder = 'recordings',
|
|
59
93
|
record_every_episode = 250,
|
|
94
|
+
learning_rate = 8e-4,
|
|
60
95
|
discount_factor = 0.99,
|
|
61
|
-
|
|
96
|
+
betas = (0.9, 0.99),
|
|
97
|
+
gae_lam = 0.95,
|
|
98
|
+
ppo_eps_clip = 0.2,
|
|
99
|
+
ppo_entropy_weight = .01,
|
|
62
100
|
batch_size = 16,
|
|
63
|
-
epochs =
|
|
101
|
+
epochs = 3,
|
|
102
|
+
reward_range = (-100., 100.)
|
|
64
103
|
):
|
|
65
104
|
|
|
66
105
|
# accelerate
|
|
67
106
|
|
|
68
|
-
|
|
69
|
-
device =
|
|
107
|
+
accelerator = Accelerator()
|
|
108
|
+
device = accelerator.device
|
|
70
109
|
|
|
71
110
|
# environment
|
|
72
111
|
|
|
@@ -91,14 +130,15 @@ def main(
|
|
|
91
130
|
replay = ReplayBuffer(
|
|
92
131
|
'replay',
|
|
93
132
|
num_episodes,
|
|
94
|
-
max_timesteps,
|
|
133
|
+
max_timesteps + 1, # one extra node for bootstrap node - not relevant for locoformer, but for completeness
|
|
95
134
|
fields = dict(
|
|
96
135
|
state = ('float', (dim_state,)),
|
|
97
136
|
action = 'int',
|
|
98
137
|
action_log_prob = 'float',
|
|
99
138
|
reward = 'float',
|
|
100
139
|
value = 'float',
|
|
101
|
-
done = 'bool'
|
|
140
|
+
done = 'bool',
|
|
141
|
+
learnable = 'bool'
|
|
102
142
|
)
|
|
103
143
|
)
|
|
104
144
|
|
|
@@ -107,20 +147,30 @@ def main(
|
|
|
107
147
|
locoformer = Locoformer(
|
|
108
148
|
embedder = MLP(dim_state, 64, bias = False),
|
|
109
149
|
unembedder = MLP(64, num_actions, bias = False),
|
|
110
|
-
value_network = MLP(64, 1, bias = False),
|
|
111
150
|
transformer = dict(
|
|
112
151
|
dim = 64,
|
|
113
152
|
dim_head = 32,
|
|
114
153
|
heads = 4,
|
|
115
154
|
depth = 4,
|
|
116
155
|
window_size = 16
|
|
117
|
-
)
|
|
156
|
+
),
|
|
157
|
+
discount_factor = discount_factor,
|
|
158
|
+
gae_lam = gae_lam,
|
|
159
|
+
ppo_eps_clip = ppo_eps_clip,
|
|
160
|
+
ppo_entropy_weight = ppo_entropy_weight,
|
|
161
|
+
use_spo = True,
|
|
162
|
+
value_network = MLP(64, 64),
|
|
163
|
+
dim_value_input = 64,
|
|
164
|
+
reward_range = reward_range,
|
|
165
|
+
hl_gauss_loss_kwargs = dict(),
|
|
166
|
+
recurrent_kv_cache = True,
|
|
167
|
+
calc_gae_kwargs = dict(
|
|
168
|
+
use_accelerated = False
|
|
169
|
+
),
|
|
118
170
|
).to(device)
|
|
119
171
|
|
|
120
|
-
optim_actor = Adam([*locoformer.transformer.parameters(), *locoformer.actor_parameters()], lr = learning_rate)
|
|
121
|
-
optim_critic = Adam([*locoformer.transformer.parameters(), *locoformer.critic_parameters()], lr = learning_rate)
|
|
122
|
-
|
|
123
|
-
timesteps_learn = 0
|
|
172
|
+
optim_actor = Adam([*locoformer.transformer.parameters(), *locoformer.actor_parameters()], lr = learning_rate, betas = betas)
|
|
173
|
+
optim_critic = Adam([*locoformer.transformer.parameters(), *locoformer.critic_parameters()], lr = learning_rate, betas = betas)
|
|
124
174
|
|
|
125
175
|
# able to wrap the env for all values to torch tensors and back
|
|
126
176
|
# all environments should follow usual MDP interface, domain randomization should be given at instantiation
|
|
@@ -129,7 +179,8 @@ def main(
|
|
|
129
179
|
|
|
130
180
|
# loop
|
|
131
181
|
|
|
132
|
-
for
|
|
182
|
+
for episodes_index in tqdm(range(num_episodes)):
|
|
183
|
+
|
|
133
184
|
state, *_ = env_reset()
|
|
134
185
|
|
|
135
186
|
timestep = 0
|
|
@@ -151,42 +202,59 @@ def main(
|
|
|
151
202
|
|
|
152
203
|
# append to memory
|
|
153
204
|
|
|
154
|
-
|
|
205
|
+
exceeds_max_timesteps = timestep == (max_timesteps - 1)
|
|
206
|
+
done = truncated or terminated or tensor(exceeds_max_timesteps)
|
|
155
207
|
|
|
156
208
|
# get log prob of action
|
|
157
209
|
|
|
158
210
|
action_log_prob = action_logits.gather(-1, rearrange(action, '-> 1'))
|
|
159
211
|
action_log_prob = rearrange(action_log_prob, '1 ->')
|
|
160
212
|
|
|
161
|
-
replay.store(
|
|
213
|
+
memory = replay.store(
|
|
162
214
|
state = state,
|
|
163
215
|
action = action,
|
|
164
216
|
action_log_prob = action_log_prob,
|
|
165
217
|
reward = reward,
|
|
166
218
|
value = value,
|
|
167
|
-
done = done
|
|
219
|
+
done = done,
|
|
220
|
+
learnable = tensor(True)
|
|
168
221
|
)
|
|
169
222
|
|
|
170
223
|
# increment counters
|
|
171
224
|
|
|
172
225
|
timestep += 1
|
|
173
|
-
timesteps_learn += 1
|
|
174
226
|
|
|
175
|
-
#
|
|
227
|
+
# break if done or exceed max timestep
|
|
176
228
|
|
|
177
|
-
if
|
|
178
|
-
# todo - carry out learning
|
|
229
|
+
if done:
|
|
179
230
|
|
|
180
|
-
|
|
181
|
-
|
|
231
|
+
# handle bootstrap value, which is a non-learnable timestep added with the next value for GAE
|
|
232
|
+
# only if terminated signal not detected
|
|
182
233
|
|
|
183
|
-
|
|
234
|
+
if not terminated:
|
|
235
|
+
_, next_value = stateful_forward(next_state, return_values = True)
|
|
236
|
+
|
|
237
|
+
memory._replace(value = next_value, learnable = False)
|
|
238
|
+
|
|
239
|
+
replay.store(**memory._asdict())
|
|
184
240
|
|
|
185
|
-
if done or timestep >= max_timesteps:
|
|
186
241
|
break
|
|
187
242
|
|
|
188
243
|
state = next_state
|
|
189
244
|
|
|
245
|
+
# learn if hit the number of learn timesteps
|
|
246
|
+
|
|
247
|
+
if divisible_by(episodes_index + 1, num_episodes_before_learn):
|
|
248
|
+
|
|
249
|
+
learn(
|
|
250
|
+
locoformer,
|
|
251
|
+
optim_actor,
|
|
252
|
+
optim_critic,
|
|
253
|
+
accelerator,
|
|
254
|
+
replay,
|
|
255
|
+
batch_size,
|
|
256
|
+
epochs,
|
|
257
|
+
)
|
|
190
258
|
# main
|
|
191
259
|
|
|
192
260
|
if __name__ == '__main__':
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|