locoformer 0.0.29__py3-none-any.whl → 0.0.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- locoformer/locoformer.py +325 -38
- {locoformer-0.0.29.dist-info → locoformer-0.0.43.dist-info}/METADATA +2 -1
- locoformer-0.0.43.dist-info/RECORD +6 -0
- {locoformer-0.0.29.dist-info → locoformer-0.0.43.dist-info}/WHEEL +1 -1
- locoformer-0.0.29.dist-info/RECORD +0 -6
- {locoformer-0.0.29.dist-info → locoformer-0.0.43.dist-info}/licenses/LICENSE +0 -0
locoformer/locoformer.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
from typing import Callable
|
|
3
|
-
from
|
|
3
|
+
from types import SimpleNamespace
|
|
4
|
+
from functools import partial, wraps
|
|
4
5
|
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from contextlib import contextmanager
|
|
7
8
|
from collections import namedtuple
|
|
8
9
|
|
|
10
|
+
from inspect import signature
|
|
11
|
+
|
|
9
12
|
import numpy as np
|
|
10
13
|
from numpy import ndarray
|
|
11
14
|
from numpy.lib.format import open_memmap
|
|
@@ -17,7 +20,7 @@ import torch
|
|
|
17
20
|
from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
|
|
18
21
|
import torch.nn.functional as F
|
|
19
22
|
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
|
|
20
|
-
from torch.utils._pytree import tree_map
|
|
23
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
21
24
|
from torch.utils.data import Dataset, DataLoader
|
|
22
25
|
from torch.optim import Optimizer
|
|
23
26
|
|
|
@@ -31,6 +34,10 @@ from hl_gauss_pytorch import HLGaussLoss
|
|
|
31
34
|
|
|
32
35
|
from assoc_scan import AssocScan
|
|
33
36
|
|
|
37
|
+
from x_mlps_pytorch import MLP
|
|
38
|
+
|
|
39
|
+
from x_evolution import EvoStrategy
|
|
40
|
+
|
|
34
41
|
# constants
|
|
35
42
|
|
|
36
43
|
LinearNoBias = partial(Linear, bias = False)
|
|
@@ -48,9 +55,40 @@ def default(v, d):
|
|
|
48
55
|
def first(arr):
|
|
49
56
|
return arr[0]
|
|
50
57
|
|
|
58
|
+
def xnor(x, y):
|
|
59
|
+
return not (x ^ y)
|
|
60
|
+
|
|
51
61
|
def divisible_by(num, den):
|
|
52
62
|
return (num % den) == 0
|
|
53
63
|
|
|
64
|
+
def get_param_names(fn):
|
|
65
|
+
parameters = signature(fn).parameters
|
|
66
|
+
return list(parameters.keys())
|
|
67
|
+
|
|
68
|
+
def check_has_param_attr(
|
|
69
|
+
param_name,
|
|
70
|
+
param_attr,
|
|
71
|
+
default_value = None
|
|
72
|
+
):
|
|
73
|
+
def decorator(fn):
|
|
74
|
+
sig = signature(fn)
|
|
75
|
+
|
|
76
|
+
@wraps(fn)
|
|
77
|
+
def inner(*args, **kwargs):
|
|
78
|
+
|
|
79
|
+
bound_args = sig.bind(*args, **kwargs).arguments
|
|
80
|
+
|
|
81
|
+
if not (
|
|
82
|
+
param_name in bound_args and
|
|
83
|
+
hasattr(bound_args[param_name], param_attr)
|
|
84
|
+
):
|
|
85
|
+
return default_value
|
|
86
|
+
|
|
87
|
+
return fn(*args, **kwargs)
|
|
88
|
+
|
|
89
|
+
return inner
|
|
90
|
+
return decorator
|
|
91
|
+
|
|
54
92
|
# tensor helpers
|
|
55
93
|
|
|
56
94
|
def log(t, eps = 1e-20):
|
|
@@ -78,10 +116,87 @@ def pad_at_dim(
|
|
|
78
116
|
def normalize(t, eps = 1e-5):
|
|
79
117
|
return (t - t.mean()) / t.std().clamp_min(eps)
|
|
80
118
|
|
|
119
|
+
def tensor_to_dict(
|
|
120
|
+
t: Tensor,
|
|
121
|
+
config: tuple[tuple[str, int] | str],
|
|
122
|
+
dim = -1,
|
|
123
|
+
return_dottable = True
|
|
124
|
+
):
|
|
125
|
+
config = tuple((c, 1) if isinstance(c, str) else c for c in config)
|
|
126
|
+
|
|
127
|
+
names, sizes = zip(*config)
|
|
128
|
+
assert sum(sizes) == t.shape[dim]
|
|
129
|
+
|
|
130
|
+
t = t.split(sizes, dim = dim)
|
|
131
|
+
tensor_dict = dict(zip(names, t))
|
|
132
|
+
|
|
133
|
+
if not return_dottable:
|
|
134
|
+
return tensor_dict
|
|
135
|
+
|
|
136
|
+
return SimpleNamespace(**tensor_dict)
|
|
137
|
+
|
|
81
138
|
def calc_entropy(logits):
|
|
82
139
|
prob = logits.softmax(dim = -1)
|
|
83
140
|
return -(prob * log(prob)).sum(dim = -1)
|
|
84
141
|
|
|
142
|
+
# reward functions - A.2
|
|
143
|
+
|
|
144
|
+
@check_has_param_attr('state', 'v_xy')
|
|
145
|
+
@check_has_param_attr('command', 'v_xy')
|
|
146
|
+
def reward_linear_velocity_command_tracking(
|
|
147
|
+
state,
|
|
148
|
+
command,
|
|
149
|
+
s1 = 1.
|
|
150
|
+
):
|
|
151
|
+
error = (state.v_xy - command.v_xy).norm(dim = -1).pow(2)
|
|
152
|
+
return torch.exp(-error / s1)
|
|
153
|
+
|
|
154
|
+
@check_has_param_attr('state', 'w_z')
|
|
155
|
+
@check_has_param_attr('command', 'w_z')
|
|
156
|
+
def reward_angular_velocity_command_tracking(
|
|
157
|
+
state,
|
|
158
|
+
command,
|
|
159
|
+
s2 = 1.
|
|
160
|
+
):
|
|
161
|
+
error = (state.w_z - command.w_z).norm(dim = -1).pow(2)
|
|
162
|
+
return torch.exp(-error / s2)
|
|
163
|
+
|
|
164
|
+
@check_has_param_attr('state', 'v_z')
|
|
165
|
+
def reward_base_linear_velocity_penalty(
|
|
166
|
+
state
|
|
167
|
+
):
|
|
168
|
+
return -state.v_z.norm(dim = -1).pow(2)
|
|
169
|
+
|
|
170
|
+
@check_has_param_attr('state', 'w_xy')
|
|
171
|
+
def reward_base_angular_velocity_penalty(
|
|
172
|
+
state
|
|
173
|
+
):
|
|
174
|
+
return -state.w_xy.norm(dim = -1).pow(2)
|
|
175
|
+
|
|
176
|
+
@check_has_param_attr('state', 'x_z')
|
|
177
|
+
def reward_base_height_penalty(
|
|
178
|
+
state,
|
|
179
|
+
x_z_nominal = 0.27
|
|
180
|
+
):
|
|
181
|
+
return -(state.x_z - x_z_nominal).norm(dim = -1).pow(2)
|
|
182
|
+
|
|
183
|
+
@check_has_param_attr('state', 'joint_q')
|
|
184
|
+
def reward_joint_acceleration_penalty(
|
|
185
|
+
state
|
|
186
|
+
):
|
|
187
|
+
return -state.joint_q.norm(dim = -1).pow(2)
|
|
188
|
+
|
|
189
|
+
@check_has_param_attr('state', 'tau')
|
|
190
|
+
def reward_torque_penalty(
|
|
191
|
+
state
|
|
192
|
+
):
|
|
193
|
+
return -state.tau.norm(dim = -1).pow(2)
|
|
194
|
+
|
|
195
|
+
def reward_alive(
|
|
196
|
+
state
|
|
197
|
+
):
|
|
198
|
+
return 1.
|
|
199
|
+
|
|
85
200
|
# generalized advantage estimate
|
|
86
201
|
|
|
87
202
|
@torch.no_grad()
|
|
@@ -266,7 +381,8 @@ class RemappedReplayDataset(Dataset):
|
|
|
266
381
|
self,
|
|
267
382
|
dataset: ReplayDataset,
|
|
268
383
|
episode_mapping: Tensor | list[list[int]],
|
|
269
|
-
shuffle_episodes = False
|
|
384
|
+
shuffle_episodes = False,
|
|
385
|
+
num_trials_select = None
|
|
270
386
|
):
|
|
271
387
|
assert len(dataset) > 0
|
|
272
388
|
self.dataset = dataset
|
|
@@ -278,6 +394,10 @@ class RemappedReplayDataset(Dataset):
|
|
|
278
394
|
self.episode_mapping = episode_mapping
|
|
279
395
|
self.shuffle_episodes = shuffle_episodes
|
|
280
396
|
|
|
397
|
+
assert not (exists(num_trials_select) and num_trials_select >= 1)
|
|
398
|
+
self.sub_select_trials = exists(num_trials_select)
|
|
399
|
+
self.num_trials_select = num_trials_select
|
|
400
|
+
|
|
281
401
|
def __len__(self):
|
|
282
402
|
return len(self.episode_mapping)
|
|
283
403
|
|
|
@@ -290,10 +410,22 @@ class RemappedReplayDataset(Dataset):
|
|
|
290
410
|
|
|
291
411
|
assert not is_empty(episode_indices)
|
|
292
412
|
|
|
293
|
-
if
|
|
413
|
+
# shuffle the episode indices if either shuffle episodes is turned on, or `num_trial_select` passed in (for sub selecting episodes from a set)
|
|
414
|
+
|
|
415
|
+
if (
|
|
416
|
+
episode_indices.numel() > 1 and
|
|
417
|
+
(self.shuffle_episodes or self.sub_select_trials)
|
|
418
|
+
):
|
|
294
419
|
num_episodes = len(episode_indices)
|
|
295
420
|
episode_indices = episode_indices[torch.randperm(num_episodes)]
|
|
296
421
|
|
|
422
|
+
# crop out the episodes
|
|
423
|
+
|
|
424
|
+
if self.sub_select_trials:
|
|
425
|
+
episode_indices = episode_indices[:self.num_trials_select]
|
|
426
|
+
|
|
427
|
+
# now select out the episode data and merge along time
|
|
428
|
+
|
|
297
429
|
episode_data = [self.dataset[i] for i in episode_indices.tolist()]
|
|
298
430
|
|
|
299
431
|
episode_lens = stack([data.pop('_lens') for data in episode_data])
|
|
@@ -368,6 +500,10 @@ class ReplayBuffer:
|
|
|
368
500
|
# memmap file
|
|
369
501
|
|
|
370
502
|
filepath = folder / f'{field_name}.data.npy'
|
|
503
|
+
|
|
504
|
+
if isinstance(shape, int):
|
|
505
|
+
shape = (shape,)
|
|
506
|
+
|
|
371
507
|
memmap = open_memmap(str(filepath), mode = 'w+', dtype = dtype, shape = (max_episodes, max_timesteps, *shape))
|
|
372
508
|
|
|
373
509
|
self.memmaps[field_name] = memmap
|
|
@@ -463,6 +599,70 @@ class ReplayBuffer:
|
|
|
463
599
|
|
|
464
600
|
return DataLoader(self.dataset(episode_mapping), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
|
|
465
601
|
|
|
602
|
+
# normalization + conditioning (needed for the commands to the robot)
|
|
603
|
+
|
|
604
|
+
class MaybeAdaRMSNormWrapper(Module):
|
|
605
|
+
def __init__(
|
|
606
|
+
self,
|
|
607
|
+
fn: Module,
|
|
608
|
+
dim,
|
|
609
|
+
dim_cond = None
|
|
610
|
+
):
|
|
611
|
+
super().__init__()
|
|
612
|
+
condition = exists(dim_cond)
|
|
613
|
+
|
|
614
|
+
self.fn = fn
|
|
615
|
+
self.norm = nn.RMSNorm(dim, elementwise_affine = not condition)
|
|
616
|
+
|
|
617
|
+
self.accept_condition = condition
|
|
618
|
+
|
|
619
|
+
if condition:
|
|
620
|
+
self.to_gamma = LinearNoBias(dim_cond, dim)
|
|
621
|
+
self.to_ada_norm_zero = nn.Linear(dim_cond, dim)
|
|
622
|
+
|
|
623
|
+
nn.init.zeros_(self.to_gamma.weight)
|
|
624
|
+
nn.init.zeros_(self.to_ada_norm_zero.weight)
|
|
625
|
+
nn.init.constant_(self.to_ada_norm_zero.bias, -5.)
|
|
626
|
+
|
|
627
|
+
def forward(
|
|
628
|
+
self,
|
|
629
|
+
x,
|
|
630
|
+
cond = None,
|
|
631
|
+
**kwargs
|
|
632
|
+
):
|
|
633
|
+
|
|
634
|
+
need_cond = self.accept_condition
|
|
635
|
+
|
|
636
|
+
assert xnor(exists(cond), need_cond)
|
|
637
|
+
|
|
638
|
+
prenormed = self.norm(x)
|
|
639
|
+
|
|
640
|
+
if need_cond:
|
|
641
|
+
if cond.ndim == 2:
|
|
642
|
+
cond = rearrange(cond, 'b d -> b 1 d')
|
|
643
|
+
|
|
644
|
+
scale_in = self.to_gamma(cond)
|
|
645
|
+
prenormed = prenormed * (scale_in + 1.)
|
|
646
|
+
|
|
647
|
+
all_fn_out = self.fn(prenormed, **kwargs)
|
|
648
|
+
|
|
649
|
+
if not need_cond:
|
|
650
|
+
return all_fn_out
|
|
651
|
+
|
|
652
|
+
# function may return multiple args
|
|
653
|
+
|
|
654
|
+
(out, *rest), tree_spec = tree_flatten(all_fn_out)
|
|
655
|
+
|
|
656
|
+
if need_cond:
|
|
657
|
+
scale_out = self.to_ada_norm_zero(cond).sigmoid()
|
|
658
|
+
out = out * scale_out
|
|
659
|
+
|
|
660
|
+
# restore
|
|
661
|
+
|
|
662
|
+
all_fn_out = tree_unflatten((out, *rest), tree_spec)
|
|
663
|
+
|
|
664
|
+
return all_fn_out
|
|
665
|
+
|
|
466
666
|
# transformer-xl with ppo
|
|
467
667
|
|
|
468
668
|
class Attention(Module):
|
|
@@ -472,15 +672,12 @@ class Attention(Module):
|
|
|
472
672
|
window_size,
|
|
473
673
|
dim_head = 64,
|
|
474
674
|
heads = 8,
|
|
475
|
-
pre_rmsnorm = True,
|
|
476
675
|
fixed_window_size = False,
|
|
477
676
|
accept_value_residual = False
|
|
478
677
|
):
|
|
479
678
|
super().__init__()
|
|
480
679
|
self.scale = dim_head ** -0.5
|
|
481
680
|
|
|
482
|
-
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
|
483
|
-
|
|
484
681
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
|
485
682
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
486
683
|
|
|
@@ -524,8 +721,6 @@ class Attention(Module):
|
|
|
524
721
|
|
|
525
722
|
device = tokens.device
|
|
526
723
|
|
|
527
|
-
tokens = self.norm(tokens)
|
|
528
|
-
|
|
529
724
|
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
|
|
530
725
|
|
|
531
726
|
q, k, v = map(self.split_heads, (q, k, v))
|
|
@@ -614,19 +809,26 @@ class TransformerXL(Module):
|
|
|
614
809
|
dim_head = 64,
|
|
615
810
|
heads = 8,
|
|
616
811
|
expansion_factor = 4.,
|
|
812
|
+
dim_cond = None,
|
|
617
813
|
final_norm = True,
|
|
618
814
|
fixed_window_size = False,
|
|
619
815
|
):
|
|
620
816
|
super().__init__()
|
|
621
817
|
|
|
818
|
+
condition = exists(dim_cond)
|
|
819
|
+
|
|
820
|
+
self.to_cond_tokens = MLP(dim_cond, dim * 2, activate_last = True) if exists(dim_cond) else None
|
|
821
|
+
|
|
822
|
+
norm_fn = partial(MaybeAdaRMSNormWrapper, dim = dim, dim_cond = (dim * 2) if condition else None)
|
|
823
|
+
|
|
622
824
|
layers = ModuleList([])
|
|
623
825
|
|
|
624
826
|
for i in range(depth):
|
|
625
827
|
is_first = i == 0
|
|
626
828
|
|
|
627
|
-
attn = Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first)
|
|
829
|
+
attn = norm_fn(Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first))
|
|
628
830
|
|
|
629
|
-
ff = FeedForward(dim = dim, expansion_factor = expansion_factor)
|
|
831
|
+
ff = norm_fn(FeedForward(dim = dim, expansion_factor = expansion_factor))
|
|
630
832
|
|
|
631
833
|
layers.append(ModuleList([
|
|
632
834
|
attn, ff
|
|
@@ -644,20 +846,32 @@ class TransformerXL(Module):
|
|
|
644
846
|
self,
|
|
645
847
|
x,
|
|
646
848
|
cache = None,
|
|
647
|
-
return_kv_cache = False
|
|
849
|
+
return_kv_cache = False,
|
|
850
|
+
condition: Tensor | None = None
|
|
648
851
|
):
|
|
649
852
|
|
|
853
|
+
# cache and residuals
|
|
854
|
+
|
|
650
855
|
cache = default(cache, (None,) * len(self.layers))
|
|
651
856
|
|
|
652
857
|
next_kv_caches = []
|
|
653
858
|
value_residual = None
|
|
654
859
|
|
|
860
|
+
# handle condition
|
|
861
|
+
|
|
862
|
+
cond_tokens = None
|
|
863
|
+
if exists(condition):
|
|
864
|
+
assert exists(self.to_cond_tokens)
|
|
865
|
+
cond_tokens = self.to_cond_tokens(condition)
|
|
866
|
+
|
|
867
|
+
# layers
|
|
868
|
+
|
|
655
869
|
for (attn, ff), kv_cache in zip(self.layers, cache):
|
|
656
870
|
|
|
657
|
-
attn_out, (next_kv_cache, values) = attn(x, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
|
|
871
|
+
attn_out, (next_kv_cache, values) = attn(x, cond = cond_tokens, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
|
|
658
872
|
|
|
659
873
|
x = attn_out + x
|
|
660
|
-
x = ff(x) + x
|
|
874
|
+
x = ff(x, cond = cond_tokens) + x
|
|
661
875
|
|
|
662
876
|
next_kv_caches.append(next_kv_cache)
|
|
663
877
|
value_residual = default(value_residual, values)
|
|
@@ -678,7 +892,7 @@ class TransformerXL(Module):
|
|
|
678
892
|
class Locoformer(Module):
|
|
679
893
|
def __init__(
|
|
680
894
|
self,
|
|
681
|
-
embedder: Module,
|
|
895
|
+
embedder: Module | ModuleList | list[Module],
|
|
682
896
|
unembedder: Module,
|
|
683
897
|
transformer: dict | TransformerXL,
|
|
684
898
|
discount_factor = 0.999,
|
|
@@ -686,10 +900,10 @@ class Locoformer(Module):
|
|
|
686
900
|
ppo_eps_clip = 0.2,
|
|
687
901
|
ppo_entropy_weight = 0.01,
|
|
688
902
|
ppo_value_clip = 0.4,
|
|
689
|
-
dim_value_input = None,
|
|
903
|
+
dim_value_input = None, # needs to be set for value network to be available
|
|
690
904
|
value_network: Module = nn.Identity(),
|
|
691
905
|
reward_range: tuple[float, float] | None = None,
|
|
692
|
-
reward_shaping_fns: list[Callable[
|
|
906
|
+
reward_shaping_fns: list[Callable[..., float | Tensor]] | None = None,
|
|
693
907
|
num_reward_bins = 32,
|
|
694
908
|
hl_gauss_loss_kwargs = dict(),
|
|
695
909
|
value_loss_weight = 0.5,
|
|
@@ -704,7 +918,15 @@ class Locoformer(Module):
|
|
|
704
918
|
|
|
705
919
|
self.transformer = transformer
|
|
706
920
|
|
|
921
|
+
# handle state embedder
|
|
922
|
+
|
|
923
|
+
if isinstance(embedder, list):
|
|
924
|
+
embedder = ModuleList(embedder)
|
|
925
|
+
|
|
707
926
|
self.embedder = embedder
|
|
927
|
+
|
|
928
|
+
# unembed state to actions or ssl predictions
|
|
929
|
+
|
|
708
930
|
self.unembedder = unembedder
|
|
709
931
|
|
|
710
932
|
self.fixed_window_size = transformer.fixed_window_size
|
|
@@ -746,7 +968,7 @@ class Locoformer(Module):
|
|
|
746
968
|
|
|
747
969
|
self.use_spo = use_spo
|
|
748
970
|
|
|
749
|
-
# maybe recurrent kv cache
|
|
971
|
+
# maybe recurrent kv cache, from Ding et al. https://arxiv.org/abs/2012.15688
|
|
750
972
|
|
|
751
973
|
self.recurrent_kv_cache = recurrent_kv_cache
|
|
752
974
|
|
|
@@ -772,6 +994,14 @@ class Locoformer(Module):
|
|
|
772
994
|
|
|
773
995
|
return self.to_value_pred.parameters()
|
|
774
996
|
|
|
997
|
+
def evolve(
|
|
998
|
+
self,
|
|
999
|
+
environment,
|
|
1000
|
+
**kwargs
|
|
1001
|
+
):
|
|
1002
|
+
evo_strat = EvoStrategy(self, environment = environment, **kwargs)
|
|
1003
|
+
evo_strat()
|
|
1004
|
+
|
|
775
1005
|
def ppo(
|
|
776
1006
|
self,
|
|
777
1007
|
state,
|
|
@@ -781,6 +1011,8 @@ class Locoformer(Module):
|
|
|
781
1011
|
old_value,
|
|
782
1012
|
mask,
|
|
783
1013
|
episode_lens,
|
|
1014
|
+
condition: Tensor | None = None,
|
|
1015
|
+
state_type: int | None = None,
|
|
784
1016
|
actor_optim: Optimizer | None = None,
|
|
785
1017
|
critic_optim: Optimizer | None = None
|
|
786
1018
|
):
|
|
@@ -794,18 +1026,25 @@ class Locoformer(Module):
|
|
|
794
1026
|
|
|
795
1027
|
advantage = normalize(advantage)
|
|
796
1028
|
|
|
1029
|
+
data_tensors = (
|
|
1030
|
+
state,
|
|
1031
|
+
action,
|
|
1032
|
+
old_action_log_prob,
|
|
1033
|
+
reward,
|
|
1034
|
+
old_value,
|
|
1035
|
+
mask,
|
|
1036
|
+
advantage,
|
|
1037
|
+
returns
|
|
1038
|
+
)
|
|
1039
|
+
|
|
1040
|
+
has_condition = exists(condition)
|
|
1041
|
+
|
|
1042
|
+
if exists(condition):
|
|
1043
|
+
data_tensors = (*data_tensors, condition)
|
|
1044
|
+
|
|
797
1045
|
windowed_tensors = [
|
|
798
1046
|
t.split(window_size, dim = 1) for t in
|
|
799
|
-
|
|
800
|
-
state,
|
|
801
|
-
action,
|
|
802
|
-
old_action_log_prob,
|
|
803
|
-
reward,
|
|
804
|
-
old_value,
|
|
805
|
-
mask,
|
|
806
|
-
advantage,
|
|
807
|
-
returns
|
|
808
|
-
)
|
|
1047
|
+
data_tensors
|
|
809
1048
|
]
|
|
810
1049
|
|
|
811
1050
|
mean_actor_loss = self.zero.clone()
|
|
@@ -823,10 +1062,14 @@ class Locoformer(Module):
|
|
|
823
1062
|
old_value,
|
|
824
1063
|
mask,
|
|
825
1064
|
advantage,
|
|
826
|
-
returns
|
|
1065
|
+
returns,
|
|
1066
|
+
*rest
|
|
827
1067
|
) in zip(*windowed_tensors):
|
|
828
1068
|
|
|
829
|
-
|
|
1069
|
+
if has_condition:
|
|
1070
|
+
condition, = rest
|
|
1071
|
+
|
|
1072
|
+
(action_logits, value_logits), cache = self.forward(state, condition = condition, state_type = state_type, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True)
|
|
830
1073
|
entropy = calc_entropy(action_logits)
|
|
831
1074
|
|
|
832
1075
|
action = rearrange(action, 'b t -> b t 1')
|
|
@@ -882,16 +1125,33 @@ class Locoformer(Module):
|
|
|
882
1125
|
|
|
883
1126
|
return mean_actor_loss.detach(), mean_critic_loss.detach()
|
|
884
1127
|
|
|
885
|
-
def
|
|
1128
|
+
def state_and_command_to_rewards(
|
|
886
1129
|
self,
|
|
887
|
-
state
|
|
1130
|
+
state,
|
|
1131
|
+
commands = None
|
|
888
1132
|
) -> Tensor:
|
|
889
1133
|
|
|
890
1134
|
assert self.has_reward_shaping
|
|
891
1135
|
|
|
892
|
-
rewards = [
|
|
1136
|
+
rewards = []
|
|
1137
|
+
|
|
1138
|
+
for fn in self.reward_shaping_fns:
|
|
1139
|
+
param_names = get_param_names(fn)
|
|
1140
|
+
param_names = set(param_names) & {'state', 'command'}
|
|
1141
|
+
|
|
1142
|
+
if param_names == {'state'}: # only state
|
|
1143
|
+
reward = fn(state = state)
|
|
1144
|
+
elif param_names == {'state', 'command'}: # state and command
|
|
1145
|
+
reward = fn(state = state, command = commands)
|
|
1146
|
+
else:
|
|
1147
|
+
raise ValueError('invalid number of arguments for reward shaping function')
|
|
1148
|
+
|
|
1149
|
+
rewards.append(reward)
|
|
1150
|
+
|
|
1151
|
+
# cast to Tensor if returns a float, just make it flexible for researcher
|
|
893
1152
|
|
|
894
1153
|
rewards = [tensor(reward) if not is_tensor(reward) else reward for reward in rewards]
|
|
1154
|
+
|
|
895
1155
|
return stack(rewards)
|
|
896
1156
|
|
|
897
1157
|
def wrap_env_functions(self, env):
|
|
@@ -921,7 +1181,7 @@ class Locoformer(Module):
|
|
|
921
1181
|
if not self.has_reward_shaping:
|
|
922
1182
|
return env_step_out_torch
|
|
923
1183
|
|
|
924
|
-
shaped_rewards = self.
|
|
1184
|
+
shaped_rewards = self.state_and_command_to_rewards(env_step_out_torch)
|
|
925
1185
|
|
|
926
1186
|
return env_step_out_torch, shaped_rewards
|
|
927
1187
|
|
|
@@ -940,20 +1200,36 @@ class Locoformer(Module):
|
|
|
940
1200
|
|
|
941
1201
|
cache = None
|
|
942
1202
|
|
|
943
|
-
def stateful_forward(
|
|
1203
|
+
def stateful_forward(
|
|
1204
|
+
state: Tensor,
|
|
1205
|
+
condition: Tensor | None = None,
|
|
1206
|
+
state_type: int | None = None,
|
|
1207
|
+
**override_kwargs
|
|
1208
|
+
):
|
|
944
1209
|
nonlocal cache
|
|
945
1210
|
|
|
1211
|
+
state = state.to(self.device)
|
|
1212
|
+
|
|
1213
|
+
if exists(condition):
|
|
1214
|
+
condition = condition.to(self.device)
|
|
1215
|
+
|
|
946
1216
|
# handle no batch or time, for easier time rolling out against envs
|
|
947
1217
|
|
|
948
1218
|
if not has_batch_dim:
|
|
949
1219
|
state = rearrange(state, '... -> 1 ...')
|
|
950
1220
|
|
|
1221
|
+
if exists(condition):
|
|
1222
|
+
condition = rearrange(condition, '... -> 1 ...')
|
|
1223
|
+
|
|
951
1224
|
if not has_time_dim:
|
|
952
1225
|
state = state.unsqueeze(state_time_dim)
|
|
953
1226
|
|
|
1227
|
+
if exists(condition):
|
|
1228
|
+
condition = rearrange(condition, '... d -> ... 1 d')
|
|
1229
|
+
|
|
954
1230
|
# forwards
|
|
955
1231
|
|
|
956
|
-
out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
|
|
1232
|
+
out, cache = self.forward(state, condition = condition, state_type = state_type, cache = cache, **{**kwargs, **override_kwargs})
|
|
957
1233
|
|
|
958
1234
|
# maybe remove batch or time
|
|
959
1235
|
|
|
@@ -988,6 +1264,8 @@ class Locoformer(Module):
|
|
|
988
1264
|
self,
|
|
989
1265
|
state: Tensor,
|
|
990
1266
|
cache: Cache | None = None,
|
|
1267
|
+
condition: Tensor | None = None,
|
|
1268
|
+
state_type: int | None = None,
|
|
991
1269
|
detach_cache = False,
|
|
992
1270
|
return_values = False,
|
|
993
1271
|
return_raw_value_logits = False
|
|
@@ -995,7 +1273,16 @@ class Locoformer(Module):
|
|
|
995
1273
|
|
|
996
1274
|
state = state.to(self.device)
|
|
997
1275
|
|
|
998
|
-
|
|
1276
|
+
# determine which function to invoke for state to token for transformer
|
|
1277
|
+
|
|
1278
|
+
state_to_token = self.embedder
|
|
1279
|
+
|
|
1280
|
+
if exists(state_type):
|
|
1281
|
+
state_to_token = self.embedder[state_type]
|
|
1282
|
+
|
|
1283
|
+
# embed
|
|
1284
|
+
|
|
1285
|
+
tokens = state_to_token(state)
|
|
999
1286
|
|
|
1000
1287
|
# time
|
|
1001
1288
|
|
|
@@ -1015,7 +1302,7 @@ class Locoformer(Module):
|
|
|
1015
1302
|
|
|
1016
1303
|
# attention
|
|
1017
1304
|
|
|
1018
|
-
embed, kv_cache = self.transformer(tokens, cache = prev_kv_cache, return_kv_cache = True)
|
|
1305
|
+
embed, kv_cache = self.transformer(tokens, condition = condition, cache = prev_kv_cache, return_kv_cache = True)
|
|
1019
1306
|
|
|
1020
1307
|
# unembed to actions - in language models this would be the next state
|
|
1021
1308
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: locoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.43
|
|
4
4
|
Summary: LocoFormer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/locoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/locoformer
|
|
@@ -41,6 +41,7 @@ Requires-Dist: einx>=0.3.0
|
|
|
41
41
|
Requires-Dist: hl-gauss-pytorch>=0.2.0
|
|
42
42
|
Requires-Dist: rotary-embedding-torch
|
|
43
43
|
Requires-Dist: torch>=2.4
|
|
44
|
+
Requires-Dist: x-evolution
|
|
44
45
|
Requires-Dist: x-mlps-pytorch
|
|
45
46
|
Provides-Extra: examples
|
|
46
47
|
Requires-Dist: accelerate; extra == 'examples'
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
|
|
2
|
+
locoformer/locoformer.py,sha256=5gQTtseqs92K9ee9HJ1gEqhm8MFPFDFXPnoPxLnf8Nw,37531
|
|
3
|
+
locoformer-0.0.43.dist-info/METADATA,sha256=Vgx50wEmRpwrGxoOntARE2oU7g5TdqcM2ZUvrpOBjIk,3283
|
|
4
|
+
locoformer-0.0.43.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
locoformer-0.0.43.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
locoformer-0.0.43.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
|
|
2
|
-
locoformer/locoformer.py,sha256=Tr_1btuoTZ0huXeDcAeuHxTPaVeCUEGc5iLvMYGDLck,29982
|
|
3
|
-
locoformer-0.0.29.dist-info/METADATA,sha256=5Fi3EOsgpBvpzAFVZQyrlink-HcHE8EgFl10Y5l8mqM,3256
|
|
4
|
-
locoformer-0.0.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
locoformer-0.0.29.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
locoformer-0.0.29.dist-info/RECORD,,
|
|
File without changes
|