locoformer 0.0.15__py3-none-any.whl → 0.0.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- locoformer/locoformer.py +599 -91
- {locoformer-0.0.15.dist-info → locoformer-0.0.43.dist-info}/METADATA +4 -2
- locoformer-0.0.43.dist-info/RECORD +6 -0
- {locoformer-0.0.15.dist-info → locoformer-0.0.43.dist-info}/WHEEL +1 -1
- locoformer-0.0.15.dist-info/RECORD +0 -6
- {locoformer-0.0.15.dist-info → locoformer-0.0.43.dist-info}/licenses/LICENSE +0 -0
locoformer/locoformer.py
CHANGED
|
@@ -1,10 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from
|
|
2
|
+
from typing import Callable
|
|
3
|
+
from types import SimpleNamespace
|
|
4
|
+
from functools import partial, wraps
|
|
3
5
|
|
|
4
6
|
from pathlib import Path
|
|
5
7
|
from contextlib import contextmanager
|
|
6
8
|
from collections import namedtuple
|
|
7
9
|
|
|
10
|
+
from inspect import signature
|
|
11
|
+
|
|
8
12
|
import numpy as np
|
|
9
13
|
from numpy import ndarray
|
|
10
14
|
from numpy.lib.format import open_memmap
|
|
@@ -16,8 +20,9 @@ import torch
|
|
|
16
20
|
from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
|
|
17
21
|
import torch.nn.functional as F
|
|
18
22
|
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
|
|
19
|
-
from torch.utils._pytree import tree_map
|
|
23
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
20
24
|
from torch.utils.data import Dataset, DataLoader
|
|
25
|
+
from torch.optim import Optimizer
|
|
21
26
|
|
|
22
27
|
import einx
|
|
23
28
|
from einops import rearrange, einsum
|
|
@@ -25,10 +30,20 @@ from einops.layers.torch import Rearrange
|
|
|
25
30
|
|
|
26
31
|
from rotary_embedding_torch import RotaryEmbedding
|
|
27
32
|
|
|
33
|
+
from hl_gauss_pytorch import HLGaussLoss
|
|
34
|
+
|
|
28
35
|
from assoc_scan import AssocScan
|
|
29
36
|
|
|
37
|
+
from x_mlps_pytorch import MLP
|
|
38
|
+
|
|
39
|
+
from x_evolution import EvoStrategy
|
|
40
|
+
|
|
41
|
+
# constants
|
|
42
|
+
|
|
30
43
|
LinearNoBias = partial(Linear, bias = False)
|
|
31
44
|
|
|
45
|
+
Cache = namedtuple('Cache', ('curr_timestep', 'kv_cache')) # (int, Tensor)
|
|
46
|
+
|
|
32
47
|
# helper functions
|
|
33
48
|
|
|
34
49
|
def exists(v):
|
|
@@ -40,20 +55,51 @@ def default(v, d):
|
|
|
40
55
|
def first(arr):
|
|
41
56
|
return arr[0]
|
|
42
57
|
|
|
58
|
+
def xnor(x, y):
|
|
59
|
+
return not (x ^ y)
|
|
60
|
+
|
|
43
61
|
def divisible_by(num, den):
|
|
44
62
|
return (num % den) == 0
|
|
45
63
|
|
|
64
|
+
def get_param_names(fn):
|
|
65
|
+
parameters = signature(fn).parameters
|
|
66
|
+
return list(parameters.keys())
|
|
67
|
+
|
|
68
|
+
def check_has_param_attr(
|
|
69
|
+
param_name,
|
|
70
|
+
param_attr,
|
|
71
|
+
default_value = None
|
|
72
|
+
):
|
|
73
|
+
def decorator(fn):
|
|
74
|
+
sig = signature(fn)
|
|
75
|
+
|
|
76
|
+
@wraps(fn)
|
|
77
|
+
def inner(*args, **kwargs):
|
|
78
|
+
|
|
79
|
+
bound_args = sig.bind(*args, **kwargs).arguments
|
|
80
|
+
|
|
81
|
+
if not (
|
|
82
|
+
param_name in bound_args and
|
|
83
|
+
hasattr(bound_args[param_name], param_attr)
|
|
84
|
+
):
|
|
85
|
+
return default_value
|
|
86
|
+
|
|
87
|
+
return fn(*args, **kwargs)
|
|
88
|
+
|
|
89
|
+
return inner
|
|
90
|
+
return decorator
|
|
91
|
+
|
|
46
92
|
# tensor helpers
|
|
47
93
|
|
|
48
94
|
def log(t, eps = 1e-20):
|
|
49
95
|
return t.clamp_min(eps).log()
|
|
50
96
|
|
|
97
|
+
def is_empty(t):
|
|
98
|
+
return t.numel() == 0
|
|
99
|
+
|
|
51
100
|
def tree_map_tensor(x, fn):
|
|
52
101
|
return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
|
|
53
102
|
|
|
54
|
-
def detach_all(x):
|
|
55
|
-
return tree_map_tensor(x, lambda t: t.detach())
|
|
56
|
-
|
|
57
103
|
def pad_at_dim(
|
|
58
104
|
t,
|
|
59
105
|
pad: tuple[int, int],
|
|
@@ -67,10 +113,90 @@ def pad_at_dim(
|
|
|
67
113
|
zeros = ((0, 0) * dims_from_right)
|
|
68
114
|
return F.pad(t, (*zeros, *pad), value = value)
|
|
69
115
|
|
|
116
|
+
def normalize(t, eps = 1e-5):
|
|
117
|
+
return (t - t.mean()) / t.std().clamp_min(eps)
|
|
118
|
+
|
|
119
|
+
def tensor_to_dict(
|
|
120
|
+
t: Tensor,
|
|
121
|
+
config: tuple[tuple[str, int] | str],
|
|
122
|
+
dim = -1,
|
|
123
|
+
return_dottable = True
|
|
124
|
+
):
|
|
125
|
+
config = tuple((c, 1) if isinstance(c, str) else c for c in config)
|
|
126
|
+
|
|
127
|
+
names, sizes = zip(*config)
|
|
128
|
+
assert sum(sizes) == t.shape[dim]
|
|
129
|
+
|
|
130
|
+
t = t.split(sizes, dim = dim)
|
|
131
|
+
tensor_dict = dict(zip(names, t))
|
|
132
|
+
|
|
133
|
+
if not return_dottable:
|
|
134
|
+
return tensor_dict
|
|
135
|
+
|
|
136
|
+
return SimpleNamespace(**tensor_dict)
|
|
137
|
+
|
|
70
138
|
def calc_entropy(logits):
|
|
71
139
|
prob = logits.softmax(dim = -1)
|
|
72
140
|
return -(prob * log(prob)).sum(dim = -1)
|
|
73
141
|
|
|
142
|
+
# reward functions - A.2
|
|
143
|
+
|
|
144
|
+
@check_has_param_attr('state', 'v_xy')
|
|
145
|
+
@check_has_param_attr('command', 'v_xy')
|
|
146
|
+
def reward_linear_velocity_command_tracking(
|
|
147
|
+
state,
|
|
148
|
+
command,
|
|
149
|
+
s1 = 1.
|
|
150
|
+
):
|
|
151
|
+
error = (state.v_xy - command.v_xy).norm(dim = -1).pow(2)
|
|
152
|
+
return torch.exp(-error / s1)
|
|
153
|
+
|
|
154
|
+
@check_has_param_attr('state', 'w_z')
|
|
155
|
+
@check_has_param_attr('command', 'w_z')
|
|
156
|
+
def reward_angular_velocity_command_tracking(
|
|
157
|
+
state,
|
|
158
|
+
command,
|
|
159
|
+
s2 = 1.
|
|
160
|
+
):
|
|
161
|
+
error = (state.w_z - command.w_z).norm(dim = -1).pow(2)
|
|
162
|
+
return torch.exp(-error / s2)
|
|
163
|
+
|
|
164
|
+
@check_has_param_attr('state', 'v_z')
|
|
165
|
+
def reward_base_linear_velocity_penalty(
|
|
166
|
+
state
|
|
167
|
+
):
|
|
168
|
+
return -state.v_z.norm(dim = -1).pow(2)
|
|
169
|
+
|
|
170
|
+
@check_has_param_attr('state', 'w_xy')
|
|
171
|
+
def reward_base_angular_velocity_penalty(
|
|
172
|
+
state
|
|
173
|
+
):
|
|
174
|
+
return -state.w_xy.norm(dim = -1).pow(2)
|
|
175
|
+
|
|
176
|
+
@check_has_param_attr('state', 'x_z')
|
|
177
|
+
def reward_base_height_penalty(
|
|
178
|
+
state,
|
|
179
|
+
x_z_nominal = 0.27
|
|
180
|
+
):
|
|
181
|
+
return -(state.x_z - x_z_nominal).norm(dim = -1).pow(2)
|
|
182
|
+
|
|
183
|
+
@check_has_param_attr('state', 'joint_q')
|
|
184
|
+
def reward_joint_acceleration_penalty(
|
|
185
|
+
state
|
|
186
|
+
):
|
|
187
|
+
return -state.joint_q.norm(dim = -1).pow(2)
|
|
188
|
+
|
|
189
|
+
@check_has_param_attr('state', 'tau')
|
|
190
|
+
def reward_torque_penalty(
|
|
191
|
+
state
|
|
192
|
+
):
|
|
193
|
+
return -state.tau.norm(dim = -1).pow(2)
|
|
194
|
+
|
|
195
|
+
def reward_alive(
|
|
196
|
+
state
|
|
197
|
+
):
|
|
198
|
+
return 1.
|
|
199
|
+
|
|
74
200
|
# generalized advantage estimate
|
|
75
201
|
|
|
76
202
|
@torch.no_grad()
|
|
@@ -100,7 +226,7 @@ def calc_gae(
|
|
|
100
226
|
|
|
101
227
|
returns = gae + values
|
|
102
228
|
|
|
103
|
-
return returns
|
|
229
|
+
return gae, returns
|
|
104
230
|
|
|
105
231
|
# transformer-xl mask w/ flex attn
|
|
106
232
|
|
|
@@ -250,6 +376,74 @@ class ReplayDataset(Dataset):
|
|
|
250
376
|
|
|
251
377
|
return data
|
|
252
378
|
|
|
379
|
+
class RemappedReplayDataset(Dataset):
|
|
380
|
+
def __init__(
|
|
381
|
+
self,
|
|
382
|
+
dataset: ReplayDataset,
|
|
383
|
+
episode_mapping: Tensor | list[list[int]],
|
|
384
|
+
shuffle_episodes = False,
|
|
385
|
+
num_trials_select = None
|
|
386
|
+
):
|
|
387
|
+
assert len(dataset) > 0
|
|
388
|
+
self.dataset = dataset
|
|
389
|
+
|
|
390
|
+
if is_tensor(episode_mapping):
|
|
391
|
+
assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
|
|
392
|
+
episode_mapping = episode_mapping.tolist()
|
|
393
|
+
|
|
394
|
+
self.episode_mapping = episode_mapping
|
|
395
|
+
self.shuffle_episodes = shuffle_episodes
|
|
396
|
+
|
|
397
|
+
assert not (exists(num_trials_select) and num_trials_select >= 1)
|
|
398
|
+
self.sub_select_trials = exists(num_trials_select)
|
|
399
|
+
self.num_trials_select = num_trials_select
|
|
400
|
+
|
|
401
|
+
def __len__(self):
|
|
402
|
+
return len(self.episode_mapping)
|
|
403
|
+
|
|
404
|
+
def __getitem__(self, idx):
|
|
405
|
+
|
|
406
|
+
episode_indices = self.episode_mapping[idx]
|
|
407
|
+
|
|
408
|
+
episode_indices = tensor(episode_indices)
|
|
409
|
+
episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
|
|
410
|
+
|
|
411
|
+
assert not is_empty(episode_indices)
|
|
412
|
+
|
|
413
|
+
# shuffle the episode indices if either shuffle episodes is turned on, or `num_trial_select` passed in (for sub selecting episodes from a set)
|
|
414
|
+
|
|
415
|
+
if (
|
|
416
|
+
episode_indices.numel() > 1 and
|
|
417
|
+
(self.shuffle_episodes or self.sub_select_trials)
|
|
418
|
+
):
|
|
419
|
+
num_episodes = len(episode_indices)
|
|
420
|
+
episode_indices = episode_indices[torch.randperm(num_episodes)]
|
|
421
|
+
|
|
422
|
+
# crop out the episodes
|
|
423
|
+
|
|
424
|
+
if self.sub_select_trials:
|
|
425
|
+
episode_indices = episode_indices[:self.num_trials_select]
|
|
426
|
+
|
|
427
|
+
# now select out the episode data and merge along time
|
|
428
|
+
|
|
429
|
+
episode_data = [self.dataset[i] for i in episode_indices.tolist()]
|
|
430
|
+
|
|
431
|
+
episode_lens = stack([data.pop('_lens') for data in episode_data])
|
|
432
|
+
|
|
433
|
+
keys = first(episode_data).keys()
|
|
434
|
+
|
|
435
|
+
values = [list(data.values()) for data in episode_data]
|
|
436
|
+
|
|
437
|
+
values = [cat(field_values) for field_values in zip(*values)] # concat across time
|
|
438
|
+
|
|
439
|
+
multi_episode_data = dict(zip(keys, values))
|
|
440
|
+
|
|
441
|
+
multi_episode_data['_lens'] = episode_lens.sum()
|
|
442
|
+
|
|
443
|
+
multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
|
|
444
|
+
|
|
445
|
+
return multi_episode_data
|
|
446
|
+
|
|
253
447
|
class ReplayBuffer:
|
|
254
448
|
|
|
255
449
|
@beartype
|
|
@@ -306,6 +500,10 @@ class ReplayBuffer:
|
|
|
306
500
|
# memmap file
|
|
307
501
|
|
|
308
502
|
filepath = folder / f'{field_name}.data.npy'
|
|
503
|
+
|
|
504
|
+
if isinstance(shape, int):
|
|
505
|
+
shape = (shape,)
|
|
506
|
+
|
|
309
507
|
memmap = open_memmap(str(filepath), mode = 'w+', dtype = dtype, shape = (max_episodes, max_timesteps, *shape))
|
|
310
508
|
|
|
311
509
|
self.memmaps[field_name] = memmap
|
|
@@ -314,6 +512,9 @@ class ReplayBuffer:
|
|
|
314
512
|
|
|
315
513
|
self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
|
|
316
514
|
|
|
515
|
+
def __len__(self):
|
|
516
|
+
return (self.episode_lens > 0).sum().item()
|
|
517
|
+
|
|
317
518
|
def reset_(self):
|
|
318
519
|
self.episode_lens[:] = 0
|
|
319
520
|
self.episode_index = 0
|
|
@@ -375,15 +576,92 @@ class ReplayBuffer:
|
|
|
375
576
|
|
|
376
577
|
return self.memory_namedtuple(**data)
|
|
377
578
|
|
|
378
|
-
def dataset(
|
|
579
|
+
def dataset(
|
|
580
|
+
self,
|
|
581
|
+
episode_mapping: Tensor | list[list[int]] | None = None,
|
|
582
|
+
) -> Dataset:
|
|
379
583
|
self.flush()
|
|
380
584
|
|
|
381
|
-
|
|
585
|
+
dataset = ReplayDataset(self.folder)
|
|
586
|
+
|
|
587
|
+
if not exists(episode_mapping):
|
|
588
|
+
return dataset
|
|
382
589
|
|
|
383
|
-
|
|
590
|
+
return RemappedReplayDataset(dataset, episode_mapping)
|
|
591
|
+
|
|
592
|
+
def dataloader(
|
|
593
|
+
self,
|
|
594
|
+
batch_size,
|
|
595
|
+
episode_mapping: Tensor | list[list[int]] | None = None,
|
|
596
|
+
**kwargs
|
|
597
|
+
) -> DataLoader:
|
|
384
598
|
self.flush()
|
|
385
599
|
|
|
386
|
-
return DataLoader(self.dataset(), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
|
|
600
|
+
return DataLoader(self.dataset(episode_mapping), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
|
|
601
|
+
|
|
602
|
+
# normalization + conditioning (needed for the commands to the robot)
|
|
603
|
+
|
|
604
|
+
class MaybeAdaRMSNormWrapper(Module):
|
|
605
|
+
def __init__(
|
|
606
|
+
self,
|
|
607
|
+
fn: Module,
|
|
608
|
+
dim,
|
|
609
|
+
dim_cond = None
|
|
610
|
+
):
|
|
611
|
+
super().__init__()
|
|
612
|
+
condition = exists(dim_cond)
|
|
613
|
+
|
|
614
|
+
self.fn = fn
|
|
615
|
+
self.norm = nn.RMSNorm(dim, elementwise_affine = not condition)
|
|
616
|
+
|
|
617
|
+
self.accept_condition = condition
|
|
618
|
+
|
|
619
|
+
if condition:
|
|
620
|
+
self.to_gamma = LinearNoBias(dim_cond, dim)
|
|
621
|
+
self.to_ada_norm_zero = nn.Linear(dim_cond, dim)
|
|
622
|
+
|
|
623
|
+
nn.init.zeros_(self.to_gamma.weight)
|
|
624
|
+
nn.init.zeros_(self.to_ada_norm_zero.weight)
|
|
625
|
+
nn.init.constant_(self.to_ada_norm_zero.bias, -5.)
|
|
626
|
+
|
|
627
|
+
def forward(
|
|
628
|
+
self,
|
|
629
|
+
x,
|
|
630
|
+
cond = None,
|
|
631
|
+
**kwargs
|
|
632
|
+
):
|
|
633
|
+
|
|
634
|
+
need_cond = self.accept_condition
|
|
635
|
+
|
|
636
|
+
assert xnor(exists(cond), need_cond)
|
|
637
|
+
|
|
638
|
+
prenormed = self.norm(x)
|
|
639
|
+
|
|
640
|
+
if need_cond:
|
|
641
|
+
if cond.ndim == 2:
|
|
642
|
+
cond = rearrange(cond, 'b d -> b 1 d')
|
|
643
|
+
|
|
644
|
+
scale_in = self.to_gamma(cond)
|
|
645
|
+
prenormed = prenormed * (scale_in + 1.)
|
|
646
|
+
|
|
647
|
+
all_fn_out = self.fn(prenormed, **kwargs)
|
|
648
|
+
|
|
649
|
+
if not need_cond:
|
|
650
|
+
return all_fn_out
|
|
651
|
+
|
|
652
|
+
# function may return multiple args
|
|
653
|
+
|
|
654
|
+
(out, *rest), tree_spec = tree_flatten(all_fn_out)
|
|
655
|
+
|
|
656
|
+
if need_cond:
|
|
657
|
+
scale_out = self.to_ada_norm_zero(cond).sigmoid()
|
|
658
|
+
out = out * scale_out
|
|
659
|
+
|
|
660
|
+
# restore
|
|
661
|
+
|
|
662
|
+
all_fn_out = tree_unflatten((out, *rest), tree_spec)
|
|
663
|
+
|
|
664
|
+
return all_fn_out
|
|
387
665
|
|
|
388
666
|
# transformer-xl with ppo
|
|
389
667
|
|
|
@@ -394,15 +672,12 @@ class Attention(Module):
|
|
|
394
672
|
window_size,
|
|
395
673
|
dim_head = 64,
|
|
396
674
|
heads = 8,
|
|
397
|
-
pre_rmsnorm = True,
|
|
398
675
|
fixed_window_size = False,
|
|
399
676
|
accept_value_residual = False
|
|
400
677
|
):
|
|
401
678
|
super().__init__()
|
|
402
679
|
self.scale = dim_head ** -0.5
|
|
403
680
|
|
|
404
|
-
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
|
405
|
-
|
|
406
681
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
|
407
682
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
408
683
|
|
|
@@ -446,8 +721,6 @@ class Attention(Module):
|
|
|
446
721
|
|
|
447
722
|
device = tokens.device
|
|
448
723
|
|
|
449
|
-
tokens = self.norm(tokens)
|
|
450
|
-
|
|
451
724
|
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
|
|
452
725
|
|
|
453
726
|
q, k, v = map(self.split_heads, (q, k, v))
|
|
@@ -536,19 +809,26 @@ class TransformerXL(Module):
|
|
|
536
809
|
dim_head = 64,
|
|
537
810
|
heads = 8,
|
|
538
811
|
expansion_factor = 4.,
|
|
812
|
+
dim_cond = None,
|
|
539
813
|
final_norm = True,
|
|
540
814
|
fixed_window_size = False,
|
|
541
815
|
):
|
|
542
816
|
super().__init__()
|
|
543
817
|
|
|
818
|
+
condition = exists(dim_cond)
|
|
819
|
+
|
|
820
|
+
self.to_cond_tokens = MLP(dim_cond, dim * 2, activate_last = True) if exists(dim_cond) else None
|
|
821
|
+
|
|
822
|
+
norm_fn = partial(MaybeAdaRMSNormWrapper, dim = dim, dim_cond = (dim * 2) if condition else None)
|
|
823
|
+
|
|
544
824
|
layers = ModuleList([])
|
|
545
825
|
|
|
546
826
|
for i in range(depth):
|
|
547
827
|
is_first = i == 0
|
|
548
828
|
|
|
549
|
-
attn = Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first)
|
|
829
|
+
attn = norm_fn(Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first))
|
|
550
830
|
|
|
551
|
-
ff = FeedForward(dim = dim, expansion_factor = expansion_factor)
|
|
831
|
+
ff = norm_fn(FeedForward(dim = dim, expansion_factor = expansion_factor))
|
|
552
832
|
|
|
553
833
|
layers.append(ModuleList([
|
|
554
834
|
attn, ff
|
|
@@ -566,20 +846,32 @@ class TransformerXL(Module):
|
|
|
566
846
|
self,
|
|
567
847
|
x,
|
|
568
848
|
cache = None,
|
|
569
|
-
return_kv_cache = False
|
|
849
|
+
return_kv_cache = False,
|
|
850
|
+
condition: Tensor | None = None
|
|
570
851
|
):
|
|
571
852
|
|
|
853
|
+
# cache and residuals
|
|
854
|
+
|
|
572
855
|
cache = default(cache, (None,) * len(self.layers))
|
|
573
856
|
|
|
574
857
|
next_kv_caches = []
|
|
575
858
|
value_residual = None
|
|
576
859
|
|
|
860
|
+
# handle condition
|
|
861
|
+
|
|
862
|
+
cond_tokens = None
|
|
863
|
+
if exists(condition):
|
|
864
|
+
assert exists(self.to_cond_tokens)
|
|
865
|
+
cond_tokens = self.to_cond_tokens(condition)
|
|
866
|
+
|
|
867
|
+
# layers
|
|
868
|
+
|
|
577
869
|
for (attn, ff), kv_cache in zip(self.layers, cache):
|
|
578
870
|
|
|
579
|
-
attn_out, (next_kv_cache, values) = attn(x, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
|
|
871
|
+
attn_out, (next_kv_cache, values) = attn(x, cond = cond_tokens, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
|
|
580
872
|
|
|
581
873
|
x = attn_out + x
|
|
582
|
-
x = ff(x) + x
|
|
874
|
+
x = ff(x, cond = cond_tokens) + x
|
|
583
875
|
|
|
584
876
|
next_kv_caches.append(next_kv_cache)
|
|
585
877
|
value_residual = default(value_residual, values)
|
|
@@ -600,16 +892,24 @@ class TransformerXL(Module):
|
|
|
600
892
|
class Locoformer(Module):
|
|
601
893
|
def __init__(
|
|
602
894
|
self,
|
|
603
|
-
embedder: Module,
|
|
895
|
+
embedder: Module | ModuleList | list[Module],
|
|
604
896
|
unembedder: Module,
|
|
605
897
|
transformer: dict | TransformerXL,
|
|
606
|
-
value_network: Module | None = None,
|
|
607
898
|
discount_factor = 0.999,
|
|
608
899
|
gae_lam = 0.95,
|
|
609
900
|
ppo_eps_clip = 0.2,
|
|
610
901
|
ppo_entropy_weight = 0.01,
|
|
611
902
|
ppo_value_clip = 0.4,
|
|
612
|
-
|
|
903
|
+
dim_value_input = None, # needs to be set for value network to be available
|
|
904
|
+
value_network: Module = nn.Identity(),
|
|
905
|
+
reward_range: tuple[float, float] | None = None,
|
|
906
|
+
reward_shaping_fns: list[Callable[..., float | Tensor]] | None = None,
|
|
907
|
+
num_reward_bins = 32,
|
|
908
|
+
hl_gauss_loss_kwargs = dict(),
|
|
909
|
+
value_loss_weight = 0.5,
|
|
910
|
+
calc_gae_kwargs: dict = dict(),
|
|
911
|
+
recurrent_kv_cache = True,
|
|
912
|
+
use_spo = False # simple policy optimization https://arxiv.org/abs/2401.16025 - Levine's group (PI) verified it is more stable than PPO
|
|
613
913
|
):
|
|
614
914
|
super().__init__()
|
|
615
915
|
|
|
@@ -618,14 +918,41 @@ class Locoformer(Module):
|
|
|
618
918
|
|
|
619
919
|
self.transformer = transformer
|
|
620
920
|
|
|
921
|
+
# handle state embedder
|
|
922
|
+
|
|
923
|
+
if isinstance(embedder, list):
|
|
924
|
+
embedder = ModuleList(embedder)
|
|
925
|
+
|
|
621
926
|
self.embedder = embedder
|
|
622
|
-
self.unembedder = unembedder
|
|
623
927
|
|
|
624
|
-
|
|
928
|
+
# unembed state to actions or ssl predictions
|
|
929
|
+
|
|
930
|
+
self.unembedder = unembedder
|
|
625
931
|
|
|
626
932
|
self.fixed_window_size = transformer.fixed_window_size
|
|
627
933
|
self.window_size = transformer.window_size
|
|
628
934
|
|
|
935
|
+
# determine value network, using HL Gauss Layer
|
|
936
|
+
|
|
937
|
+
self.to_value_pred = None
|
|
938
|
+
|
|
939
|
+
if exists(dim_value_input):
|
|
940
|
+
assert exists(reward_range)
|
|
941
|
+
|
|
942
|
+
self.to_value_pred = nn.Sequential(
|
|
943
|
+
value_network,
|
|
944
|
+
LinearNoBias(dim_value_input, num_reward_bins)
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
reward_min, reward_max = reward_range
|
|
948
|
+
|
|
949
|
+
self.hl_gauss_loss = HLGaussLoss(
|
|
950
|
+
min_value = reward_min,
|
|
951
|
+
max_value = reward_max,
|
|
952
|
+
num_bins = num_reward_bins,
|
|
953
|
+
**hl_gauss_loss_kwargs
|
|
954
|
+
)
|
|
955
|
+
|
|
629
956
|
# ppo related
|
|
630
957
|
|
|
631
958
|
self.discount_factor = discount_factor
|
|
@@ -635,6 +962,25 @@ class Locoformer(Module):
|
|
|
635
962
|
self.ppo_value_clip = ppo_value_clip
|
|
636
963
|
self.value_loss_weight = value_loss_weight
|
|
637
964
|
|
|
965
|
+
self.calc_gae_kwargs = calc_gae_kwargs
|
|
966
|
+
|
|
967
|
+
# maybe use spo
|
|
968
|
+
|
|
969
|
+
self.use_spo = use_spo
|
|
970
|
+
|
|
971
|
+
# maybe recurrent kv cache, from Ding et al. https://arxiv.org/abs/2012.15688
|
|
972
|
+
|
|
973
|
+
self.recurrent_kv_cache = recurrent_kv_cache
|
|
974
|
+
|
|
975
|
+
# reward shaping function
|
|
976
|
+
|
|
977
|
+
self.has_reward_shaping = exists(reward_shaping_fns)
|
|
978
|
+
self.reward_shaping_fns = reward_shaping_fns
|
|
979
|
+
|
|
980
|
+
# loss related
|
|
981
|
+
|
|
982
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
983
|
+
|
|
638
984
|
@property
|
|
639
985
|
def device(self):
|
|
640
986
|
return next(self.parameters()).device
|
|
@@ -643,10 +989,18 @@ class Locoformer(Module):
|
|
|
643
989
|
return self.unembedder.parameters()
|
|
644
990
|
|
|
645
991
|
def critic_parameters(self):
|
|
646
|
-
if not exists(self.
|
|
992
|
+
if not exists(self.to_value_pred):
|
|
647
993
|
return []
|
|
648
994
|
|
|
649
|
-
return self.
|
|
995
|
+
return self.to_value_pred.parameters()
|
|
996
|
+
|
|
997
|
+
def evolve(
|
|
998
|
+
self,
|
|
999
|
+
environment,
|
|
1000
|
+
**kwargs
|
|
1001
|
+
):
|
|
1002
|
+
evo_strat = EvoStrategy(self, environment = environment, **kwargs)
|
|
1003
|
+
evo_strat()
|
|
650
1004
|
|
|
651
1005
|
def ppo(
|
|
652
1006
|
self,
|
|
@@ -656,79 +1010,180 @@ class Locoformer(Module):
|
|
|
656
1010
|
reward,
|
|
657
1011
|
old_value,
|
|
658
1012
|
mask,
|
|
659
|
-
|
|
660
|
-
|
|
1013
|
+
episode_lens,
|
|
1014
|
+
condition: Tensor | None = None,
|
|
1015
|
+
state_type: int | None = None,
|
|
1016
|
+
actor_optim: Optimizer | None = None,
|
|
1017
|
+
critic_optim: Optimizer | None = None
|
|
661
1018
|
):
|
|
1019
|
+
window_size = self.window_size
|
|
1020
|
+
total_learnable_tokens = mask.sum().item()
|
|
662
1021
|
|
|
663
|
-
|
|
664
|
-
|
|
1022
|
+
seq_len = state.shape[1]
|
|
1023
|
+
gae_mask = einx.less('j, i -> i j', arange(seq_len, device = self.device), episode_lens)
|
|
665
1024
|
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
1025
|
+
advantage, returns = calc_gae(reward, old_value, masks = gae_mask, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
|
|
1026
|
+
|
|
1027
|
+
advantage = normalize(advantage)
|
|
1028
|
+
|
|
1029
|
+
data_tensors = (
|
|
1030
|
+
state,
|
|
1031
|
+
action,
|
|
1032
|
+
old_action_log_prob,
|
|
1033
|
+
reward,
|
|
1034
|
+
old_value,
|
|
1035
|
+
mask,
|
|
1036
|
+
advantage,
|
|
1037
|
+
returns
|
|
1038
|
+
)
|
|
669
1039
|
|
|
670
|
-
|
|
1040
|
+
has_condition = exists(condition)
|
|
671
1041
|
|
|
672
|
-
|
|
673
|
-
|
|
1042
|
+
if exists(condition):
|
|
1043
|
+
data_tensors = (*data_tensors, condition)
|
|
1044
|
+
|
|
1045
|
+
windowed_tensors = [
|
|
1046
|
+
t.split(window_size, dim = 1) for t in
|
|
1047
|
+
data_tensors
|
|
1048
|
+
]
|
|
1049
|
+
|
|
1050
|
+
mean_actor_loss = self.zero.clone()
|
|
1051
|
+
mean_critic_loss = self.zero.clone()
|
|
1052
|
+
|
|
1053
|
+
# learn across windows
|
|
1054
|
+
|
|
1055
|
+
cache = None
|
|
674
1056
|
|
|
675
|
-
|
|
676
|
-
|
|
1057
|
+
for (
|
|
1058
|
+
state,
|
|
1059
|
+
action,
|
|
1060
|
+
old_action_log_prob,
|
|
1061
|
+
reward,
|
|
1062
|
+
old_value,
|
|
1063
|
+
mask,
|
|
1064
|
+
advantage,
|
|
1065
|
+
returns,
|
|
1066
|
+
*rest
|
|
1067
|
+
) in zip(*windowed_tensors):
|
|
677
1068
|
|
|
678
|
-
|
|
1069
|
+
if has_condition:
|
|
1070
|
+
condition, = rest
|
|
679
1071
|
|
|
680
|
-
|
|
1072
|
+
(action_logits, value_logits), cache = self.forward(state, condition = condition, state_type = state_type, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True)
|
|
1073
|
+
entropy = calc_entropy(action_logits)
|
|
681
1074
|
|
|
682
|
-
|
|
683
|
-
|
|
1075
|
+
action = rearrange(action, 'b t -> b t 1')
|
|
1076
|
+
log_prob = action_logits.gather(-1, action)
|
|
1077
|
+
log_prob = rearrange(log_prob, 'b t 1 -> b t')
|
|
684
1078
|
|
|
685
|
-
|
|
1079
|
+
# update actor, classic clipped surrogate loss
|
|
686
1080
|
|
|
687
|
-
|
|
1081
|
+
eps_clip = self.ppo_eps_clip
|
|
1082
|
+
ratio = (log_prob - old_action_log_prob).exp()
|
|
688
1083
|
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
1084
|
+
if self.use_spo:
|
|
1085
|
+
actor_loss = -(ratio * advantage - (advantage.abs() * (ratio - 1.).square()) / (2 * eps_clip))
|
|
1086
|
+
else:
|
|
1087
|
+
actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
|
|
692
1088
|
|
|
693
|
-
|
|
1089
|
+
actor_loss = actor_loss - self.ppo_entropy_weight * entropy
|
|
694
1090
|
|
|
695
|
-
|
|
696
|
-
|
|
1091
|
+
windowed_actor_loss = actor_loss[mask].sum() / total_learnable_tokens
|
|
1092
|
+
windowed_actor_loss.backward(retain_graph = True)
|
|
1093
|
+
|
|
1094
|
+
# update critic
|
|
1095
|
+
|
|
1096
|
+
value_loss = self.hl_gauss_loss(value_logits, returns, reduction = 'none')
|
|
1097
|
+
|
|
1098
|
+
value_clip = self.ppo_value_clip
|
|
1099
|
+
value = self.hl_gauss_loss(value_logits)
|
|
1100
|
+
|
|
1101
|
+
clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
|
|
1102
|
+
clipped_value_loss = self.hl_gauss_loss(clipped_value, returns, reduction = 'none')
|
|
1103
|
+
|
|
1104
|
+
critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
|
|
1105
|
+
|
|
1106
|
+
windowed_critic_loss = critic_loss[mask].sum() / total_learnable_tokens
|
|
1107
|
+
windowed_critic_loss.backward(retain_graph = True)
|
|
1108
|
+
|
|
1109
|
+
# accumulate
|
|
1110
|
+
|
|
1111
|
+
mean_actor_loss.add_(windowed_actor_loss)
|
|
1112
|
+
mean_critic_loss.add_(windowed_critic_loss)
|
|
697
1113
|
|
|
698
1114
|
# optimizer update
|
|
699
1115
|
|
|
700
|
-
actor_optim
|
|
701
|
-
|
|
1116
|
+
if exists(actor_optim):
|
|
1117
|
+
actor_optim.step()
|
|
1118
|
+
actor_optim.zero_grad()
|
|
702
1119
|
|
|
703
|
-
critic_optim
|
|
704
|
-
|
|
1120
|
+
if exists(critic_optim):
|
|
1121
|
+
critic_optim.step()
|
|
1122
|
+
critic_optim.zero_grad()
|
|
705
1123
|
|
|
706
1124
|
# return losses for logging
|
|
707
1125
|
|
|
708
1126
|
return mean_actor_loss.detach(), mean_critic_loss.detach()
|
|
709
1127
|
|
|
1128
|
+
def state_and_command_to_rewards(
|
|
1129
|
+
self,
|
|
1130
|
+
state,
|
|
1131
|
+
commands = None
|
|
1132
|
+
) -> Tensor:
|
|
1133
|
+
|
|
1134
|
+
assert self.has_reward_shaping
|
|
1135
|
+
|
|
1136
|
+
rewards = []
|
|
1137
|
+
|
|
1138
|
+
for fn in self.reward_shaping_fns:
|
|
1139
|
+
param_names = get_param_names(fn)
|
|
1140
|
+
param_names = set(param_names) & {'state', 'command'}
|
|
1141
|
+
|
|
1142
|
+
if param_names == {'state'}: # only state
|
|
1143
|
+
reward = fn(state = state)
|
|
1144
|
+
elif param_names == {'state', 'command'}: # state and command
|
|
1145
|
+
reward = fn(state = state, command = commands)
|
|
1146
|
+
else:
|
|
1147
|
+
raise ValueError('invalid number of arguments for reward shaping function')
|
|
1148
|
+
|
|
1149
|
+
rewards.append(reward)
|
|
1150
|
+
|
|
1151
|
+
# cast to Tensor if returns a float, just make it flexible for researcher
|
|
1152
|
+
|
|
1153
|
+
rewards = [tensor(reward) if not is_tensor(reward) else reward for reward in rewards]
|
|
1154
|
+
|
|
1155
|
+
return stack(rewards)
|
|
1156
|
+
|
|
710
1157
|
def wrap_env_functions(self, env):
|
|
711
1158
|
|
|
712
|
-
def
|
|
713
|
-
|
|
1159
|
+
def transform_output(el):
|
|
1160
|
+
if isinstance(el, ndarray):
|
|
1161
|
+
return from_numpy(el)
|
|
1162
|
+
elif isinstance(el, (int, bool, float)):
|
|
1163
|
+
return tensor(el)
|
|
1164
|
+
else:
|
|
1165
|
+
return el
|
|
714
1166
|
|
|
715
|
-
|
|
716
|
-
|
|
1167
|
+
def wrapped_reset(*args, **kwargs):
|
|
1168
|
+
env_reset_out = env.reset(*args, **kwargs)
|
|
717
1169
|
|
|
718
|
-
return
|
|
1170
|
+
return tree_map(transform_output, env_reset_out)
|
|
719
1171
|
|
|
720
1172
|
def wrapped_step(action, *args, **kwargs):
|
|
721
|
-
out = env.step(action.item(), *args, **kwargs)
|
|
722
1173
|
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
return el
|
|
1174
|
+
if is_tensor(action):
|
|
1175
|
+
action = action.item()
|
|
1176
|
+
|
|
1177
|
+
env_step_out = env.step(action, *args, **kwargs)
|
|
1178
|
+
|
|
1179
|
+
env_step_out_torch = tree_map(transform_output, env_step_out)
|
|
730
1180
|
|
|
731
|
-
|
|
1181
|
+
if not self.has_reward_shaping:
|
|
1182
|
+
return env_step_out_torch
|
|
1183
|
+
|
|
1184
|
+
shaped_rewards = self.state_and_command_to_rewards(env_step_out_torch)
|
|
1185
|
+
|
|
1186
|
+
return env_step_out_torch, shaped_rewards
|
|
732
1187
|
|
|
733
1188
|
return wrapped_reset, wrapped_step
|
|
734
1189
|
|
|
@@ -738,38 +1193,48 @@ class Locoformer(Module):
|
|
|
738
1193
|
inference_mode = False,
|
|
739
1194
|
has_batch_dim = False,
|
|
740
1195
|
has_time_dim = False,
|
|
1196
|
+
state_time_dim = 1,
|
|
741
1197
|
**kwargs
|
|
742
1198
|
):
|
|
743
1199
|
window_size = self.window_size
|
|
744
1200
|
|
|
745
1201
|
cache = None
|
|
746
1202
|
|
|
747
|
-
def stateful_forward(
|
|
1203
|
+
def stateful_forward(
|
|
1204
|
+
state: Tensor,
|
|
1205
|
+
condition: Tensor | None = None,
|
|
1206
|
+
state_type: int | None = None,
|
|
1207
|
+
**override_kwargs
|
|
1208
|
+
):
|
|
748
1209
|
nonlocal cache
|
|
749
1210
|
|
|
1211
|
+
state = state.to(self.device)
|
|
1212
|
+
|
|
1213
|
+
if exists(condition):
|
|
1214
|
+
condition = condition.to(self.device)
|
|
1215
|
+
|
|
750
1216
|
# handle no batch or time, for easier time rolling out against envs
|
|
751
1217
|
|
|
752
1218
|
if not has_batch_dim:
|
|
753
1219
|
state = rearrange(state, '... -> 1 ...')
|
|
754
1220
|
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
# forwards
|
|
1221
|
+
if exists(condition):
|
|
1222
|
+
condition = rearrange(condition, '... -> 1 ...')
|
|
759
1223
|
|
|
760
|
-
|
|
1224
|
+
if not has_time_dim:
|
|
1225
|
+
state = state.unsqueeze(state_time_dim)
|
|
761
1226
|
|
|
762
|
-
|
|
1227
|
+
if exists(condition):
|
|
1228
|
+
condition = rearrange(condition, '... d -> ... 1 d')
|
|
763
1229
|
|
|
764
|
-
|
|
1230
|
+
# forwards
|
|
765
1231
|
|
|
766
|
-
|
|
767
|
-
cache = cache[..., -window_size:, :]
|
|
1232
|
+
out, cache = self.forward(state, condition = condition, state_type = state_type, cache = cache, **{**kwargs, **override_kwargs})
|
|
768
1233
|
|
|
769
1234
|
# maybe remove batch or time
|
|
770
1235
|
|
|
771
1236
|
if not has_time_dim:
|
|
772
|
-
out = tree_map_tensor(out, lambda t:
|
|
1237
|
+
out = tree_map_tensor(out, lambda t: t.squeeze(state_time_dim))
|
|
773
1238
|
|
|
774
1239
|
if not has_batch_dim:
|
|
775
1240
|
out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
|
|
@@ -798,16 +1263,46 @@ class Locoformer(Module):
|
|
|
798
1263
|
def forward(
|
|
799
1264
|
self,
|
|
800
1265
|
state: Tensor,
|
|
801
|
-
cache:
|
|
1266
|
+
cache: Cache | None = None,
|
|
1267
|
+
condition: Tensor | None = None,
|
|
1268
|
+
state_type: int | None = None,
|
|
802
1269
|
detach_cache = False,
|
|
803
|
-
return_values = False
|
|
1270
|
+
return_values = False,
|
|
1271
|
+
return_raw_value_logits = False
|
|
804
1272
|
):
|
|
805
1273
|
|
|
806
1274
|
state = state.to(self.device)
|
|
807
1275
|
|
|
808
|
-
|
|
1276
|
+
# determine which function to invoke for state to token for transformer
|
|
1277
|
+
|
|
1278
|
+
state_to_token = self.embedder
|
|
1279
|
+
|
|
1280
|
+
if exists(state_type):
|
|
1281
|
+
state_to_token = self.embedder[state_type]
|
|
1282
|
+
|
|
1283
|
+
# embed
|
|
1284
|
+
|
|
1285
|
+
tokens = state_to_token(state)
|
|
1286
|
+
|
|
1287
|
+
# time
|
|
1288
|
+
|
|
1289
|
+
time = tokens.shape[-2]
|
|
1290
|
+
|
|
1291
|
+
# destruct the cache for the current timestep and the cache
|
|
809
1292
|
|
|
810
|
-
|
|
1293
|
+
prev_kv_cache = None
|
|
1294
|
+
timestep_start = 0
|
|
1295
|
+
|
|
1296
|
+
if exists(cache):
|
|
1297
|
+
timestep_start, prev_kv_cache = cache
|
|
1298
|
+
|
|
1299
|
+
# an assert - make sure during training or inference, forward never gets anything that crosses the window segment boundary, to open up some possibilities with extending memory
|
|
1300
|
+
|
|
1301
|
+
assert ((timestep_start % self.window_size) + time) <= self.window_size
|
|
1302
|
+
|
|
1303
|
+
# attention
|
|
1304
|
+
|
|
1305
|
+
embed, kv_cache = self.transformer(tokens, condition = condition, cache = prev_kv_cache, return_kv_cache = True)
|
|
811
1306
|
|
|
812
1307
|
# unembed to actions - in language models this would be the next state
|
|
813
1308
|
|
|
@@ -818,21 +1313,34 @@ class Locoformer(Module):
|
|
|
818
1313
|
# maybe detach cache
|
|
819
1314
|
|
|
820
1315
|
if detach_cache:
|
|
821
|
-
kv_cache =
|
|
1316
|
+
kv_cache = kv_cache.detach()
|
|
822
1317
|
|
|
823
1318
|
# handle returning of values
|
|
824
1319
|
|
|
825
1320
|
if return_values:
|
|
826
|
-
assert exists(self.
|
|
1321
|
+
assert exists(self.to_value_pred)
|
|
827
1322
|
|
|
828
|
-
values = self.
|
|
1323
|
+
values = self.to_value_pred(embed)
|
|
829
1324
|
|
|
830
|
-
if
|
|
831
|
-
|
|
832
|
-
values = rearrange(values, '... 1 -> ...')
|
|
1325
|
+
if not return_raw_value_logits:
|
|
1326
|
+
values = self.hl_gauss_loss(values) # converts the value logits to scalar values
|
|
833
1327
|
|
|
834
1328
|
out = (out, values)
|
|
835
1329
|
|
|
836
1330
|
# output and cache
|
|
837
1331
|
|
|
838
|
-
|
|
1332
|
+
next_timestep = time + timestep_start
|
|
1333
|
+
|
|
1334
|
+
# handle curtailing kv cache at the right intervals
|
|
1335
|
+
|
|
1336
|
+
window_size = self.window_size
|
|
1337
|
+
|
|
1338
|
+
if self.fixed_window_size or divisible_by(next_timestep, window_size * 2):
|
|
1339
|
+
kv_cache = kv_cache[..., -window_size:, :]
|
|
1340
|
+
|
|
1341
|
+
# maybe recurrent cache - shift the kv cache from one layer above to the one below, for extending on receptive field of past
|
|
1342
|
+
|
|
1343
|
+
if self.recurrent_kv_cache and divisible_by(next_timestep, window_size):
|
|
1344
|
+
kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
|
|
1345
|
+
|
|
1346
|
+
return out, (next_timestep, kv_cache)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: locoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.43
|
|
4
4
|
Summary: LocoFormer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/locoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/locoformer
|
|
@@ -38,8 +38,10 @@ Requires-Dist: assoc-scan
|
|
|
38
38
|
Requires-Dist: beartype
|
|
39
39
|
Requires-Dist: einops>=0.8.0
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
|
+
Requires-Dist: hl-gauss-pytorch>=0.2.0
|
|
41
42
|
Requires-Dist: rotary-embedding-torch
|
|
42
43
|
Requires-Dist: torch>=2.4
|
|
44
|
+
Requires-Dist: x-evolution
|
|
43
45
|
Requires-Dist: x-mlps-pytorch
|
|
44
46
|
Provides-Extra: examples
|
|
45
47
|
Requires-Dist: accelerate; extra == 'examples'
|
|
@@ -54,7 +56,7 @@ Description-Content-Type: text/markdown
|
|
|
54
56
|
|
|
55
57
|
[LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
|
|
56
58
|
|
|
57
|
-
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment)
|
|
59
|
+
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) and extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
58
60
|
|
|
59
61
|
## Sponsors
|
|
60
62
|
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
|
|
2
|
+
locoformer/locoformer.py,sha256=5gQTtseqs92K9ee9HJ1gEqhm8MFPFDFXPnoPxLnf8Nw,37531
|
|
3
|
+
locoformer-0.0.43.dist-info/METADATA,sha256=Vgx50wEmRpwrGxoOntARE2oU7g5TdqcM2ZUvrpOBjIk,3283
|
|
4
|
+
locoformer-0.0.43.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
locoformer-0.0.43.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
locoformer-0.0.43.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
|
|
2
|
-
locoformer/locoformer.py,sha256=1jPK41G4HB1PEPtlusQxcrne489E-3QKXAULZ20FEZM,22740
|
|
3
|
-
locoformer-0.0.15.dist-info/METADATA,sha256=IHtK7NvVQewYQ0GBB7v1KG90_H2Jakxir0MakUIA-jU,3218
|
|
4
|
-
locoformer-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
locoformer-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
locoformer-0.0.15.dist-info/RECORD,,
|
|
File without changes
|