locoformer 0.0.17__tar.gz → 0.0.37__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: locoformer
3
- Version: 0.0.17
3
+ Version: 0.0.37
4
4
  Summary: LocoFormer
5
5
  Project-URL: Homepage, https://pypi.org/project/locoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/locoformer
@@ -38,8 +38,10 @@ Requires-Dist: assoc-scan
38
38
  Requires-Dist: beartype
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
+ Requires-Dist: hl-gauss-pytorch>=0.2.0
41
42
  Requires-Dist: rotary-embedding-torch
42
43
  Requires-Dist: torch>=2.4
44
+ Requires-Dist: x-evolution
43
45
  Requires-Dist: x-mlps-pytorch
44
46
  Provides-Extra: examples
45
47
  Requires-Dist: accelerate; extra == 'examples'
@@ -54,7 +56,7 @@ Description-Content-Type: text/markdown
54
56
 
55
57
  [LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
56
58
 
57
- The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
59
+ The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) and extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
58
60
 
59
61
  ## Sponsors
60
62
 
@@ -4,7 +4,7 @@
4
4
 
5
5
  [LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
6
6
 
7
- The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
7
+ The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) and extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
8
8
 
9
9
  ## Sponsors
10
10
 
@@ -1,10 +1,14 @@
1
1
  from __future__ import annotations
2
+ from typing import Callable
3
+ from types import SimpleNamespace
2
4
  from functools import partial
3
5
 
4
6
  from pathlib import Path
5
7
  from contextlib import contextmanager
6
8
  from collections import namedtuple
7
9
 
10
+ from inspect import signature
11
+
8
12
  import numpy as np
9
13
  from numpy import ndarray
10
14
  from numpy.lib.format import open_memmap
@@ -16,7 +20,7 @@ import torch
16
20
  from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
17
21
  import torch.nn.functional as F
18
22
  from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
19
- from torch.utils._pytree import tree_map
23
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
20
24
  from torch.utils.data import Dataset, DataLoader
21
25
  from torch.optim import Optimizer
22
26
 
@@ -26,12 +30,20 @@ from einops.layers.torch import Rearrange
26
30
 
27
31
  from rotary_embedding_torch import RotaryEmbedding
28
32
 
33
+ from hl_gauss_pytorch import HLGaussLoss
34
+
29
35
  from assoc_scan import AssocScan
30
36
 
37
+ from x_mlps_pytorch import MLP
38
+
39
+ from x_evolution import EvoStrategy
40
+
31
41
  # constants
32
42
 
33
43
  LinearNoBias = partial(Linear, bias = False)
34
44
 
45
+ Cache = namedtuple('Cache', ('curr_timestep', 'kv_cache')) # (int, Tensor)
46
+
35
47
  # helper functions
36
48
 
37
49
  def exists(v):
@@ -43,14 +55,24 @@ def default(v, d):
43
55
  def first(arr):
44
56
  return arr[0]
45
57
 
58
+ def xnor(x, y):
59
+ return not (x ^ y)
60
+
46
61
  def divisible_by(num, den):
47
62
  return (num % den) == 0
48
63
 
64
+ def get_param_names(fn):
65
+ parameters = signature(fn).parameters
66
+ return list(parameters.keys())
67
+
49
68
  # tensor helpers
50
69
 
51
70
  def log(t, eps = 1e-20):
52
71
  return t.clamp_min(eps).log()
53
72
 
73
+ def is_empty(t):
74
+ return t.numel() == 0
75
+
54
76
  def tree_map_tensor(x, fn):
55
77
  return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
56
78
 
