locoformer 0.0.17__tar.gz → 0.0.30__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: locoformer
3
- Version: 0.0.17
3
+ Version: 0.0.30
4
4
  Summary: LocoFormer
5
5
  Project-URL: Homepage, https://pypi.org/project/locoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/locoformer
@@ -38,6 +38,7 @@ Requires-Dist: assoc-scan
38
38
  Requires-Dist: beartype
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
+ Requires-Dist: hl-gauss-pytorch>=0.2.0
41
42
  Requires-Dist: rotary-embedding-torch
42
43
  Requires-Dist: torch>=2.4
43
44
  Requires-Dist: x-mlps-pytorch
@@ -54,7 +55,7 @@ Description-Content-Type: text/markdown
54
55
 
55
56
  [LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
56
57
 
57
- The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
58
+ The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) and extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
58
59
 
59
60
  ## Sponsors
60
61
 
@@ -4,7 +4,7 @@
4
4
 
5
5
  [LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
6
6
 
7
- The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
7
+ The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) and extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
8
8
 
9
9
  ## Sponsors
10
10
 
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+ from typing import Callable
2
3
  from functools import partial
3
4
 
4
5
  from pathlib import Path
@@ -16,7 +17,7 @@ import torch
16
17
  from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
17
18
  import torch.nn.functional as F
18
19
  from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
19
- from torch.utils._pytree import tree_map
20
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
20
21
  from torch.utils.data import Dataset, DataLoader
21
22
  from torch.optim import Optimizer
22
23
 
@@ -26,12 +27,16 @@ from einops.layers.torch import Rearrange
26
27
 
27
28
  from rotary_embedding_torch import RotaryEmbedding
28
29
 
30
+ from hl_gauss_pytorch import HLGaussLoss
31
+
29
32
  from assoc_scan import AssocScan
30
33
 
31
34
  # constants
32
35
 
33
36
  LinearNoBias = partial(Linear, bias = False)
34
37
 
38
+ Cache = namedtuple('Cache', ('curr_timestep', 'kv_cache')) # (int, Tensor)
39
+
35
40
  # helper functions
36
41
 
37
42
  def exists(v):
@@ -43,6 +48,9 @@ def default(v, d):
43
48
  def first(arr):
44
49
  return arr[0]
45
50
 
51
+ def xnor(x, y):
52
+ return not (x ^ y)
53
+
46
54
  def divisible_by(num, den):
47
55
  return (num % den) == 0
48
56
 
@@ -51,6 +59,9 @@ def divisible_by(num, den):
51
59
  def log(t, eps = 1e-20):
52
60
  return t.clamp_min(eps).log()
53
61
 
62
+ def is_empty(t):
63
+ return t.numel() == 0
64
+
54
65
  def tree_map_tensor(x, fn):
55
66
  return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
56
67
 
