locoformer 0.0.15__py3-none-any.whl → 0.0.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- locoformer/locoformer.py +296 -75
- {locoformer-0.0.15.dist-info → locoformer-0.0.29.dist-info}/METADATA +3 -2
- locoformer-0.0.29.dist-info/RECORD +6 -0
- locoformer-0.0.15.dist-info/RECORD +0 -6
- {locoformer-0.0.15.dist-info → locoformer-0.0.29.dist-info}/WHEEL +0 -0
- {locoformer-0.0.15.dist-info → locoformer-0.0.29.dist-info}/licenses/LICENSE +0 -0
locoformer/locoformer.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
2
3
|
from functools import partial
|
|
3
4
|
|
|
4
5
|
from pathlib import Path
|
|
@@ -18,6 +19,7 @@ import torch.nn.functional as F
|
|
|
18
19
|
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
|
|
19
20
|
from torch.utils._pytree import tree_map
|
|
20
21
|
from torch.utils.data import Dataset, DataLoader
|
|
22
|
+
from torch.optim import Optimizer
|
|
21
23
|
|
|
22
24
|
import einx
|
|
23
25
|
from einops import rearrange, einsum
|
|
@@ -25,10 +27,16 @@ from einops.layers.torch import Rearrange
|
|
|
25
27
|
|
|
26
28
|
from rotary_embedding_torch import RotaryEmbedding
|
|
27
29
|
|
|
30
|
+
from hl_gauss_pytorch import HLGaussLoss
|
|
31
|
+
|
|
28
32
|
from assoc_scan import AssocScan
|
|
29
33
|
|
|
34
|
+
# constants
|
|
35
|
+
|
|
30
36
|
LinearNoBias = partial(Linear, bias = False)
|
|
31
37
|
|
|
38
|
+
Cache = namedtuple('Cache', ('curr_timestep', 'kv_cache')) # (int, Tensor)
|
|
39
|
+
|
|
32
40
|
# helper functions
|
|
33
41
|
|
|
34
42
|
def exists(v):
|
|
@@ -48,12 +56,12 @@ def divisible_by(num, den):
|
|
|
48
56
|
def log(t, eps = 1e-20):
|
|
49
57
|
return t.clamp_min(eps).log()
|
|
50
58
|
|
|
59
|
+
def is_empty(t):
|
|
60
|
+
return t.numel() == 0
|
|
61
|
+
|
|
51
62
|
def tree_map_tensor(x, fn):
|
|
52
63
|
return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
|
|
53
64
|
|
|
54
|
-
def detach_all(x):
|
|
55
|
-
return tree_map_tensor(x, lambda t: t.detach())
|
|
56
|
-
|
|
57
65
|
def pad_at_dim(
|
|
58
66
|
t,
|
|
59
67
|
pad: tuple[int, int],
|
|
@@ -67,6 +75,9 @@ def pad_at_dim(
|
|
|
67
75
|
zeros = ((0, 0) * dims_from_right)
|
|
68
76
|
return F.pad(t, (*zeros, *pad), value = value)
|
|
69
77
|
|
|
78
|
+
def normalize(t, eps = 1e-5):
|
|
79
|
+
return (t - t.mean()) / t.std().clamp_min(eps)
|
|
80
|
+
|
|
70
81
|
def calc_entropy(logits):
|
|
71
82
|
prob = logits.softmax(dim = -1)
|
|
72
83
|
return -(prob * log(prob)).sum(dim = -1)
|
|
@@ -100,7 +111,7 @@ def calc_gae(
|
|
|
100
111
|
|
|
101
112
|
returns = gae + values
|
|
102
113
|
|
|
103
|
-
return returns
|
|
114
|
+
return gae, returns
|
|
104
115
|
|
|
105
116
|
# transformer-xl mask w/ flex attn
|
|
106
117
|
|
|
@@ -250,6 +261,57 @@ class ReplayDataset(Dataset):
|
|
|
250
261
|
|
|
251
262
|
return data
|
|
252
263
|
|
|
264
|
+
class RemappedReplayDataset(Dataset):
|
|
265
|
+
def __init__(
|
|
266
|
+
self,
|
|
267
|
+
dataset: ReplayDataset,
|
|
268
|
+
episode_mapping: Tensor | list[list[int]],
|
|
269
|
+
shuffle_episodes = False
|
|
270
|
+
):
|
|
271
|
+
assert len(dataset) > 0
|
|
272
|
+
self.dataset = dataset
|
|
273
|
+
|
|
274
|
+
if is_tensor(episode_mapping):
|
|
275
|
+
assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
|
|
276
|
+
episode_mapping = episode_mapping.tolist()
|
|
277
|
+
|
|
278
|
+
self.episode_mapping = episode_mapping
|
|
279
|
+
self.shuffle_episodes = shuffle_episodes
|
|
280
|
+
|
|
281
|
+
def __len__(self):
|
|
282
|
+
return len(self.episode_mapping)
|
|
283
|
+
|
|
284
|
+
def __getitem__(self, idx):
|
|
285
|
+
|
|
286
|
+
episode_indices = self.episode_mapping[idx]
|
|
287
|
+
|
|
288
|
+
episode_indices = tensor(episode_indices)
|
|
289
|
+
episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
|
|
290
|
+
|
|
291
|
+
assert not is_empty(episode_indices)
|
|
292
|
+
|
|
293
|
+
if self.shuffle_episodes and episode_indices.numel() > 1:
|
|
294
|
+
num_episodes = len(episode_indices)
|
|
295
|
+
episode_indices = episode_indices[torch.randperm(num_episodes)]
|
|
296
|
+
|
|
297
|
+
episode_data = [self.dataset[i] for i in episode_indices.tolist()]
|
|
298
|
+
|
|
299
|
+
episode_lens = stack([data.pop('_lens') for data in episode_data])
|
|
300
|
+
|
|
301
|
+
keys = first(episode_data).keys()
|
|
302
|
+
|
|
303
|
+
values = [list(data.values()) for data in episode_data]
|
|
304
|
+
|
|
305
|
+
values = [cat(field_values) for field_values in zip(*values)] # concat across time
|
|
306
|
+
|
|
307
|
+
multi_episode_data = dict(zip(keys, values))
|
|
308
|
+
|
|
309
|
+
multi_episode_data['_lens'] = episode_lens.sum()
|
|
310
|
+
|
|
311
|
+
multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
|
|
312
|
+
|
|
313
|
+
return multi_episode_data
|
|
314
|
+
|
|
253
315
|
class ReplayBuffer:
|
|
254
316
|
|
|
255
317
|
@beartype
|
|
@@ -314,6 +376,9 @@ class ReplayBuffer:
|
|
|
314
376
|
|
|
315
377
|
self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
|
|
316
378
|
|
|
379
|
+
def __len__(self):
|
|
380
|
+
return (self.episode_lens > 0).sum().item()
|
|
381
|
+
|
|
317
382
|
def reset_(self):
|
|
318
383
|
self.episode_lens[:] = 0
|
|
319
384
|
self.episode_index = 0
|
|
@@ -375,15 +440,28 @@ class ReplayBuffer:
|
|
|
375
440
|
|
|
376
441
|
return self.memory_namedtuple(**data)
|
|
377
442
|
|
|
378
|
-
def dataset(
|
|
443
|
+
def dataset(
|
|
444
|
+
self,
|
|
445
|
+
episode_mapping: Tensor | list[list[int]] | None = None,
|
|
446
|
+
) -> Dataset:
|
|
379
447
|
self.flush()
|
|
380
448
|
|
|
381
|
-
|
|
449
|
+
dataset = ReplayDataset(self.folder)
|
|
450
|
+
|
|
451
|
+
if not exists(episode_mapping):
|
|
452
|
+
return dataset
|
|
382
453
|
|
|
383
|
-
|
|
454
|
+
return RemappedReplayDataset(dataset, episode_mapping)
|
|
455
|
+
|
|
456
|
+
def dataloader(
|
|
457
|
+
self,
|
|
458
|
+
batch_size,
|
|
459
|
+
episode_mapping: Tensor | list[list[int]] | None = None,
|
|
460
|
+
**kwargs
|
|
461
|
+
) -> DataLoader:
|
|
384
462
|
self.flush()
|
|
385
463
|
|
|
386
|
-
return DataLoader(self.dataset(), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
|
|
464
|
+
return DataLoader(self.dataset(episode_mapping), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
|
|
387
465
|
|
|
388
466
|
# transformer-xl with ppo
|
|
389
467
|
|
|
@@ -603,13 +681,21 @@ class Locoformer(Module):
|
|
|
603
681
|
embedder: Module,
|
|
604
682
|
unembedder: Module,
|
|
605
683
|
transformer: dict | TransformerXL,
|
|
606
|
-
value_network: Module | None = None,
|
|
607
684
|
discount_factor = 0.999,
|
|
608
685
|
gae_lam = 0.95,
|
|
609
686
|
ppo_eps_clip = 0.2,
|
|
610
687
|
ppo_entropy_weight = 0.01,
|
|
611
688
|
ppo_value_clip = 0.4,
|
|
612
|
-
|
|
689
|
+
dim_value_input = None, # needs to be set for value network to be available
|
|
690
|
+
value_network: Module = nn.Identity(),
|
|
691
|
+
reward_range: tuple[float, float] | None = None,
|
|
692
|
+
reward_shaping_fns: list[Callable[[Tensor], float | Tensor]] | None = None,
|
|
693
|
+
num_reward_bins = 32,
|
|
694
|
+
hl_gauss_loss_kwargs = dict(),
|
|
695
|
+
value_loss_weight = 0.5,
|
|
696
|
+
calc_gae_kwargs: dict = dict(),
|
|
697
|
+
recurrent_kv_cache = True,
|
|
698
|
+
use_spo = False # simple policy optimization https://arxiv.org/abs/2401.16025 - Levine's group (PI) verified it is more stable than PPO
|
|
613
699
|
):
|
|
614
700
|
super().__init__()
|
|
615
701
|
|
|
@@ -621,11 +707,30 @@ class Locoformer(Module):
|
|
|
621
707
|
self.embedder = embedder
|
|
622
708
|
self.unembedder = unembedder
|
|
623
709
|
|
|
624
|
-
self.value_network = value_network
|
|
625
|
-
|
|
626
710
|
self.fixed_window_size = transformer.fixed_window_size
|
|
627
711
|
self.window_size = transformer.window_size
|
|
628
712
|
|
|
713
|
+
# determine value network, using HL Gauss Layer
|
|
714
|
+
|
|
715
|
+
self.to_value_pred = None
|
|
716
|
+
|
|
717
|
+
if exists(dim_value_input):
|
|
718
|
+
assert exists(reward_range)
|
|
719
|
+
|
|
720
|
+
self.to_value_pred = nn.Sequential(
|
|
721
|
+
value_network,
|
|
722
|
+
LinearNoBias(dim_value_input, num_reward_bins)
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
reward_min, reward_max = reward_range
|
|
726
|
+
|
|
727
|
+
self.hl_gauss_loss = HLGaussLoss(
|
|
728
|
+
min_value = reward_min,
|
|
729
|
+
max_value = reward_max,
|
|
730
|
+
num_bins = num_reward_bins,
|
|
731
|
+
**hl_gauss_loss_kwargs
|
|
732
|
+
)
|
|
733
|
+
|
|
629
734
|
# ppo related
|
|
630
735
|
|
|
631
736
|
self.discount_factor = discount_factor
|
|
@@ -635,6 +740,25 @@ class Locoformer(Module):
|
|
|
635
740
|
self.ppo_value_clip = ppo_value_clip
|
|
636
741
|
self.value_loss_weight = value_loss_weight
|
|
637
742
|
|
|
743
|
+
self.calc_gae_kwargs = calc_gae_kwargs
|
|
744
|
+
|
|
745
|
+
# maybe use spo
|
|
746
|
+
|
|
747
|
+
self.use_spo = use_spo
|
|
748
|
+
|
|
749
|
+
# maybe recurrent kv cache (todo: find and cite this paper from ages ago)
|
|
750
|
+
|
|
751
|
+
self.recurrent_kv_cache = recurrent_kv_cache
|
|
752
|
+
|
|
753
|
+
# reward shaping function
|
|
754
|
+
|
|
755
|
+
self.has_reward_shaping = exists(reward_shaping_fns)
|
|
756
|
+
self.reward_shaping_fns = reward_shaping_fns
|
|
757
|
+
|
|
758
|
+
# loss related
|
|
759
|
+
|
|
760
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
761
|
+
|
|
638
762
|
@property
|
|
639
763
|
def device(self):
|
|
640
764
|
return next(self.parameters()).device
|
|
@@ -643,10 +767,10 @@ class Locoformer(Module):
|
|
|
643
767
|
return self.unembedder.parameters()
|
|
644
768
|
|
|
645
769
|
def critic_parameters(self):
|
|
646
|
-
if not exists(self.
|
|
770
|
+
if not exists(self.to_value_pred):
|
|
647
771
|
return []
|
|
648
772
|
|
|
649
|
-
return self.
|
|
773
|
+
return self.to_value_pred.parameters()
|
|
650
774
|
|
|
651
775
|
def ppo(
|
|
652
776
|
self,
|
|
@@ -656,79 +780,150 @@ class Locoformer(Module):
|
|
|
656
780
|
reward,
|
|
657
781
|
old_value,
|
|
658
782
|
mask,
|
|
659
|
-
|
|
660
|
-
|
|
783
|
+
episode_lens,
|
|
784
|
+
actor_optim: Optimizer | None = None,
|
|
785
|
+
critic_optim: Optimizer | None = None
|
|
661
786
|
):
|
|
787
|
+
window_size = self.window_size
|
|
788
|
+
total_learnable_tokens = mask.sum().item()
|
|
789
|
+
|
|
790
|
+
seq_len = state.shape[1]
|
|
791
|
+
gae_mask = einx.less('j, i -> i j', arange(seq_len, device = self.device), episode_lens)
|
|
792
|
+
|
|
793
|
+
advantage, returns = calc_gae(reward, old_value, masks = gae_mask, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
|
|
794
|
+
|
|
795
|
+
advantage = normalize(advantage)
|
|
796
|
+
|
|
797
|
+
windowed_tensors = [
|
|
798
|
+
t.split(window_size, dim = 1) for t in
|
|
799
|
+
(
|
|
800
|
+
state,
|
|
801
|
+
action,
|
|
802
|
+
old_action_log_prob,
|
|
803
|
+
reward,
|
|
804
|
+
old_value,
|
|
805
|
+
mask,
|
|
806
|
+
advantage,
|
|
807
|
+
returns
|
|
808
|
+
)
|
|
809
|
+
]
|
|
810
|
+
|
|
811
|
+
mean_actor_loss = self.zero.clone()
|
|
812
|
+
mean_critic_loss = self.zero.clone()
|
|
662
813
|
|
|
663
|
-
|
|
664
|
-
entropy = calc_entropy(action_logits)
|
|
814
|
+
# learn across windows
|
|
665
815
|
|
|
666
|
-
|
|
667
|
-
log_prob = action_logits.gather(-1, action)
|
|
668
|
-
log_prob = rearrange(log_prob, 'b t 1 -> b t')
|
|
816
|
+
cache = None
|
|
669
817
|
|
|
670
|
-
|
|
818
|
+
for (
|
|
819
|
+
state,
|
|
820
|
+
action,
|
|
821
|
+
old_action_log_prob,
|
|
822
|
+
reward,
|
|
823
|
+
old_value,
|
|
824
|
+
mask,
|
|
825
|
+
advantage,
|
|
826
|
+
returns
|
|
827
|
+
) in zip(*windowed_tensors):
|
|
671
828
|
|
|
672
|
-
|
|
673
|
-
|
|
829
|
+
(action_logits, value_logits), cache = self.forward(state, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True)
|
|
830
|
+
entropy = calc_entropy(action_logits)
|
|
674
831
|
|
|
675
|
-
|
|
676
|
-
|
|
832
|
+
action = rearrange(action, 'b t -> b t 1')
|
|
833
|
+
log_prob = action_logits.gather(-1, action)
|
|
834
|
+
log_prob = rearrange(log_prob, 'b t 1 -> b t')
|
|
677
835
|
|
|
678
|
-
|
|
836
|
+
# update actor, classic clipped surrogate loss
|
|
679
837
|
|
|
680
|
-
|
|
838
|
+
eps_clip = self.ppo_eps_clip
|
|
839
|
+
ratio = (log_prob - old_action_log_prob).exp()
|
|
681
840
|
|
|
682
|
-
|
|
683
|
-
|
|
841
|
+
if self.use_spo:
|
|
842
|
+
actor_loss = -(ratio * advantage - (advantage.abs() * (ratio - 1.).square()) / (2 * eps_clip))
|
|
843
|
+
else:
|
|
844
|
+
actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
|
|
684
845
|
|
|
685
|
-
|
|
846
|
+
actor_loss = actor_loss - self.ppo_entropy_weight * entropy
|
|
686
847
|
|
|
687
|
-
|
|
848
|
+
windowed_actor_loss = actor_loss[mask].sum() / total_learnable_tokens
|
|
849
|
+
windowed_actor_loss.backward(retain_graph = True)
|
|
688
850
|
|
|
689
|
-
|
|
690
|
-
clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
|
|
691
|
-
clipped_value_loss = F.mse_loss(returns, clipped_value, reduction = 'none')
|
|
851
|
+
# update critic
|
|
692
852
|
|
|
693
|
-
|
|
853
|
+
value_loss = self.hl_gauss_loss(value_logits, returns, reduction = 'none')
|
|
694
854
|
|
|
695
|
-
|
|
696
|
-
|
|
855
|
+
value_clip = self.ppo_value_clip
|
|
856
|
+
value = self.hl_gauss_loss(value_logits)
|
|
857
|
+
|
|
858
|
+
clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
|
|
859
|
+
clipped_value_loss = self.hl_gauss_loss(clipped_value, returns, reduction = 'none')
|
|
860
|
+
|
|
861
|
+
critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
|
|
862
|
+
|
|
863
|
+
windowed_critic_loss = critic_loss[mask].sum() / total_learnable_tokens
|
|
864
|
+
windowed_critic_loss.backward(retain_graph = True)
|
|
865
|
+
|
|
866
|
+
# accumulate
|
|
867
|
+
|
|
868
|
+
mean_actor_loss.add_(windowed_actor_loss)
|
|
869
|
+
mean_critic_loss.add_(windowed_critic_loss)
|
|
697
870
|
|
|
698
871
|
# optimizer update
|
|
699
872
|
|
|
700
|
-
actor_optim
|
|
701
|
-
|
|
873
|
+
if exists(actor_optim):
|
|
874
|
+
actor_optim.step()
|
|
875
|
+
actor_optim.zero_grad()
|
|
702
876
|
|
|
703
|
-
critic_optim
|
|
704
|
-
|
|
877
|
+
if exists(critic_optim):
|
|
878
|
+
critic_optim.step()
|
|
879
|
+
critic_optim.zero_grad()
|
|
705
880
|
|
|
706
881
|
# return losses for logging
|
|
707
882
|
|
|
708
883
|
return mean_actor_loss.detach(), mean_critic_loss.detach()
|
|
709
884
|
|
|
885
|
+
def state_to_rewards(
|
|
886
|
+
self,
|
|
887
|
+
state
|
|
888
|
+
) -> Tensor:
|
|
889
|
+
|
|
890
|
+
assert self.has_reward_shaping
|
|
891
|
+
|
|
892
|
+
rewards = [fn(state) for fn in self.reward_shaping_fns]
|
|
893
|
+
|
|
894
|
+
rewards = [tensor(reward) if not is_tensor(reward) else reward for reward in rewards]
|
|
895
|
+
return stack(rewards)
|
|
896
|
+
|
|
710
897
|
def wrap_env_functions(self, env):
|
|
711
898
|
|
|
712
|
-
def
|
|
713
|
-
|
|
899
|
+
def transform_output(el):
|
|
900
|
+
if isinstance(el, ndarray):
|
|
901
|
+
return from_numpy(el)
|
|
902
|
+
elif isinstance(el, (int, bool, float)):
|
|
903
|
+
return tensor(el)
|
|
904
|
+
else:
|
|
905
|
+
return el
|
|
714
906
|
|
|
715
|
-
|
|
716
|
-
|
|
907
|
+
def wrapped_reset(*args, **kwargs):
|
|
908
|
+
env_reset_out = env.reset(*args, **kwargs)
|
|
717
909
|
|
|
718
|
-
return
|
|
910
|
+
return tree_map(transform_output, env_reset_out)
|
|
719
911
|
|
|
720
912
|
def wrapped_step(action, *args, **kwargs):
|
|
721
|
-
out = env.step(action.item(), *args, **kwargs)
|
|
722
913
|
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
return el
|
|
914
|
+
if is_tensor(action):
|
|
915
|
+
action = action.item()
|
|
916
|
+
|
|
917
|
+
env_step_out = env.step(action, *args, **kwargs)
|
|
918
|
+
|
|
919
|
+
env_step_out_torch = tree_map(transform_output, env_step_out)
|
|
730
920
|
|
|
731
|
-
|
|
921
|
+
if not self.has_reward_shaping:
|
|
922
|
+
return env_step_out_torch
|
|
923
|
+
|
|
924
|
+
shaped_rewards = self.state_to_rewards(env_step_out_torch)
|
|
925
|
+
|
|
926
|
+
return env_step_out_torch, shaped_rewards
|
|
732
927
|
|
|
733
928
|
return wrapped_reset, wrapped_step
|
|
734
929
|
|
|
@@ -738,6 +933,7 @@ class Locoformer(Module):
|
|
|
738
933
|
inference_mode = False,
|
|
739
934
|
has_batch_dim = False,
|
|
740
935
|
has_time_dim = False,
|
|
936
|
+
state_time_dim = 1,
|
|
741
937
|
**kwargs
|
|
742
938
|
):
|
|
743
939
|
window_size = self.window_size
|
|
@@ -753,23 +949,16 @@ class Locoformer(Module):
|
|
|
753
949
|
state = rearrange(state, '... -> 1 ...')
|
|
754
950
|
|
|
755
951
|
if not has_time_dim:
|
|
756
|
-
state =
|
|
952
|
+
state = state.unsqueeze(state_time_dim)
|
|
757
953
|
|
|
758
954
|
# forwards
|
|
759
955
|
|
|
760
956
|
out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
|
|
761
957
|
|
|
762
|
-
# handle cache
|
|
763
|
-
|
|
764
|
-
cache_len = cache.shape[-2]
|
|
765
|
-
|
|
766
|
-
if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
|
|
767
|
-
cache = cache[..., -window_size:, :]
|
|
768
|
-
|
|
769
958
|
# maybe remove batch or time
|
|
770
959
|
|
|
771
960
|
if not has_time_dim:
|
|
772
|
-
out = tree_map_tensor(out, lambda t:
|
|
961
|
+
out = tree_map_tensor(out, lambda t: t.squeeze(state_time_dim))
|
|
773
962
|
|
|
774
963
|
if not has_batch_dim:
|
|
775
964
|
out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
|
|
@@ -798,16 +987,35 @@ class Locoformer(Module):
|
|
|
798
987
|
def forward(
|
|
799
988
|
self,
|
|
800
989
|
state: Tensor,
|
|
801
|
-
cache:
|
|
990
|
+
cache: Cache | None = None,
|
|
802
991
|
detach_cache = False,
|
|
803
|
-
return_values = False
|
|
992
|
+
return_values = False,
|
|
993
|
+
return_raw_value_logits = False
|
|
804
994
|
):
|
|
805
995
|
|
|
806
996
|
state = state.to(self.device)
|
|
807
997
|
|
|
808
998
|
tokens = self.embedder(state)
|
|
809
999
|
|
|
810
|
-
|
|
1000
|
+
# time
|
|
1001
|
+
|
|
1002
|
+
time = tokens.shape[-2]
|
|
1003
|
+
|
|
1004
|
+
# destruct the cache for the current timestep and the cache
|
|
1005
|
+
|
|
1006
|
+
prev_kv_cache = None
|
|
1007
|
+
timestep_start = 0
|
|
1008
|
+
|
|
1009
|
+
if exists(cache):
|
|
1010
|
+
timestep_start, prev_kv_cache = cache
|
|
1011
|
+
|
|
1012
|
+
# an assert - make sure during training or inference, forward never gets anything that crosses the window segment boundary, to open up some possibilities with extending memory
|
|
1013
|
+
|
|
1014
|
+
assert ((timestep_start % self.window_size) + time) <= self.window_size
|
|
1015
|
+
|
|
1016
|
+
# attention
|
|
1017
|
+
|
|
1018
|
+
embed, kv_cache = self.transformer(tokens, cache = prev_kv_cache, return_kv_cache = True)
|
|
811
1019
|
|
|
812
1020
|
# unembed to actions - in language models this would be the next state
|
|
813
1021
|
|
|
@@ -818,21 +1026,34 @@ class Locoformer(Module):
|
|
|
818
1026
|
# maybe detach cache
|
|
819
1027
|
|
|
820
1028
|
if detach_cache:
|
|
821
|
-
kv_cache =
|
|
1029
|
+
kv_cache = kv_cache.detach()
|
|
822
1030
|
|
|
823
1031
|
# handle returning of values
|
|
824
1032
|
|
|
825
1033
|
if return_values:
|
|
826
|
-
assert exists(self.
|
|
1034
|
+
assert exists(self.to_value_pred)
|
|
827
1035
|
|
|
828
|
-
values = self.
|
|
1036
|
+
values = self.to_value_pred(embed)
|
|
829
1037
|
|
|
830
|
-
if
|
|
831
|
-
|
|
832
|
-
values = rearrange(values, '... 1 -> ...')
|
|
1038
|
+
if not return_raw_value_logits:
|
|
1039
|
+
values = self.hl_gauss_loss(values) # converts the value logits to scalar values
|
|
833
1040
|
|
|
834
1041
|
out = (out, values)
|
|
835
1042
|
|
|
836
1043
|
# output and cache
|
|
837
1044
|
|
|
838
|
-
|
|
1045
|
+
next_timestep = time + timestep_start
|
|
1046
|
+
|
|
1047
|
+
# handle curtailing kv cache at the right intervals
|
|
1048
|
+
|
|
1049
|
+
window_size = self.window_size
|
|
1050
|
+
|
|
1051
|
+
if self.fixed_window_size or divisible_by(next_timestep, window_size * 2):
|
|
1052
|
+
kv_cache = kv_cache[..., -window_size:, :]
|
|
1053
|
+
|
|
1054
|
+
# maybe recurrent cache - shift the kv cache from one layer above to the one below, for extending on receptive field of past
|
|
1055
|
+
|
|
1056
|
+
if self.recurrent_kv_cache and divisible_by(next_timestep, window_size):
|
|
1057
|
+
kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
|
|
1058
|
+
|
|
1059
|
+
return out, (next_timestep, kv_cache)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: locoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.29
|
|
4
4
|
Summary: LocoFormer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/locoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/locoformer
|
|
@@ -38,6 +38,7 @@ Requires-Dist: assoc-scan
|
|
|
38
38
|
Requires-Dist: beartype
|
|
39
39
|
Requires-Dist: einops>=0.8.0
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
|
+
Requires-Dist: hl-gauss-pytorch>=0.2.0
|
|
41
42
|
Requires-Dist: rotary-embedding-torch
|
|
42
43
|
Requires-Dist: torch>=2.4
|
|
43
44
|
Requires-Dist: x-mlps-pytorch
|
|
@@ -54,7 +55,7 @@ Description-Content-Type: text/markdown
|
|
|
54
55
|
|
|
55
56
|
[LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
|
|
56
57
|
|
|
57
|
-
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment)
|
|
58
|
+
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) and extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
58
59
|
|
|
59
60
|
## Sponsors
|
|
60
61
|
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
|
|
2
|
+
locoformer/locoformer.py,sha256=Tr_1btuoTZ0huXeDcAeuHxTPaVeCUEGc5iLvMYGDLck,29982
|
|
3
|
+
locoformer-0.0.29.dist-info/METADATA,sha256=5Fi3EOsgpBvpzAFVZQyrlink-HcHE8EgFl10Y5l8mqM,3256
|
|
4
|
+
locoformer-0.0.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
locoformer-0.0.29.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
locoformer-0.0.29.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
|
|
2
|
-
locoformer/locoformer.py,sha256=1jPK41G4HB1PEPtlusQxcrne489E-3QKXAULZ20FEZM,22740
|
|
3
|
-
locoformer-0.0.15.dist-info/METADATA,sha256=IHtK7NvVQewYQ0GBB7v1KG90_H2Jakxir0MakUIA-jU,3218
|
|
4
|
-
locoformer-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
locoformer-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
locoformer-0.0.15.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|