@@ -67,10 +89,102 @@ def pad_at_dim(
67
89
  zeros = ((0, 0) * dims_from_right)
68
90
  return F.pad(t, (*zeros, *pad), value = value)
69
91
 
92
+ def normalize(t, eps = 1e-5):
93
+ return (t - t.mean()) / t.std().clamp_min(eps)
94
+
95
+ def tensor_to_dict(
96
+ t: Tensor,
97
+ config: tuple[tuple[str, int] | str],
98
+ dim = -1,
99
+ return_dottable = True
100
+ ):
101
+ config = tuple((c, 1) if isinstance(c, str) else c for c in config)
102
+
103
+ names, sizes = zip(*config)
104
+ assert sum(sizes) == t.shape[dim]
105
+
106
+ t = t.split(sizes, dim = dim)
107
+ tensor_dict = dict(zip(names, t))
108
+
109
+ if not return_dottable:
110
+ return tensor_dict
111
+
112
+ return SimpleNamespace(**tensor_dict)
113
+
70
114
  def calc_entropy(logits):
71
115
  prob = logits.softmax(dim = -1)
72
116
  return -(prob * log(prob)).sum(dim = -1)
73
117
 
118
+ # reward functions - A.2
119
+
120
+ def reward_linear_velocity_command_tracking(
121
+ state,
122
+ command,
123
+ s1 = 1.
124
+ ):
125
+ if not (hasattr(state, 'v_xy') and hasattr(command, 'v_xy')):
126
+ return 0.
127
+
128
+ error = (state.v_xy - command.v_xy).norm(dim = -1).pow(2)
129
+ return torch.exp(-error / s1)
130
+
131
+ def reward_angular_velocity_command_tracking(
132
+ state,
133
+ command,
134
+ s2 = 1.
135
+ ):
136
+ if not (hasattr(state, 'w_z') and hasattr(command, 'w_z')):
137
+ return 0.
138
+
139
+ error = (state.w_z - command.w_z).norm(dim = -1).pow(2)
140
+ return torch.exp(-error / s2)
141
+
142
+ def reward_base_linear_velocity_penalty(
143
+ state
144
+ ):
145
+ if not hasattr(state, 'v_z'):
146
+ return 0.
147
+
148
+ return -state.v_z.norm(dim = -1).pow(2)
149
+
150
+ def reward_base_angular_velocity_penalty(
151
+ state
152
+ ):
153
+ if not hasattr(state, 'w_xy'):
154
+ return 0.
155
+
156
+ return -state.w_xy.norm(dim = -1).pow(2)
157
+
158
+ def reward_base_height_penalty(
159
+ state,
160
+ x_z_nominal = 0.27
161
+ ):
162
+ if not hasattr(state, 'x_z'):
163
+ return 0.
164
+
165
+ return -(state.x_z - x_z_nominal).norm(dim = -1).pow(2)
166
+
167
+ def reward_joint_acceleration_penalty(
168
+ state
169
+ ):
170
+ if not hasattr(state, 'joint_q'):
171
+ return 0.
172
+
173
+ return -state.joint_q.norm(dim = -1).pow(2)
174
+
175
+ def reward_torque_penalty(
176
+ state
177
+ ):
178
+ if not hasattr(state, 'tau'):
179
+ return 0.
180
+
181
+ return -state.tau.norm(dim = -1).pow(2)
182
+
183
+ def reward_alive(
184
+ state
185
+ ):
186
+ return 1.
187
+
74
188
  # generalized advantage estimate
75
189
 
76
190
  @torch.no_grad()
@@ -250,6 +364,57 @@ class ReplayDataset(Dataset):
250
364
 
251
365
  return data
252
366
 
367
+ class RemappedReplayDataset(Dataset):
368
+ def __init__(
369
+ self,
370
+ dataset: ReplayDataset,
371
+ episode_mapping: Tensor | list[list[int]],
372
+ shuffle_episodes = False
373
+ ):
374
+ assert len(dataset) > 0
375
+ self.dataset = dataset
376
+
377
+ if is_tensor(episode_mapping):
378
+ assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
379
+ episode_mapping = episode_mapping.tolist()
380
+
381
+ self.episode_mapping = episode_mapping
382
+ self.shuffle_episodes = shuffle_episodes
383
+
384
+ def __len__(self):
385
+ return len(self.episode_mapping)
386
+
387
+ def __getitem__(self, idx):
388
+
389
+ episode_indices = self.episode_mapping[idx]
390
+
391
+ episode_indices = tensor(episode_indices)
392
+ episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
393
+
394
+ assert not is_empty(episode_indices)
395
+
396
+ if self.shuffle_episodes and episode_indices.numel() > 1:
397
+ num_episodes = len(episode_indices)
398
+ episode_indices = episode_indices[torch.randperm(num_episodes)]
399
+
400
+ episode_data = [self.dataset[i] for i in episode_indices.tolist()]
401
+
402
+ episode_lens = stack([data.pop('_lens') for data in episode_data])
403
+
404
+ keys = first(episode_data).keys()
405
+
406
+ values = [list(data.values()) for data in episode_data]
407
+
408
+ values = [cat(field_values) for field_values in zip(*values)] # concat across time
409
+
410
+ multi_episode_data = dict(zip(keys, values))
411
+
412
+ multi_episode_data['_lens'] = episode_lens.sum()
413
+
414
+ multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
415
+
416
+ return multi_episode_data
417
+
253
418
  class ReplayBuffer:
254
419
 
255
420
  @beartype
@@ -314,6 +479,9 @@ class ReplayBuffer:
314
479
 
315
480
  self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
316
481
 
482
+ def __len__(self):
483
+ return (self.episode_lens > 0).sum().item()
484
+
317
485
  def reset_(self):
318
486
  self.episode_lens[:] = 0
319
487
  self.episode_index = 0
@@ -375,15 +543,92 @@ class ReplayBuffer:
375
543
 
376
544
  return self.memory_namedtuple(**data)
377
545
 
378
- def dataset(self) -> Dataset:
546
+ def dataset(
547
+ self,
548
+ episode_mapping: Tensor | list[list[int]] | None = None,
549
+ ) -> Dataset:
379
550
  self.flush()
380
551
 
381
- return ReplayDataset(self.folder)
552
+ dataset = ReplayDataset(self.folder)
553
+
554
+ if not exists(episode_mapping):
555
+ return dataset
556
+
557
+ return RemappedReplayDataset(dataset, episode_mapping)
382
558
 
383
- def dataloader(self, batch_size, **kwargs) -> DataLoader:
559
+ def dataloader(
560
+ self,
561
+ batch_size,
562
+ episode_mapping: Tensor | list[list[int]] | None = None,
563
+ **kwargs
564
+ ) -> DataLoader:
384
565
  self.flush()
385
566
 
386
- return DataLoader(self.dataset(), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
567
+ return DataLoader(self.dataset(episode_mapping), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
568
+
569
+ # normalization + conditioning (needed for the commands to the robot)
570
+
571
+ class MaybeAdaRMSNormWrapper(Module):
572
+ def __init__(
573
+ self,
574
+ fn: Module,
575
+ dim,
576
+ dim_cond = None
577
+ ):
578
+ super().__init__()
579
+ condition = exists(dim_cond)
580
+
581
+ self.fn = fn
582
+ self.norm = nn.RMSNorm(dim, elementwise_affine = not condition)
583
+
584
+ self.accept_condition = condition
585
+
586
+ if condition:
587
+ self.to_gamma = LinearNoBias(dim_cond, dim)
588
+ self.to_ada_norm_zero = nn.Linear(dim_cond, dim)
589
+
590
+ nn.init.zeros_(self.to_gamma.weight)
591
+ nn.init.zeros_(self.to_ada_norm_zero.weight)
592
+ nn.init.constant_(self.to_ada_norm_zero.bias, -5.)
593
+
594
+ def forward(
595
+ self,
596
+ x,
597
+ cond = None,
598
+ **kwargs
599
+ ):
600
+
601
+ need_cond = self.accept_condition
602
+
603
+ assert xnor(exists(cond), need_cond)
604
+
605
+ prenormed = self.norm(x)
606
+
607
+ if need_cond:
608
+ if cond.ndim == 2:
609
+ cond = rearrange(cond, 'b d -> b 1 d')
610
+
611
+ scale_in = self.to_gamma(cond)
612
+ prenormed = prenormed * (scale_in + 1.)
613
+
614
+ all_fn_out = self.fn(prenormed, **kwargs)
615
+
616
+ if not need_cond:
617
+ return all_fn_out
618
+
619
+ # function may return multiple args
620
+
621
+ (out, *rest), tree_spec = tree_flatten(all_fn_out)
622
+
623
+ if need_cond:
624
+ scale_out = self.to_ada_norm_zero(cond).sigmoid()
625
+ out = out * scale_out
626
+
627
+ # restore
628
+
629
+ all_fn_out = tree_unflatten((out, *rest), tree_spec)
630
+
631
+ return all_fn_out
387
632
 
388
633
  # transformer-xl with ppo
389
634
 
@@ -394,15 +639,12 @@ class Attention(Module):
394
639
  window_size,
395
640
  dim_head = 64,
396
641
  heads = 8,
397
- pre_rmsnorm = True,
398
642
  fixed_window_size = False,
399
643
  accept_value_residual = False
400
644
  ):
401
645
  super().__init__()
402
646
  self.scale = dim_head ** -0.5
403
647
 
404
- self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
405
-
406
648
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
407
649
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
408
650
 
@@ -446,8 +688,6 @@ class Attention(Module):
446
688
 
447
689
  device = tokens.device
448
690
 
449
- tokens = self.norm(tokens)
450
-
451
691
  q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
452
692
 
453
693
  q, k, v = map(self.split_heads, (q, k, v))
@@ -536,19 +776,26 @@ class TransformerXL(Module):
536
776
  dim_head = 64,
537
777
  heads = 8,
538
778
  expansion_factor = 4.,
779
+ dim_cond = None,
539
780
  final_norm = True,
540
781
  fixed_window_size = False,
541
782
  ):
542
783
  super().__init__()
543
784
 
785
+ condition = exists(dim_cond)
786
+
787
+ self.to_cond_tokens = MLP(dim_cond, dim * 2, activate_last = True) if exists(dim_cond) else None
788
+
789
+ norm_fn = partial(MaybeAdaRMSNormWrapper, dim = dim, dim_cond = (dim * 2) if condition else None)
790
+
544
791
  layers = ModuleList([])
545
792
 
546
793
  for i in range(depth):
547
794
  is_first = i == 0
548
795
 
549
- attn = Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first)
796
+ attn = norm_fn(Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first))
550
797
 