@@ -67,6 +78,9 @@ def pad_at_dim(
67
78
  zeros = ((0, 0) * dims_from_right)
68
79
  return F.pad(t, (*zeros, *pad), value = value)
69
80
 
81
+ def normalize(t, eps = 1e-5):
82
+ return (t - t.mean()) / t.std().clamp_min(eps)
83
+
70
84
  def calc_entropy(logits):
71
85
  prob = logits.softmax(dim = -1)
72
86
  return -(prob * log(prob)).sum(dim = -1)
@@ -250,6 +264,57 @@ class ReplayDataset(Dataset):
250
264
 
251
265
  return data
252
266
 
267
+ class RemappedReplayDataset(Dataset):
268
+ def __init__(
269
+ self,
270
+ dataset: ReplayDataset,
271
+ episode_mapping: Tensor | list[list[int]],
272
+ shuffle_episodes = False
273
+ ):
274
+ assert len(dataset) > 0
275
+ self.dataset = dataset
276
+
277
+ if is_tensor(episode_mapping):
278
+ assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
279
+ episode_mapping = episode_mapping.tolist()
280
+
281
+ self.episode_mapping = episode_mapping
282
+ self.shuffle_episodes = shuffle_episodes
283
+
284
+ def __len__(self):
285
+ return len(self.episode_mapping)
286
+
287
+ def __getitem__(self, idx):
288
+
289
+ episode_indices = self.episode_mapping[idx]
290
+
291
+ episode_indices = tensor(episode_indices)
292
+ episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
293
+
294
+ assert not is_empty(episode_indices)
295
+
296
+ if self.shuffle_episodes and episode_indices.numel() > 1:
297
+ num_episodes = len(episode_indices)
298
+ episode_indices = episode_indices[torch.randperm(num_episodes)]
299
+
300
+ episode_data = [self.dataset[i] for i in episode_indices.tolist()]
301
+
302
+ episode_lens = stack([data.pop('_lens') for data in episode_data])
303
+
304
+ keys = first(episode_data).keys()
305
+
306
+ values = [list(data.values()) for data in episode_data]
307
+
308
+ values = [cat(field_values) for field_values in zip(*values)] # concat across time
309
+
310
+ multi_episode_data = dict(zip(keys, values))
311
+
312
+ multi_episode_data['_lens'] = episode_lens.sum()
313
+
314
+ multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
315
+
316
+ return multi_episode_data
317
+
253
318
  class ReplayBuffer:
254
319
 
255
320
  @beartype
@@ -314,6 +379,9 @@ class ReplayBuffer:
314
379
 
315
380
  self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
316
381
 
382
+ def __len__(self):
383
+ return (self.episode_lens > 0).sum().item()
384
+
317
385
  def reset_(self):
318
386
  self.episode_lens[:] = 0
319
387
  self.episode_index = 0
@@ -375,15 +443,91 @@ class ReplayBuffer:
375
443
 
376
444
  return self.memory_namedtuple(**data)
377
445
 
378
- def dataset(self) -> Dataset:
446
+ def dataset(
447
+ self,
448
+ episode_mapping: Tensor | list[list[int]] | None = None,
449
+ ) -> Dataset:
379
450
  self.flush()
380
451
 
381
- return ReplayDataset(self.folder)
452
+ dataset = ReplayDataset(self.folder)
453
+
454
+ if not exists(episode_mapping):
455
+ return dataset
382
456
 
383
- def dataloader(self, batch_size, **kwargs) -> DataLoader:
457
+ return RemappedReplayDataset(dataset, episode_mapping)
458
+
459
+ def dataloader(
460
+ self,
461
+ batch_size,
462
+ episode_mapping: Tensor | list[list[int]] | None = None,
463
+ **kwargs
464
+ ) -> DataLoader:
384
465
  self.flush()
385
466
 
386
- return DataLoader(self.dataset(), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
467
+ return DataLoader(self.dataset(episode_mapping), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
468
+
469
+ # normalization + conditioning (needed for the commands to the robot)
470
+
471
+ class MaybeAdaRMSNormWrapper(Module):
472
+ def __init__(
473
+ self,
474
+ fn: Module,
475
+ dim,
476
+ dim_cond = None
477
+ ):
478
+ super().__init__()
479
+ condition = exists(dim_cond)
480
+
481
+ self.fn = fn
482
+ self.norm = nn.RMSNorm(dim, elementwise_affine = not condition)
483
+
484
+ self.accept_condition = condition
485
+
486
+ if condition:
487
+ self.to_gamma = LinearNoBias(dim_cond, dim)
488
+ self.to_ada_norm_zero = nn.Linear(dim_cond, dim)
489
+
490
+ nn.init.zeros_(self.to_gamma.weight, 0.)
491
+ nn.init.zeros_(self.to_ada_norm_zero.weight, 0.)
492
+ nn.init.constant_(self.to_ada_norm_zero.bias, -5.)
493
+
494
+ def forward(
495
+ self,
496
+ x,
497
+ cond = None,
498
+ **kwargs
499
+ ):
500
+
501
+ need_cond = self.accept_condition
502
+ assert xnor(exists(cond), need_cond)
503
+
504
+ prenormed = self.norm(x)
505
+
506
+ if need_cond:
507
+ if cond.ndim == 2:
508
+ cond = rearrange(cond, 'b d -> b 1 d')
509
+
510
+ scale_in = self.to_gamma(cond)
511
+ prenormed = prenormed * (scale_in + 1.)
512
+
513
+ all_fn_out = self.fn(prenormed, **kwargs)
514
+
515
+ if not need_cond:
516
+ return all_fn_out
517
+
518
+ # function may return multiple args
519
+
520
+ (out, *rest), tree_spec = tree_flatten(all_fn_out)
521
+
522
+ if need_cond:
523
+ scale_out = self.to_ada_norm_zero(cond).sigmoid()
524
+ out = out * scale_out
525
+
526
+ # restore
527
+
528
+ all_fn_out = tree_unflatten((out, *rest), tree_spec)
529
+
530
+ return all_fn_out
387
531
 
388
532
  # transformer-xl with ppo
389
533
 
@@ -394,15 +538,12 @@ class Attention(Module):
394
538
  window_size,
395
539
  dim_head = 64,
396
540
  heads = 8,
397
- pre_rmsnorm = True,
398
541
  fixed_window_size = False,
399
542
  accept_value_residual = False
400
543
  ):
401
544
  super().__init__()
402
545
  self.scale = dim_head ** -0.5
403
546
 
404
- self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
405
-
406
547
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
407
548
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
408
549
 
@@ -446,8 +587,6 @@ class Attention(Module):
446
587
 
447
588
  device = tokens.device
448
589
 
449
- tokens = self.norm(tokens)
450
-
451
590
  q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
452
591
 
453
592
  q, k, v = map(self.split_heads, (q, k, v))
@@ -536,19 +675,24 @@ class TransformerXL(Module):
536
675
  dim_head = 64,
537
676
  heads = 8,
538
677
  expansion_factor = 4.,
678
+ dim_cond = None,
539
679
  final_norm = True,
540
680
  fixed_window_size = False,
541
681
  ):
542
682
  super().__init__()
543
683
 
684
+ condition = exists(dim_cond)
685
+
686
+ norm_fn = partial(MaybeAdaRMSNormWrapper, dim = dim, dim_cond = dim_cond)
687
+
544
688
  layers = ModuleList([])
545
689
 
546
690
  for i in range(depth):
547
691
  is_first = i == 0
548
692
 
549
- attn = Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first)
693
+ attn = norm_fn(Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first))
550
694
 