551
- ff = FeedForward(dim = dim, expansion_factor = expansion_factor)
798
+ ff = norm_fn(FeedForward(dim = dim, expansion_factor = expansion_factor))
552
799
 
553
800
  layers.append(ModuleList([
554
801
  attn, ff
@@ -566,20 +813,32 @@ class TransformerXL(Module):
566
813
  self,
567
814
  x,
568
815
  cache = None,
569
- return_kv_cache = False
816
+ return_kv_cache = False,
817
+ condition: Tensor | None = None
570
818
  ):
571
819
 
820
+ # cache and residuals
821
+
572
822
  cache = default(cache, (None,) * len(self.layers))
573
823
 
574
824
  next_kv_caches = []
575
825
  value_residual = None
576
826
 
827
+ # handle condition
828
+
829
+ cond_tokens = None
830
+ if exists(condition):
831
+ assert exists(self.to_cond_tokens)
832
+ cond_tokens = self.to_cond_tokens(condition)
833
+
834
+ # layers
835
+
577
836
  for (attn, ff), kv_cache in zip(self.layers, cache):
578
837
 
579
- attn_out, (next_kv_cache, values) = attn(x, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
838
+ attn_out, (next_kv_cache, values) = attn(x, cond = cond_tokens, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
580
839
 
581
840
  x = attn_out + x
582
- x = ff(x) + x
841
+ x = ff(x, cond = cond_tokens) + x
583
842
 
584
843
  next_kv_caches.append(next_kv_cache)
585
844
  value_residual = default(value_residual, values)
@@ -603,14 +862,21 @@ class Locoformer(Module):
603
862
  embedder: Module,
604
863
  unembedder: Module,
605
864
  transformer: dict | TransformerXL,
606
- value_network: Module | None = None,
607
865
  discount_factor = 0.999,
608
866
  gae_lam = 0.95,
609
867
  ppo_eps_clip = 0.2,
610
868
  ppo_entropy_weight = 0.01,
611
869
  ppo_value_clip = 0.4,
870
+ dim_value_input = None, # needs to be set for value network to be available
871
+ value_network: Module = nn.Identity(),
872
+ reward_range: tuple[float, float] | None = None,
873
+ reward_shaping_fns: list[Callable[..., float | Tensor]] | None = None,
874
+ num_reward_bins = 32,
875
+ hl_gauss_loss_kwargs = dict(),
612
876
  value_loss_weight = 0.5,
613
- calc_gae_kwargs: dict = dict()
877
+ calc_gae_kwargs: dict = dict(),
878
+ recurrent_kv_cache = True,
879
+ use_spo = False # simple policy optimization https://arxiv.org/abs/2401.16025 - Levine's group (PI) verified it is more stable than PPO
614
880
  ):
615
881
  super().__init__()
616
882
 
@@ -622,11 +888,30 @@ class Locoformer(Module):
622
888
  self.embedder = embedder
623
889
  self.unembedder = unembedder
624
890
 
625
- self.value_network = value_network
626
-
627
891
  self.fixed_window_size = transformer.fixed_window_size
628
892
  self.window_size = transformer.window_size
629
893
 
894
+ # determine value network, using HL Gauss Layer
895
+
896
+ self.to_value_pred = None
897
+
898
+ if exists(dim_value_input):
899
+ assert exists(reward_range)
900
+
901
+ self.to_value_pred = nn.Sequential(
902
+ value_network,
903
+ LinearNoBias(dim_value_input, num_reward_bins)
904
+ )
905
+
906
+ reward_min, reward_max = reward_range
907
+
908
+ self.hl_gauss_loss = HLGaussLoss(
909
+ min_value = reward_min,
910
+ max_value = reward_max,
911
+ num_bins = num_reward_bins,
912
+ **hl_gauss_loss_kwargs
913
+ )
914
+
630
915
  # ppo related
631
916
 
632
917
  self.discount_factor = discount_factor
@@ -638,6 +923,19 @@ class Locoformer(Module):
638
923
 
639
924
  self.calc_gae_kwargs = calc_gae_kwargs
640
925
 
926
+ # maybe use spo
927
+
928
+ self.use_spo = use_spo
929
+
930
+ # maybe recurrent kv cache, from Ding et al. https://arxiv.org/abs/2012.15688
931
+
932
+ self.recurrent_kv_cache = recurrent_kv_cache
933
+
934
+ # reward shaping function
935
+
936
+ self.has_reward_shaping = exists(reward_shaping_fns)
937
+ self.reward_shaping_fns = reward_shaping_fns
938
+
641
939
  # loss related
642
940
 
643
941
  self.register_buffer('zero', tensor(0.), persistent = False)
@@ -650,10 +948,18 @@ class Locoformer(Module):
650
948
  return self.unembedder.parameters()
651
949
 
652
950
  def critic_parameters(self):
653
- if not exists(self.value_network):
951
+ if not exists(self.to_value_pred):
654
952
  return []
655
953
 
656
- return self.value_network.parameters()
954
+ return self.to_value_pred.parameters()
955
+
956
+ def evolve(
957
+ self,
958
+ environment,
959
+ **kwargs
960
+ ):
961
+ evo_strat = EvoStrategy(self, environment = environment, **kwargs)
962
+ evo_strat()
657
963
 
658
964
  def ppo(
659
965
  self,
@@ -663,12 +969,20 @@ class Locoformer(Module):
663
969
  reward,
664
970
  old_value,
665
971
  mask,
972
+ episode_lens,
666
973
  actor_optim: Optimizer | None = None,
667
974
  critic_optim: Optimizer | None = None
668
975
  ):
669
976
  window_size = self.window_size
670
977
  total_learnable_tokens = mask.sum().item()
671
978
 
979
+ seq_len = state.shape[1]
980
+ gae_mask = einx.less('j, i -> i j', arange(seq_len, device = self.device), episode_lens)
981
+
982
+ advantage, returns = calc_gae(reward, old_value, masks = gae_mask, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
983
+
984
+ advantage = normalize(advantage)
985
+
672
986
  windowed_tensors = [
673
987
  t.split(window_size, dim = 1) for t in
674
988
  (
@@ -677,7 +991,9 @@ class Locoformer(Module):
677
991
  old_action_log_prob,
678
992
  reward,
679
993
  old_value,
680
- mask
994
+ mask,
995
+ advantage,
996
+ returns
681
997
  )
682
998
  ]
683
999
 
@@ -694,10 +1010,12 @@ class Locoformer(Module):
694
1010
  old_action_log_prob,
695
1011
  reward,
696
1012
  old_value,
697
- mask
1013
+ mask,
1014
+ advantage,
1015
+ returns
698
1016
  ) in zip(*windowed_tensors):
699
1017
 
700
- (action_logits, value), cache = self.forward(state, cache = cache, detach_cache = True, return_values = True)
1018
+ (action_logits, value_logits), cache = self.forward(state, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True)
701
1019
  entropy = calc_entropy(action_logits)
702
1020
 
703
1021
  action = rearrange(action, 'b t -> b t 1')
@@ -709,9 +1027,10 @@ class Locoformer(Module):
709
1027
  eps_clip = self.ppo_eps_clip
710
1028
  ratio = (log_prob - old_action_log_prob).exp()
711
1029
 
712
- advantage, returns = calc_gae(reward, old_value, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
713
-
714
- actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
1030
+ if self.use_spo:
1031
+ actor_loss = -(ratio * advantage - (advantage.abs() * (ratio - 1.).square()) / (2 * eps_clip))
1032
+ else:
1033
+ actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
715
1034
 
716
1035
  actor_loss = actor_loss - self.ppo_entropy_weight * entropy
717
1036
 
@@ -720,11 +1039,13 @@ class Locoformer(Module):
720
1039
 
721
1040
  # update critic
722
1041
 
723
- value_loss = F.mse_loss(returns, value, reduction = 'none')
1042
+ value_loss = self.hl_gauss_loss(value_logits, returns, reduction = 'none')
724
1043
 
725
1044
  value_clip = self.ppo_value_clip
1045
+ value = self.hl_gauss_loss(value_logits)
1046
+
726
1047
  clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
727
- clipped_value_loss = F.mse_loss(returns, clipped_value, reduction = 'none')
1048
+ clipped_value_loss = self.hl_gauss_loss(clipped_value, returns, reduction = 'none')
728
1049
 
729
1050
  critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
730
1051
 
@@ -750,28 +1071,65 @@ class Locoformer(Module):
750
1071
 
751
1072
  return mean_actor_loss.detach(), mean_critic_loss.detach()
752
1073
 
1074
+ def state_and_command_to_rewards(
1075
+ self,
1076
+ state,
1077
+ commands = None
1078
+ ) -> Tensor:
1079
+
1080
+ assert self.has_reward_shaping
1081
+
1082
+ rewards = []
1083
+
1084
+ for fn in self.reward_shaping_fns:
1085
+ param_names = get_param_names(fn)
1086
+ param_names = set(param_names) & {'state', 'command'}
1087
+
1088
+ if param_names == {'state'}: # only state
1089
+ reward = fn(state = state)
1090
+ elif param_names == {'state', 'command'}: # state and command
1091
+ reward = fn(state = state, command = commands)
1092
+ else:
1093
+ raise ValueError('invalid number of arguments for reward shaping function')
1094
+
1095
+ rewards.append(reward)
1096
+
1097
+ # cast to Tensor if returns a float, just make it flexible for researcher
1098
+
1099
+ rewards = [tensor(reward) if not is_tensor(reward) else reward for reward in rewards]
1100
+
1101
+ return stack(rewards)
1102
+
753
1103
  def wrap_env_functions(self, env):
754
1104
 
755
- def wrapped_reset(*args, **kwargs):
756
- state, _ = env.reset(*args, **kwargs)
1105
+ def transform_output(el):
1106
+ if isinstance(el, ndarray):
1107
+ return from_numpy(el)
1108
+ elif isinstance(el, (int, bool, float)):
1109
+ return tensor(el)
1110
+ else:
1111
+ return el
757
1112
 
758
- if isinstance(state, ndarray):
759
- state = from_numpy(state)
1113
+ def wrapped_reset(*args, **kwargs):
1114
+ env_reset_out = env.reset(*args, **kwargs)
760
1115
 
761
- return state, _
1116
+ return tree_map(transform_output, env_reset_out)
762
1117
 
763
1118
  def wrapped_step(action, *args, **kwargs):
764
- out = env.step(action.item(), *args, **kwargs)
765
1119
 
766
- def transform_output(el):
767
- if isinstance(el, ndarray):
768
- return from_numpy(el)
769
- elif isinstance(el, (int, bool, float)):
770
- return tensor(el)
771
- else:
772
- return el
1120
+ if is_tensor(action):
1121
+ action = action.item()
1122
+
1123
+ env_step_out = env.step(action, *args, **kwargs)
1124
+
1125
+ env_step_out_torch = tree_map(transform_output, env_step_out)
1126
+
1127
+ if not self.has_reward_shaping:
1128
+ return env_step_out_torch
773
1129
 
774
- return tree_map(transform_output, out)
1130
+ shaped_rewards = self.state_and_command_to_rewards(env_step_out_torch)
1131
+
1132
+ return env_step_out_torch, shaped_rewards
775
1133
 
776
1134
  return wrapped_reset, wrapped_step
777
1135
 
@@ -781,13 +1139,18 @@ class Locoformer(Module):
781
1139
  inference_mode = False,
782
1140
  has_batch_dim = False,
783
1141
  has_time_dim = False,
1142
+ state_time_dim = 1,
784
1143
  **kwargs
785
1144
  ):
786
1145
  window_size = self.window_size
787
1146
 
788
1147
  cache = None
789
1148
 
790
- def stateful_forward(state: Tensor, **override_kwargs):
1149
+ def stateful_forward(
1150
+ state: Tensor,
1151
+ condition: Tensor | None = None,
1152
+ **override_kwargs
1153
+ ):
791
1154
  nonlocal cache
792
1155
 
793
1156
  # handle no batch or time, for easier time rolling out against envs
@@ -795,24 +1158,23 @@ class Locoformer(Module):
795
1158
  if not has_batch_dim:
796
1159
  state = rearrange(state, '... -> 1 ...')
797
1160
 
798
- if not has_time_dim:
799
- state = rearrange(state, '... d -> ... 1 d')
800
-
801
- # forwards
1161
+ if exists(command):
1162
+ condition = rearrange(condition, '... -> 1 ...')
802
1163
 
803
- out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
1164
+ if not has_time_dim:
1165
+ state = state.unsqueeze(state_time_dim)
804
1166
 
805
- # handle cache
1167
+ if exists(command):
1168
+ condition = rearrange(condition, '... d -> ... 1 d')
806
1169
 
807
- cache_len = cache.shape[-2]
1170
+ # forwards
808
1171
 
809
- if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
810
- cache = cache[..., -window_size:, :]
1172
+ out, cache = self.forward(state, condition = condition, cache = cache, **{**kwargs, **override_kwargs})
811
1173
 
812
1174
  # maybe remove batch or time
813
1175
 
814
1176
  if not has_time_dim:
815
- out = tree_map_tensor(out, lambda t: rearrange(t, '... 1 d -> ... d'))
1177
+ out = tree_map_tensor(out, lambda t: t.squeeze(state_time_dim))
816
1178
 
817
1179
  if not has_batch_dim:
818
1180
  out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
@@ -841,16 +1203,36 @@ class Locoformer(Module):
841
1203
  def forward(
842
1204
  self,
843
1205
  state: Tensor,
844
- cache: Tensor | None = None,
1206
+ cache: Cache | None = None,
1207
+ condition: Tensor | None = None,
845
1208
  detach_cache = False,
846
- return_values = False
1209
+ return_values = False,
1210
+ return_raw_value_logits = False
847
1211
  ):
848
1212
 
849
1213
  state = state.to(self.device)
850
1214
 
851
1215
  tokens = self.embedder(state)
852
1216
 
853
- embed, kv_cache = self.transformer(tokens, cache = cache, return_kv_cache = True)
1217
+ # time
1218
+
1219
+ time = tokens.shape[-2]
1220
+
1221
+ # destruct the cache for the current timestep and the cache
1222
+
1223
+ prev_kv_cache = None
1224
+ timestep_start = 0
1225
+
1226
+ if exists(cache):
1227
+ timestep_start, prev_kv_cache = cache
1228
+
1229
+ # an assert - make sure during training or inference, forward never gets anything that crosses the window segment boundary, to open up some possibilities with extending memory
1230
+
1231
+ assert ((timestep_start % self.window_size) + time) <= self.window_size
1232
+
1233
+ # attention
1234
+
1235
+ embed, kv_cache = self.transformer(tokens, condition = condition, cache = prev_kv_cache, return_kv_cache = True)
854
1236
 
855
1237
  # unembed to actions - in language models this would be the next state
856
1238
 
@@ -866,16 +1248,29 @@ class Locoformer(Module):
866
1248
  # handle returning of values
867
1249
 
868
1250
  if return_values:
869
- assert exists(self.value_network)
1251
+ assert exists(self.to_value_pred)
870
1252
 
871
- values = self.value_network(embed)
1253
+ values = self.to_value_pred(embed)
872
1254
 
873
- if values.ndim == 3:
874
- assert values.shape[-1] == 1
875
- values = rearrange(values, '... 1 -> ...')
1255
+ if not return_raw_value_logits:
1256
+ values = self.hl_gauss_loss(values) # converts the value logits to scalar values
876
1257
 
877
1258
  out = (out, values)
878
1259
 
879
1260
  # output and cache
880
1261
 
881
- return out, kv_cache
1262
+ next_timestep = time + timestep_start
1263
+
1264
+ # handle curtailing kv cache at the right intervals
1265
+
1266
+ window_size = self.window_size
1267
+
1268
+ if self.fixed_window_size or divisible_by(next_timestep, window_size * 2):
1269
+ kv_cache = kv_cache[..., -window_size:, :]
1270
+
1271
+ # maybe recurrent cache - shift the kv cache from one layer above to the one below, for extending on receptive field of past
1272
+
1273
+ if self.recurrent_kv_cache and divisible_by(next_timestep, window_size):
1274
+ kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
1275
+
1276
+ return out, (next_timestep, kv_cache)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "locoformer"
3
- version = "0.0.17"
3
+ version = "0.0.37"
4
4
  description = "LocoFormer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -30,8 +30,10 @@ dependencies = [
30
30
  "beartype",
31
31
  "einx>=0.3.0",
32
32
  "einops>=0.8.0",
33
+ "hl-gauss-pytorch>=0.2.0",
33
34
  "rotary-embedding-torch",
34
35
  "torch>=2.4",
36
+ "x-evolution",
35
37
  "x-mlps-pytorch",
36
38
  ]
37
39
 
@@ -0,0 +1,182 @@
1
+ import pytest
2
+ param = pytest.mark.parametrize
3
+
4
+ import torch
5
+ from torch import nn
6
+ from x_mlps_pytorch import MLP
7
+
8
+ from einops import rearrange
9
+
10
+ from locoformer.locoformer import Locoformer
11
+
12
+ @param('recurrent_kv_cache', (False, True))
13
+ @param('has_commands', (False, True))
14
+ def test_locoformer(
15
+ recurrent_kv_cache,
16
+ has_commands
17
+ ):
18
+
19
+ model = Locoformer(
20
+ embedder = nn.Embedding(256, 128),
21
+ unembedder = nn.Linear(128, 256, bias = False),
22
+ value_network = MLP(128, 64, 32),
23
+ dim_value_input = 32,
24
+ reward_range = (-100., 100.),
25
+ recurrent_kv_cache = recurrent_kv_cache,
26
+ transformer = dict(
27
+ dim = 128,
28
+ depth = 1,
29
+ window_size = 512,
30
+ dim_cond = 2 if has_commands else None
31
+ )
32
+ )
33
+
34
+ seq = torch.randint(0, 256, (3, 512))
35
+
36
+ commands = None
37
+ if has_commands:
38
+ commands = torch.randn(3, 512, 2)
39
+
40
+ (logits, values), cache = model(seq, condition = commands, return_values = True)
41
+ (logits, values), cache = model(seq, condition = commands, return_values = True, cache = cache)
42
+ (logits, values), cache = model(seq, condition = commands, return_values = True, cache = cache)
43
+
44
+ assert logits.shape == (3, 512, 256)
45
+
46
+ stateful_forward = model.get_stateful_forward(has_batch_dim = True, has_time_dim = True, return_values = True, inference_mode = True)
47
+
48
+ inference_command = torch.randn(1, 1, 2) if has_commands else None
49
+
50
+ for state in seq.unbind(dim = -1):
51
+ state = rearrange(state, 'b -> b 1')
52
+
53
+ logits, values = stateful_forward(state, condition = inference_command)
54
+ assert logits.shape == (3, 1, 256)
55
+
56
+ def test_replay():
57
+ from locoformer.locoformer import ReplayBuffer
58
+
59
+ replay_buffer = ReplayBuffer(
60
+ './replay_data',
61
+ max_episodes = 10_000,
62
+ max_timesteps = 501,
63
+ fields = dict(
64
+ state = ('float', (8,)),
65
+ action = 'int',
66
+ action_log_prob = 'float',
67
+ reward = 'float',
68
+ value = 'float',
69
+ done = 'bool'
70
+ )
71
+ )
72
+
73
+ lens = [3, 5, 4]
74
+
75
+ for episode_len in lens:
76
+ with replay_buffer.one_episode():
77
+ for _ in range(episode_len):
78
+ state = torch.randn((8,))
79
+ action = torch.randint(0, 4, ())
80
+ log_prob = torch.randn(())
81
+ reward = torch.randn(())
82
+ value = torch.randn(())
83
+ done = torch.randint(0, 2, ()).bool()
84
+
85
+ replay_buffer.store(
86
+ state = state,
87
+ action = action,
88
+ action_log_prob = log_prob,
89
+ reward = reward,
90
+ value = value,
91
+ done = done
92
+ )
93
+
94
+ dataset = replay_buffer.dataset()
95
+
96
+ assert len(dataset) == 3
97
+
98
+ assert torch.is_tensor(dataset[0]['state'])
99
+
100
+ dataloader = replay_buffer.dataloader(batch_size = 3)
101
+
102
+ assert next(iter(dataloader))['state'].shape[0] == 3
103
+
104
+ # we will now consider consecutive pairs of episodes as 2 trials to be used for in-context adaptation
105
+ # but realistically there will be a function that converts a given ReplayBuffer -> Int[batch, episode_indices]
106
+
107
+ from torch import stack, arange
108
+
109
+ episode_indices = arange(len(replay_buffer))
110
+ remapped_episodes = stack((episode_indices[:-1], episode_indices[1:]))
111
+
112
+ dataloader = replay_buffer.dataloader(
113
+ batch_size = 1,
114
+ episode_mapping = remapped_episodes
115
+ )
116
+
117
+ assert next(iter(dataloader))['_lens'][0] == (3 + 5) # first and second episodes are concatted together timewise
118
+
119
+ def test_reward_shaping():
120
+
121
+ model = Locoformer(
122
+ embedder = nn.Embedding(256, 128),
123
+ unembedder = nn.Linear(128, 256, bias = False),
124
+ value_network = MLP(128, 64, 32),
125
+ dim_value_input = 32,
126
+ reward_range = (-100., 100.),
127
+ reward_shaping_fns = [
128
+ lambda state: (state[3] - 2.5).pow(2).mean(),
129
+ lambda state, command: state[4:6].norm(dim = -1)
130
+ ],
131
+ transformer = dict(
132
+ dim = 128,
133
+ depth = 1,
134
+ window_size = 512
135
+ )
136
+ )
137
+
138
+ import numpy as np
139
+
140
+ class MockEnv:
141
+ def reset(self):
142
+ return np.random.normal(size = (10,))
143
+
144
+ def step(self, *args, **kwargs):
145
+ return np.random.normal(size = (10,))
146
+
147
+
148
+ env = MockEnv()
149
+
150
+ reset_fn, step_fn = model.wrap_env_functions(env)
151
+
152
+ reset_fn()
153
+
154
+ _, rewards = step_fn(3)
155
+
156
+ assert len(rewards) == 2
157
+
158
+ def test_tensor_to_dict():
159
+ state = torch.randn(1, 3, 5)
160
+ config = (('xyz', 3), 'vx', 'vy')
161
+
162
+ from locoformer.locoformer import tensor_to_dict
163
+
164
+ state_dict = tensor_to_dict(state, config)
165
+ assert hasattr(state_dict, 'xyz') and state_dict.xyz.shape == (1, 3, 3)
166
+
167
+ def test_evo():
168
+
169
+ model = Locoformer(
170
+ embedder = nn.Embedding(256, 128),
171
+ unembedder = nn.Linear(128, 256, bias = False),
172
+ value_network = MLP(128, 64, 32),
173
+ dim_value_input = 32,
174
+ reward_range = (-100., 100.),
175
+ transformer = dict(
176
+ dim = 128,
177
+ depth = 1,
178
+ window_size = 512,
179
+ )
180
+ )
181
+
182
+ model.evolve(lambda model: 1., num_generations = 1)
@@ -160,7 +160,7 @@ for i in range(NUM_BATCHES):
160
160
  optim.step()
161
161
  optim.zero_grad()
162
162
 
163
- if divisible_by(i + 1, GENERATE_EVERY):
163
+ if divisible_by(i, GENERATE_EVERY):
164
164
  model.eval()
165
165
 
166
166
  val_seq = next(val_loader_iter)
@@ -25,7 +25,6 @@ import torch.nn.functional as F
25
25
  from torch.utils.data import TensorDataset, DataLoader
26
26
  from torch.optim import Adam
27
27
 
28
- import einx
29
28
  from einops import rearrange
30
29
 
31
30
  from locoformer.locoformer import Locoformer, ReplayBuffer
@@ -60,8 +59,6 @@ def learn(
60
59
  batch_size = 16,
61
60
  epochs = 2,
62
61
  ):
63
- device = accelerator.device
64
-
65
62
  dl = replay.dataloader(batch_size = batch_size, shuffle = True)
66
63
  model, dl, actor_optim, critic_optim = accelerator.prepare(model, dl, actor_optim, critic_optim)
67
64
 
@@ -70,18 +67,14 @@ def learn(
70
67
 
71
68
  data = SimpleNamespace(**data)
72
69
 
73
- seq_len = data.state.shape[1]
74
-
75
- value_mask = einx.less('j, i -> i j', arange(seq_len, device = device), data._lens)
76
- value = torch.where(value_mask, data.value, 0.)
77
-
78
70
  actor_loss, critic_loss = model.ppo(
79
71
  state = data.state,
80
72
  action = data.action,
81
73
  old_action_log_prob = data.action_log_prob,
82
74
  reward = data.reward,
83
- old_value = value,
75
+ old_value = data.value,
84
76
  mask = data.learnable,
77
+ episode_lens = data._lens,
85
78
  actor_optim = actor_optim,
86
79
  critic_optim = critic_optim
87
80
  )
@@ -94,7 +87,7 @@ def main(
94
87
  env_name = 'LunarLander-v3',
95
88
  num_episodes = 50_000,
96
89
  max_timesteps = 500,
97
- num_episodes_before_learn = 32,
90
+ num_episodes_before_learn = 64,
98
91
  clear_video = True,
99
92
  video_folder = 'recordings',
100
93
  record_every_episode = 250,
@@ -105,7 +98,8 @@ def main(
105
98
  ppo_eps_clip = 0.2,
106
99
  ppo_entropy_weight = .01,
107
100
  batch_size = 16,
108
- epochs = 2
101
+ epochs = 3,
102
+ reward_range = (-100., 100.)
109
103
  ):
110
104
 
111
105
  # accelerate
@@ -153,7 +147,6 @@ def main(
153
147
  locoformer = Locoformer(
154
148
  embedder = MLP(dim_state, 64, bias = False),
155
149
  unembedder = MLP(64, num_actions, bias = False),
156
- value_network = MLP(64, 1, bias = False),
157
150
  transformer = dict(
158
151
  dim = 64,
159
152
  dim_head = 32,
@@ -165,16 +158,20 @@ def main(
165
158
  gae_lam = gae_lam,
166
159
  ppo_eps_clip = ppo_eps_clip,
167
160
  ppo_entropy_weight = ppo_entropy_weight,
161
+ use_spo = True,
162
+ value_network = MLP(64, 64),
163
+ dim_value_input = 64,
164
+ reward_range = reward_range,
165
+ hl_gauss_loss_kwargs = dict(),
166
+ recurrent_kv_cache = True,
168
167
  calc_gae_kwargs = dict(
169
168
  use_accelerated = False
170
- )
169
+ ),
171
170
  ).to(device)
172
171
 
173
172
  optim_actor = Adam([*locoformer.transformer.parameters(), *locoformer.actor_parameters()], lr = learning_rate, betas = betas)
174
173
  optim_critic = Adam([*locoformer.transformer.parameters(), *locoformer.critic_parameters()], lr = learning_rate, betas = betas)
175
174
 
176
- timesteps_learn = 0
177
-
178
175
  # able to wrap the env for all values to torch tensors and back
179
176
  # all environments should follow usual MDP interface, domain randomization should be given at instantiation
180
177
 
@@ -205,7 +202,8 @@ def main(
205
202
 
206
203
  # append to memory
207
204
 
208
- done = truncated or terminated
205
+ exceeds_max_timesteps = timestep == (max_timesteps - 1)
206
+ done = truncated or terminated or tensor(exceeds_max_timesteps)
209
207
 
210
208
  # get log prob of action
211
209
 
@@ -222,23 +220,24 @@ def main(
222
220
  learnable = tensor(True)
223
221
  )
224
222
 
225
- # handle bootstrap value, which is a non-learnable timestep added with the next value for GAE
226
- # only if terminated signal not detected
223
+ # increment counters
227
224
 
228
- if not terminated:
229
- _, next_value = stateful_forward(next_state, return_values = True)
225
+ timestep += 1
230
226
 
231
- memory._replace(value = next_value, learnable = False)
227
+ # break if done or exceed max timestep
232
228
 
233
- replay.store(**memory._asdict())
229
+ if done:
234
230
 
235
- # increment counters
231
+ # handle bootstrap value, which is a non-learnable timestep added with the next value for GAE
232
+ # only if terminated signal not detected
236
233
 
237
- timestep += 1
234
+ if not terminated:
235
+ _, next_value = stateful_forward(next_state, return_values = True)
238
236
 
239
- # break if done or exceed max timestep
237
+ memory._replace(value = next_value, learnable = False)
238
+
239
+ replay.store(**memory._asdict())
240
240
 
241
- if done or timestep >= max_timesteps:
242
241
  break
243
242
 
244
243
  state = next_state
@@ -1,86 +0,0 @@
1
- import pytest
2
- param = pytest.mark.parametrize
3
-
4
- import torch
5
- from x_mlps_pytorch import MLP
6
-
7
- from einops import rearrange
8
-
9
- def test_locoformer():
10
- from locoformer.locoformer import Locoformer
11
- from torch import nn
12
-
13
- model = Locoformer(
14
- embedder = nn.Embedding(256, 128),
15
- unembedder = nn.Linear(128, 256, bias = False),
16
- value_network = MLP(128, 32, 1),
17
- transformer = dict(
18
- dim = 128,
19
- depth = 1,
20
- window_size = 512
21
- )
22
- )
23
-
24
- seq = torch.randint(0, 256, (3, 512))
25
-
26
- (logits, values), cache = model(seq, return_values = True)
27
- (logits, values), cache = model(seq, return_values = True, cache = cache)
28
- (logits, values), cache = model(seq, return_values = True, cache = cache)
29
-
30
- assert logits.shape == (3, 512, 256)
31
-
32
- stateful_forward = model.get_stateful_forward(has_batch_dim = True, has_time_dim = True, return_values = True, inference_mode = True)
33
-
34
- for state in seq.unbind(dim = -1):
35
- state = rearrange(state, 'b -> b 1')
36
-
37
- logits, values = stateful_forward(state)
38
- assert logits.shape == (3, 1, 256)
39
-
40
- def test_replay():
41
- from locoformer.locoformer import ReplayBuffer
42
-
43
- replay_buffer = ReplayBuffer(
44
- './replay_data',
45
- max_episodes = 10_000,
46
- max_timesteps = 501,
47
- fields = dict(
48
- state = ('float', (8,)),
49
- action = 'int',
50
- action_log_prob = 'float',
51
- reward = 'float',
52
- value = 'float',
53
- done = 'bool'
54
- )
55
- )
56
-
57
- lens = [3, 5, 4]
58
-
59
- for episode_len in lens:
60
- with replay_buffer.one_episode():
61
- for _ in range(episode_len):
62
- state = torch.randn((8,))
63
- action = torch.randint(0, 4, ())
64
- log_prob = torch.randn(())
65
- reward = torch.randn(())
66
- value = torch.randn(())
67
- done = torch.randint(0, 2, ()).bool()
68
-
69
- replay_buffer.store(
70
- state = state,
71
- action = action,
72
- action_log_prob = log_prob,
73
- reward = reward,
74
- value = value,
75
- done = done
76
- )
77
-
78
- dataset = replay_buffer.dataset()
79
-
80
- assert len(dataset) == 3
81
-
82
- assert torch.is_tensor(dataset[0]['state'])
83
-
84
- dataloader = replay_buffer.dataloader(batch_size = 3)
85
-
86
- assert next(iter(dataloader))['state'].shape[0] == 3
File without changes
File without changes
File without changes
File without changes
File without changes