551
- ff = FeedForward(dim = dim, expansion_factor = expansion_factor)
695
+ ff = norm_fn(FeedForward(dim = dim, expansion_factor = expansion_factor))
552
696
 
553
697
  layers.append(ModuleList([
554
698
  attn, ff
@@ -603,14 +747,21 @@ class Locoformer(Module):
603
747
  embedder: Module,
604
748
  unembedder: Module,
605
749
  transformer: dict | TransformerXL,
606
- value_network: Module | None = None,
607
750
  discount_factor = 0.999,
608
751
  gae_lam = 0.95,
609
752
  ppo_eps_clip = 0.2,
610
753
  ppo_entropy_weight = 0.01,
611
754
  ppo_value_clip = 0.4,
755
+ dim_value_input = None, # needs to be set for value network to be available
756
+ value_network: Module = nn.Identity(),
757
+ reward_range: tuple[float, float] | None = None,
758
+ reward_shaping_fns: list[Callable[[Tensor], float | Tensor]] | None = None,
759
+ num_reward_bins = 32,
760
+ hl_gauss_loss_kwargs = dict(),
612
761
  value_loss_weight = 0.5,
613
- calc_gae_kwargs: dict = dict()
762
+ calc_gae_kwargs: dict = dict(),
763
+ recurrent_kv_cache = True,
764
+ use_spo = False # simple policy optimization https://arxiv.org/abs/2401.16025 - Levine's group (PI) verified it is more stable than PPO
614
765
  ):
615
766
  super().__init__()
616
767
 
@@ -622,11 +773,30 @@ class Locoformer(Module):
622
773
  self.embedder = embedder
623
774
  self.unembedder = unembedder
624
775
 
625
- self.value_network = value_network
626
-
627
776
  self.fixed_window_size = transformer.fixed_window_size
628
777
  self.window_size = transformer.window_size
629
778
 
779
+ # determine value network, using HL Gauss Layer
780
+
781
+ self.to_value_pred = None
782
+
783
+ if exists(dim_value_input):
784
+ assert exists(reward_range)
785
+
786
+ self.to_value_pred = nn.Sequential(
787
+ value_network,
788
+ LinearNoBias(dim_value_input, num_reward_bins)
789
+ )
790
+
791
+ reward_min, reward_max = reward_range
792
+
793
+ self.hl_gauss_loss = HLGaussLoss(
794
+ min_value = reward_min,
795
+ max_value = reward_max,
796
+ num_bins = num_reward_bins,
797
+ **hl_gauss_loss_kwargs
798
+ )
799
+
630
800
  # ppo related
631
801
 
632
802
  self.discount_factor = discount_factor
@@ -638,6 +808,19 @@ class Locoformer(Module):
638
808
 
639
809
  self.calc_gae_kwargs = calc_gae_kwargs
640
810
 
811
+ # maybe use spo
812
+
813
+ self.use_spo = use_spo
814
+
815
+ # maybe recurrent kv cache, from Ding et al. https://arxiv.org/abs/2012.15688
816
+
817
+ self.recurrent_kv_cache = recurrent_kv_cache
818
+
819
+ # reward shaping function
820
+
821
+ self.has_reward_shaping = exists(reward_shaping_fns)
822
+ self.reward_shaping_fns = reward_shaping_fns
823
+
641
824
  # loss related
642
825
 
643
826
  self.register_buffer('zero', tensor(0.), persistent = False)
@@ -650,10 +833,10 @@ class Locoformer(Module):
650
833
  return self.unembedder.parameters()
651
834
 
652
835
  def critic_parameters(self):
653
- if not exists(self.value_network):
836
+ if not exists(self.to_value_pred):
654
837
  return []
655
838
 
656
- return self.value_network.parameters()
839
+ return self.to_value_pred.parameters()
657
840
 
658
841
  def ppo(
659
842
  self,
@@ -663,12 +846,20 @@ class Locoformer(Module):
663
846
  reward,
664
847
  old_value,
665
848
  mask,
849
+ episode_lens,
666
850
  actor_optim: Optimizer | None = None,
667
851
  critic_optim: Optimizer | None = None
668
852
  ):
669
853
  window_size = self.window_size
670
854
  total_learnable_tokens = mask.sum().item()
671
855
 
856
+ seq_len = state.shape[1]
857
+ gae_mask = einx.less('j, i -> i j', arange(seq_len, device = self.device), episode_lens)
858
+
859
+ advantage, returns = calc_gae(reward, old_value, masks = gae_mask, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
860
+
861
+ advantage = normalize(advantage)
862
+
672
863
  windowed_tensors = [
673
864
  t.split(window_size, dim = 1) for t in
674
865
  (
@@ -677,7 +868,9 @@ class Locoformer(Module):
677
868
  old_action_log_prob,
678
869
  reward,
679
870
  old_value,
680
- mask
871
+ mask,
872
+ advantage,
873
+ returns
681
874
  )
682
875
  ]
683
876
 
@@ -694,10 +887,12 @@ class Locoformer(Module):
694
887
  old_action_log_prob,
695
888
  reward,
696
889
  old_value,
697
- mask
890
+ mask,
891
+ advantage,
892
+ returns
698
893
  ) in zip(*windowed_tensors):
699
894
 
700
- (action_logits, value), cache = self.forward(state, cache = cache, detach_cache = True, return_values = True)
895
+ (action_logits, value_logits), cache = self.forward(state, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True)
701
896
  entropy = calc_entropy(action_logits)
702
897
 
703
898
  action = rearrange(action, 'b t -> b t 1')
@@ -709,9 +904,10 @@ class Locoformer(Module):
709
904
  eps_clip = self.ppo_eps_clip
710
905
  ratio = (log_prob - old_action_log_prob).exp()
711
906
 
712
- advantage, returns = calc_gae(reward, old_value, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
713
-
714
- actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
907
+ if self.use_spo:
908
+ actor_loss = -(ratio * advantage - (advantage.abs() * (ratio - 1.).square()) / (2 * eps_clip))
909
+ else:
910
+ actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
715
911
 
716
912
  actor_loss = actor_loss - self.ppo_entropy_weight * entropy
717
913
 
@@ -720,11 +916,13 @@ class Locoformer(Module):
720
916
 
721
917
  # update critic
722
918
 
723
- value_loss = F.mse_loss(returns, value, reduction = 'none')
919
+ value_loss = self.hl_gauss_loss(value_logits, returns, reduction = 'none')
724
920
 
725
921
  value_clip = self.ppo_value_clip
922
+ value = self.hl_gauss_loss(value_logits)
923
+
726
924
  clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
727
- clipped_value_loss = F.mse_loss(returns, clipped_value, reduction = 'none')
925
+ clipped_value_loss = self.hl_gauss_loss(clipped_value, returns, reduction = 'none')
728
926
 
729
927
  critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
730
928
 
@@ -750,28 +948,48 @@ class Locoformer(Module):
750
948
 
751
949
  return mean_actor_loss.detach(), mean_critic_loss.detach()
752
950
 
951
+ def state_to_rewards(
952
+ self,
953
+ state
954
+ ) -> Tensor:
955
+
956
+ assert self.has_reward_shaping
957
+
958
+ rewards = [fn(state) for fn in self.reward_shaping_fns]
959
+
960
+ rewards = [tensor(reward) if not is_tensor(reward) else reward for reward in rewards]
961
+ return stack(rewards)
962
+
753
963
  def wrap_env_functions(self, env):
754
964
 
755
- def wrapped_reset(*args, **kwargs):
756
- state, _ = env.reset(*args, **kwargs)
965
+ def transform_output(el):
966
+ if isinstance(el, ndarray):
967
+ return from_numpy(el)
968
+ elif isinstance(el, (int, bool, float)):
969
+ return tensor(el)
970
+ else:
971
+ return el
757
972
 
758
- if isinstance(state, ndarray):
759
- state = from_numpy(state)
973
+ def wrapped_reset(*args, **kwargs):
974
+ env_reset_out = env.reset(*args, **kwargs)
760
975
 
761
- return state, _
976
+ return tree_map(transform_output, env_reset_out)
762
977
 
763
978
  def wrapped_step(action, *args, **kwargs):
764
- out = env.step(action.item(), *args, **kwargs)
765
979
 
766
- def transform_output(el):
767
- if isinstance(el, ndarray):
768
- return from_numpy(el)
769
- elif isinstance(el, (int, bool, float)):
770
- return tensor(el)
771
- else:
772
- return el
980
+ if is_tensor(action):
981
+ action = action.item()
982
+
983
+ env_step_out = env.step(action, *args, **kwargs)
773
984
 
774
- return tree_map(transform_output, out)
985
+ env_step_out_torch = tree_map(transform_output, env_step_out)
986
+
987
+ if not self.has_reward_shaping:
988
+ return env_step_out_torch
989
+
990
+ shaped_rewards = self.state_to_rewards(env_step_out_torch)
991
+
992
+ return env_step_out_torch, shaped_rewards
775
993
 
776
994
  return wrapped_reset, wrapped_step
777
995
 
@@ -781,6 +999,7 @@ class Locoformer(Module):
781
999
  inference_mode = False,
782
1000
  has_batch_dim = False,
783
1001
  has_time_dim = False,
1002
+ state_time_dim = 1,
784
1003
  **kwargs
785
1004
  ):
786
1005
  window_size = self.window_size
@@ -796,23 +1015,16 @@ class Locoformer(Module):
796
1015
  state = rearrange(state, '... -> 1 ...')
797
1016
 
798
1017
  if not has_time_dim:
799
- state = rearrange(state, '... d -> ... 1 d')
1018
+ state = state.unsqueeze(state_time_dim)
800
1019
 
801
1020
  # forwards
802
1021
 
803
1022
  out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
804
1023
 
805
- # handle cache
806
-
807
- cache_len = cache.shape[-2]
808
-
809
- if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
810
- cache = cache[..., -window_size:, :]
811
-
812
1024
  # maybe remove batch or time
813
1025
 
814
1026
  if not has_time_dim:
815
- out = tree_map_tensor(out, lambda t: rearrange(t, '... 1 d -> ... d'))
1027
+ out = tree_map_tensor(out, lambda t: t.squeeze(state_time_dim))
816
1028
 
817
1029
  if not has_batch_dim:
818
1030
  out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
@@ -841,16 +1053,35 @@ class Locoformer(Module):
841
1053
  def forward(
842
1054
  self,
843
1055
  state: Tensor,
844
- cache: Tensor | None = None,
1056
+ cache: Cache | None = None,
845
1057
  detach_cache = False,
846
- return_values = False
1058
+ return_values = False,
1059
+ return_raw_value_logits = False
847
1060
  ):
848
1061
 
849
1062
  state = state.to(self.device)
850
1063
 
851
1064
  tokens = self.embedder(state)
852
1065
 
853
- embed, kv_cache = self.transformer(tokens, cache = cache, return_kv_cache = True)
1066
+ # time
1067
+
1068
+ time = tokens.shape[-2]
1069
+
1070
+ # destruct the cache for the current timestep and the cache
1071
+
1072
+ prev_kv_cache = None
1073
+ timestep_start = 0
1074
+
1075
+ if exists(cache):
1076
+ timestep_start, prev_kv_cache = cache
1077
+
1078
+ # an assert - make sure during training or inference, forward never gets anything that crosses the window segment boundary, to open up some possibilities with extending memory
1079
+
1080
+ assert ((timestep_start % self.window_size) + time) <= self.window_size
1081
+
1082
+ # attention
1083
+
1084
+ embed, kv_cache = self.transformer(tokens, cache = prev_kv_cache, return_kv_cache = True)
854
1085
 
855
1086
  # unembed to actions - in language models this would be the next state
856
1087
 
@@ -866,16 +1097,29 @@ class Locoformer(Module):
866
1097
  # handle returning of values
867
1098
 
868
1099
  if return_values:
869
- assert exists(self.value_network)
1100
+ assert exists(self.to_value_pred)
870
1101
 
871
- values = self.value_network(embed)
1102
+ values = self.to_value_pred(embed)
872
1103
 
873
- if values.ndim == 3:
874
- assert values.shape[-1] == 1
875
- values = rearrange(values, '... 1 -> ...')
1104
+ if not return_raw_value_logits:
1105
+ values = self.hl_gauss_loss(values) # converts the value logits to scalar values
876
1106
 
877
1107
  out = (out, values)
878
1108
 
879
1109
  # output and cache
880
1110
 
881
- return out, kv_cache
1111
+ next_timestep = time + timestep_start
1112
+
1113
+ # handle curtailing kv cache at the right intervals
1114
+
1115
+ window_size = self.window_size
1116
+
1117
+ if self.fixed_window_size or divisible_by(next_timestep, window_size * 2):
1118
+ kv_cache = kv_cache[..., -window_size:, :]
1119
+
1120
+ # maybe recurrent cache - shift the kv cache from one layer above to the one below, for extending on receptive field of past
1121
+
1122
+ if self.recurrent_kv_cache and divisible_by(next_timestep, window_size):
1123
+ kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
1124
+
1125
+ return out, (next_timestep, kv_cache)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "locoformer"
3
- version = "0.0.17"
3
+ version = "0.0.30"
4
4
  description = "LocoFormer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -30,6 +30,7 @@ dependencies = [
30
30
  "beartype",
31
31
  "einx>=0.3.0",
32
32
  "einops>=0.8.0",
33
+ "hl-gauss-pytorch>=0.2.0",
33
34
  "rotary-embedding-torch",
34
35
  "torch>=2.4",
35
36
  "x-mlps-pytorch",
@@ -2,18 +2,25 @@ import pytest
2
2
  param = pytest.mark.parametrize
3
3
 
4
4
  import torch
5
+ from torch import nn
5
6
  from x_mlps_pytorch import MLP
6
7
 
7
8
  from einops import rearrange
8
9
 
9
- def test_locoformer():
10
- from locoformer.locoformer import Locoformer
11
- from torch import nn
10
+ from locoformer.locoformer import Locoformer
11
+
12
+ @param('recurrent_kv_cache', (False, True))
13
+ def test_locoformer(
14
+ recurrent_kv_cache
15
+ ):
12
16
 
13
17
  model = Locoformer(
14
18
  embedder = nn.Embedding(256, 128),
15
19
  unembedder = nn.Linear(128, 256, bias = False),
16
- value_network = MLP(128, 32, 1),
20
+ value_network = MLP(128, 64, 32),
21
+ dim_value_input = 32,
22
+ reward_range = (-100., 100.),
23
+ recurrent_kv_cache = recurrent_kv_cache,
17
24
  transformer = dict(
18
25
  dim = 128,
19
26
  depth = 1,
@@ -83,4 +90,58 @@ def test_replay():
83
90
 
84
91
  dataloader = replay_buffer.dataloader(batch_size = 3)
85
92
 
86
- assert next(iter(dataloader))['state'].shape[0] == 3
93
+ assert next(iter(dataloader))['state'].shape[0] == 3
94
+
95
+ # we will now consider consecutive pairs of episodes as 2 trials to be used for in-context adaptation
96
+ # but realistically there will be a function that converts a given ReplayBuffer -> Int[batch, episode_indices]
97
+
98
+ from torch import stack, arange
99
+
100
+ episode_indices = arange(len(replay_buffer))
101
+ remapped_episodes = stack((episode_indices[:-1], episode_indices[1:]))
102
+
103
+ dataloader = replay_buffer.dataloader(
104
+ batch_size = 1,
105
+ episode_mapping = remapped_episodes
106
+ )
107
+
108
+ assert next(iter(dataloader))['_lens'][0] == (3 + 5) # first and second episodes are concatted together timewise
109
+
110
+ def test_reward_shaping():
111
+
112
+ model = Locoformer(
113
+ embedder = nn.Embedding(256, 128),
114
+ unembedder = nn.Linear(128, 256, bias = False),
115
+ value_network = MLP(128, 64, 32),
116
+ dim_value_input = 32,
117
+ reward_range = (-100., 100.),
118
+ reward_shaping_fns = [
119
+ lambda state: (state[3] - 2.5).pow(2).mean(),
120
+ lambda state: state[4:6].norm(dim = -1)
121
+ ],
122
+ transformer = dict(
123
+ dim = 128,
124
+ depth = 1,
125
+ window_size = 512
126
+ )
127
+ )
128
+
129
+ import numpy as np
130
+
131
+ class MockEnv:
132
+ def reset(self):
133
+ return np.random.normal(size = (10,))
134
+
135
+ def step(self, *args, **kwargs):
136
+ return np.random.normal(size = (10,))
137
+
138
+
139
+ env = MockEnv()
140
+
141
+ reset_fn, step_fn = model.wrap_env_functions(env)
142
+
143
+ reset_fn()
144
+
145
+ _, rewards = step_fn(3)
146
+
147
+ assert len(rewards) == 2
@@ -160,7 +160,7 @@ for i in range(NUM_BATCHES):
160
160
  optim.step()
161
161
  optim.zero_grad()
162
162
 
163
- if divisible_by(i + 1, GENERATE_EVERY):
163
+ if divisible_by(i, GENERATE_EVERY):
164
164
  model.eval()
165
165
 
166
166
  val_seq = next(val_loader_iter)
@@ -25,7 +25,6 @@ import torch.nn.functional as F
25
25
  from torch.utils.data import TensorDataset, DataLoader
26
26
  from torch.optim import Adam
27
27
 
28
- import einx
29
28
  from einops import rearrange
30
29
 
31
30
  from locoformer.locoformer import Locoformer, ReplayBuffer
@@ -60,8 +59,6 @@ def learn(
60
59
  batch_size = 16,
61
60
  epochs = 2,
62
61
  ):
63
- device = accelerator.device
64
-
65
62
  dl = replay.dataloader(batch_size = batch_size, shuffle = True)
66
63
  model, dl, actor_optim, critic_optim = accelerator.prepare(model, dl, actor_optim, critic_optim)
67
64
 
@@ -70,18 +67,14 @@ def learn(
70
67
 
71
68
  data = SimpleNamespace(**data)
72
69
 
73
- seq_len = data.state.shape[1]
74
-
75
- value_mask = einx.less('j, i -> i j', arange(seq_len, device = device), data._lens)
76
- value = torch.where(value_mask, data.value, 0.)
77
-
78
70
  actor_loss, critic_loss = model.ppo(
79
71
  state = data.state,
80
72
  action = data.action,
81
73
  old_action_log_prob = data.action_log_prob,
82
74
  reward = data.reward,
83
- old_value = value,
75
+ old_value = data.value,
84
76
  mask = data.learnable,
77
+ episode_lens = data._lens,
85
78
  actor_optim = actor_optim,
86
79
  critic_optim = critic_optim
87
80
  )
@@ -94,7 +87,7 @@ def main(
94
87
  env_name = 'LunarLander-v3',
95
88
  num_episodes = 50_000,
96
89
  max_timesteps = 500,
97
- num_episodes_before_learn = 32,
90
+ num_episodes_before_learn = 64,
98
91
  clear_video = True,
99
92
  video_folder = 'recordings',
100
93
  record_every_episode = 250,
@@ -105,7 +98,8 @@ def main(
105
98
  ppo_eps_clip = 0.2,
106
99
  ppo_entropy_weight = .01,
107
100
  batch_size = 16,
108
- epochs = 2
101
+ epochs = 3,
102
+ reward_range = (-100., 100.)
109
103
  ):
110
104
 
111
105
  # accelerate
@@ -153,7 +147,6 @@ def main(
153
147
  locoformer = Locoformer(
154
148
  embedder = MLP(dim_state, 64, bias = False),
155
149
  unembedder = MLP(64, num_actions, bias = False),
156
- value_network = MLP(64, 1, bias = False),
157
150
  transformer = dict(
158
151
  dim = 64,
159
152
  dim_head = 32,
@@ -165,16 +158,20 @@ def main(
165
158
  gae_lam = gae_lam,
166
159
  ppo_eps_clip = ppo_eps_clip,
167
160
  ppo_entropy_weight = ppo_entropy_weight,
161
+ use_spo = True,
162
+ value_network = MLP(64, 64),
163
+ dim_value_input = 64,
164
+ reward_range = reward_range,
165
+ hl_gauss_loss_kwargs = dict(),
166
+ recurrent_kv_cache = True,
168
167
  calc_gae_kwargs = dict(
169
168
  use_accelerated = False
170
- )
169
+ ),
171
170
  ).to(device)
172
171
 
173
172
  optim_actor = Adam([*locoformer.transformer.parameters(), *locoformer.actor_parameters()], lr = learning_rate, betas = betas)
174
173
  optim_critic = Adam([*locoformer.transformer.parameters(), *locoformer.critic_parameters()], lr = learning_rate, betas = betas)
175
174
 
176
- timesteps_learn = 0
177
-
178
175
  # able to wrap the env for all values to torch tensors and back
179
176
  # all environments should follow usual MDP interface, domain randomization should be given at instantiation
180
177
 
@@ -205,7 +202,8 @@ def main(
205
202
 
206
203
  # append to memory
207
204
 
208
- done = truncated or terminated
205
+ exceeds_max_timesteps = timestep == (max_timesteps - 1)
206
+ done = truncated or terminated or tensor(exceeds_max_timesteps)
209
207
 
210
208
  # get log prob of action
211
209
 
@@ -222,23 +220,24 @@ def main(
222
220
  learnable = tensor(True)
223
221
  )
224
222
 
225
- # handle bootstrap value, which is a non-learnable timestep added with the next value for GAE
226
- # only if terminated signal not detected
223
+ # increment counters
227
224
 
228
- if not terminated:
229
- _, next_value = stateful_forward(next_state, return_values = True)
225
+ timestep += 1
230
226
 
231
- memory._replace(value = next_value, learnable = False)
227
+ # break if done or exceed max timestep
232
228
 
233
- replay.store(**memory._asdict())
229
+ if done:
234
230
 
235
- # increment counters
231
+ # handle bootstrap value, which is a non-learnable timestep added with the next value for GAE
232
+ # only if terminated signal not detected
236
233
 
237
- timestep += 1
234
+ if not terminated:
235
+ _, next_value = stateful_forward(next_state, return_values = True)
238
236
 
239
- # break if done or exceed max timestep
237
+ memory._replace(value = next_value, learnable = False)
238
+
239
+ replay.store(**memory._asdict())
240
240
 
241
- if done or timestep >= max_timesteps:
242
241
  break
243
242
 
244
243
  state = next_state
File without changes
File without changes
File without changes
File without changes
File without changes