locoformer 0.0.11__tar.gz → 0.0.30__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: locoformer
3
- Version: 0.0.11
3
+ Version: 0.0.30
4
4
  Summary: LocoFormer
5
5
  Project-URL: Homepage, https://pypi.org/project/locoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/locoformer
@@ -38,6 +38,7 @@ Requires-Dist: assoc-scan
38
38
  Requires-Dist: beartype
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
+ Requires-Dist: hl-gauss-pytorch>=0.2.0
41
42
  Requires-Dist: rotary-embedding-torch
42
43
  Requires-Dist: torch>=2.4
43
44
  Requires-Dist: x-mlps-pytorch
@@ -54,7 +55,7 @@ Description-Content-Type: text/markdown
54
55
 
55
56
  [LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
56
57
 
57
- The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
58
+ The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) and extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
58
59
 
59
60
  ## Sponsors
60
61
 
@@ -4,7 +4,7 @@
4
4
 
5
5
  [LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
6
6
 
7
- The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
7
+ The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) and extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
8
8
 
9
9
  ## Sponsors
10
10
 
@@ -1,8 +1,10 @@
1
1
  from __future__ import annotations
2
+ from typing import Callable
2
3
  from functools import partial
3
4
 
4
5
  from pathlib import Path
5
6
  from contextlib import contextmanager
7
+ from collections import namedtuple
6
8
 
7
9
  import numpy as np
8
10
  from numpy import ndarray
@@ -15,8 +17,9 @@ import torch
15
17
  from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
16
18
  import torch.nn.functional as F
17
19
  from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
18
- from torch.utils._pytree import tree_map
20
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
19
21
  from torch.utils.data import Dataset, DataLoader
22
+ from torch.optim import Optimizer
20
23
 
21
24
  import einx
22
25
  from einops import rearrange, einsum
@@ -24,10 +27,16 @@ from einops.layers.torch import Rearrange
24
27
 
25
28
  from rotary_embedding_torch import RotaryEmbedding
26
29
 
30
+ from hl_gauss_pytorch import HLGaussLoss
31
+
27
32
  from assoc_scan import AssocScan
28
33
 
34
+ # constants
35
+
29
36
  LinearNoBias = partial(Linear, bias = False)
30
37
 
38
+ Cache = namedtuple('Cache', ('curr_timestep', 'kv_cache')) # (int, Tensor)
39
+
31
40
  # helper functions
32
41
 
33
42
  def exists(v):
@@ -39,15 +48,23 @@ def default(v, d):
39
48
  def first(arr):
40
49
  return arr[0]
41
50
 
51
+ def xnor(x, y):
52
+ return not (x ^ y)
53
+
42
54
  def divisible_by(num, den):
43
55
  return (num % den) == 0
44
56
 
57
+ # tensor helpers
58
+
59
+ def log(t, eps = 1e-20):
60
+ return t.clamp_min(eps).log()
61
+
62
+ def is_empty(t):
63
+ return t.numel() == 0
64
+
45
65
  def tree_map_tensor(x, fn):
46
66
  return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
47
67
 
48
- def detach_all(x):
49
- return tree_map_tensor(x, lambda t: t.detach())
50
-
51
68
  def pad_at_dim(
52
69
  t,
53
70
  pad: tuple[int, int],
@@ -61,13 +78,20 @@ def pad_at_dim(
61
78
  zeros = ((0, 0) * dims_from_right)
62
79
  return F.pad(t, (*zeros, *pad), value = value)
63
80
 
81
+ def normalize(t, eps = 1e-5):
82
+ return (t - t.mean()) / t.std().clamp_min(eps)
83
+
84
+ def calc_entropy(logits):
85
+ prob = logits.softmax(dim = -1)
86
+ return -(prob * log(prob)).sum(dim = -1)
87
+
64
88
  # generalized advantage estimate
65
89
 
66
90
  @torch.no_grad()
67
91
  def calc_gae(
68
92
  rewards,
69
93
  values,
70
- masks,
94
+ masks = None,
71
95
  gamma = 0.99,
72
96
  lam = 0.95,
73
97
  use_accelerated = None
@@ -78,6 +102,9 @@ def calc_gae(
78
102
  values = F.pad(values, (0, 1), value = 0.)
79
103
  values, values_next = values[..., :-1], values[..., 1:]
80
104
 
105
+ if not exists(masks):
106
+ masks = torch.ones_like(values)
107
+
81
108
  delta = rewards + gamma * values_next * masks - values
82
109
  gates = gamma * lam * masks
83
110
 
@@ -87,7 +114,7 @@ def calc_gae(
87
114
 
88
115
  returns = gae + values
89
116
 
90
- return returns
117
+ return gae, returns
91
118
 
92
119
  # transformer-xl mask w/ flex attn
93
120
 
@@ -129,8 +156,8 @@ def create_xl_mask(
129
156
  # handle intra-episodic attention if needed
130
157
 
131
158
  if exists(episode_ids):
132
- q_episode = episodes[b, q + offset]
133
- k_episode = episodes[b, k]
159
+ q_episode = episode_ids[b, q + offset]
160
+ k_episode = episode_ids[b, k]
134
161
 
135
162
  intra_episode_mask = q_episode == k_episode
136
163
  mask = mask & intra_episode_mask
@@ -231,12 +258,63 @@ class ReplayDataset(Dataset):
231
258
 
232
259
  episode_len = self.episode_lens[episode_index]
233
260
 
234
- data = {field: torch.from_numpy(memmap[episode_index, :episode_len]) for field, memmap in self.memmaps.items()}
261
+ data = {field: from_numpy(memmap[episode_index, :episode_len].copy()) for field, memmap in self.memmaps.items()}
235
262
 
236
263
  data['_lens'] = tensor(episode_len)
237
264
 
238
265
  return data
239
266
 
267
+ class RemappedReplayDataset(Dataset):
268
+ def __init__(
269
+ self,
270
+ dataset: ReplayDataset,
271
+ episode_mapping: Tensor | list[list[int]],
272
+ shuffle_episodes = False
273
+ ):
274
+ assert len(dataset) > 0
275
+ self.dataset = dataset
276
+
277
+ if is_tensor(episode_mapping):
278
+ assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
279
+ episode_mapping = episode_mapping.tolist()
280
+
281
+ self.episode_mapping = episode_mapping
282
+ self.shuffle_episodes = shuffle_episodes
283
+
284
+ def __len__(self):
285
+ return len(self.episode_mapping)
286
+
287
+ def __getitem__(self, idx):
288
+
289
+ episode_indices = self.episode_mapping[idx]
290
+
291
+ episode_indices = tensor(episode_indices)
292
+ episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
293
+
294
+ assert not is_empty(episode_indices)
295
+
296
+ if self.shuffle_episodes and episode_indices.numel() > 1:
297
+ num_episodes = len(episode_indices)
298
+ episode_indices = episode_indices[torch.randperm(num_episodes)]
299
+
300
+ episode_data = [self.dataset[i] for i in episode_indices.tolist()]
301
+
302
+ episode_lens = stack([data.pop('_lens') for data in episode_data])
303
+
304
+ keys = first(episode_data).keys()
305
+
306
+ values = [list(data.values()) for data in episode_data]
307
+
308
+ values = [cat(field_values) for field_values in zip(*values)] # concat across time
309
+
310
+ multi_episode_data = dict(zip(keys, values))
311
+
312
+ multi_episode_data['_lens'] = episode_lens.sum()
313
+
314
+ multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
315
+
316
+ return multi_episode_data
317
+
240
318
  class ReplayBuffer:
241
319
 
242
320
  @beartype
@@ -299,6 +377,16 @@ class ReplayBuffer:
299
377
  self.shapes[field_name] = shape
300
378
  self.dtypes[field_name] = dtype
301
379
 
380
+ self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
381
+
382
+ def __len__(self):
383
+ return (self.episode_lens > 0).sum().item()
384
+
385
+ def reset_(self):
386
+ self.episode_lens[:] = 0
387
+ self.episode_index = 0
388
+ self.timestep_index = 0
389
+
302
390
  def advance_episode(self):
303
391
  self.episode_index = (self.episode_index + 1) % self.max_episodes
304
392
  self.timestep_index = 0
@@ -353,15 +441,93 @@ class ReplayBuffer:
353
441
 
354
442
  self.timestep_index += 1
355
443
 
356
- def dataset(self) -> Dataset:
444
+ return self.memory_namedtuple(**data)
445
+
446
+ def dataset(
447
+ self,
448
+ episode_mapping: Tensor | list[list[int]] | None = None,
449
+ ) -> Dataset:
357
450
  self.flush()
358
451
 
359
- return ReplayDataset(self.folder)
452
+ dataset = ReplayDataset(self.folder)
453
+
454
+ if not exists(episode_mapping):
455
+ return dataset
360
456
 
361
- def dataloader(self, **kwargs) -> DataLoader:
457
+ return RemappedReplayDataset(dataset, episode_mapping)
458
+
459
+ def dataloader(
460
+ self,
461
+ batch_size,
462
+ episode_mapping: Tensor | list[list[int]] | None = None,
463
+ **kwargs
464
+ ) -> DataLoader:
362
465
  self.flush()
363
466
 
364
- return DataLoader(self.dataset(), collate_fn = collate_var_time, **kwargs)
467
+ return DataLoader(self.dataset(episode_mapping), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
468
+
469
+ # normalization + conditioning (needed for the commands to the robot)
470
+
471
+ class MaybeAdaRMSNormWrapper(Module):
472
+ def __init__(
473
+ self,
474
+ fn: Module,
475
+ dim,
476
+ dim_cond = None
477
+ ):
478
+ super().__init__()
479
+ condition = exists(dim_cond)
480
+
481
+ self.fn = fn
482
+ self.norm = nn.RMSNorm(dim, elementwise_affine = not condition)
483
+
484
+ self.accept_condition = condition
485
+
486
+ if condition:
487
+ self.to_gamma = LinearNoBias(dim_cond, dim)
488
+ self.to_ada_norm_zero = nn.Linear(dim_cond, dim)
489
+
490
+ nn.init.zeros_(self.to_gamma.weight, 0.)
491
+ nn.init.zeros_(self.to_ada_norm_zero.weight, 0.)
492
+ nn.init.constant_(self.to_ada_norm_zero.bias, -5.)
493
+
494
+ def forward(
495
+ self,
496
+ x,
497
+ cond = None,
498
+ **kwargs
499
+ ):
500
+
501
+ need_cond = self.accept_condition
502
+ assert xnor(exists(cond), need_cond)
503
+
504
+ prenormed = self.norm(x)
505
+
506
+ if need_cond:
507
+ if cond.ndim == 2:
508
+ cond = rearrange(cond, 'b d -> b 1 d')
509
+
510
+ scale_in = self.to_gamma(cond)
511
+ prenormed = prenormed * (scale_in + 1.)
512
+
513
+ all_fn_out = self.fn(prenormed, **kwargs)
514
+
515
+ if not need_cond:
516
+ return all_fn_out
517
+
518
+ # function may return multiple args
519
+
520
+ (out, *rest), tree_spec = tree_flatten(all_fn_out)
521
+
522
+ if need_cond:
523
+ scale_out = self.to_ada_norm_zero(cond).sigmoid()
524
+ out = out * scale_out
525
+
526
+ # restore
527
+
528
+ all_fn_out = tree_unflatten((out, *rest), tree_spec)
529
+
530
+ return all_fn_out
365
531
 
366
532
  # transformer-xl with ppo
367
533
 
@@ -372,15 +538,12 @@ class Attention(Module):
372
538
  window_size,
373
539
  dim_head = 64,
374
540
  heads = 8,
375
- pre_rmsnorm = True,
376
541
  fixed_window_size = False,
377
542
  accept_value_residual = False
378
543
  ):
379
544
  super().__init__()
380
545
  self.scale = dim_head ** -0.5
381
546
 
382
- self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
383
-
384
547
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
385
548
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
386
549
 
@@ -421,12 +584,9 @@ class Attention(Module):
421
584
  return_kv_cache = False,
422
585
  ):
423
586
  seq_len = tokens.shape[-2]
424
- assert seq_len <= self.window_size
425
587
 
426
588
  device = tokens.device
427
589
 
428
- tokens = self.norm(tokens)
429
-
430
590
  q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
431
591
 
432
592
  q, k, v = map(self.split_heads, (q, k, v))
@@ -515,19 +675,24 @@ class TransformerXL(Module):
515
675
  dim_head = 64,
516
676
  heads = 8,
517
677
  expansion_factor = 4.,
678
+ dim_cond = None,
518
679
  final_norm = True,
519
680
  fixed_window_size = False,
520
681
  ):
521
682
  super().__init__()
522
683
 
684
+ condition = exists(dim_cond)
685
+
686
+ norm_fn = partial(MaybeAdaRMSNormWrapper, dim = dim, dim_cond = dim_cond)
687
+
523
688
  layers = ModuleList([])
524
689
 
525
690
  for i in range(depth):
526
691
  is_first = i == 0
527
692
 
528
- attn = Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first)
693
+ attn = norm_fn(Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first))
529
694
 
530
- ff = FeedForward(dim = dim, expansion_factor = expansion_factor)
695
+ ff = norm_fn(FeedForward(dim = dim, expansion_factor = expansion_factor))
531
696
 
532
697
  layers.append(ModuleList([
533
698
  attn, ff
@@ -582,7 +747,21 @@ class Locoformer(Module):
582
747
  embedder: Module,
583
748
  unembedder: Module,
584
749
  transformer: dict | TransformerXL,
585
- value_network: Module | None = None
750
+ discount_factor = 0.999,
751
+ gae_lam = 0.95,
752
+ ppo_eps_clip = 0.2,
753
+ ppo_entropy_weight = 0.01,
754
+ ppo_value_clip = 0.4,
755
+ dim_value_input = None, # needs to be set for value network to be available
756
+ value_network: Module = nn.Identity(),
757
+ reward_range: tuple[float, float] | None = None,
758
+ reward_shaping_fns: list[Callable[[Tensor], float | Tensor]] | None = None,
759
+ num_reward_bins = 32,
760
+ hl_gauss_loss_kwargs = dict(),
761
+ value_loss_weight = 0.5,
762
+ calc_gae_kwargs: dict = dict(),
763
+ recurrent_kv_cache = True,
764
+ use_spo = False # simple policy optimization https://arxiv.org/abs/2401.16025 - Levine's group (PI) verified it is more stable than PPO
586
765
  ):
587
766
  super().__init__()
588
767
 
@@ -594,11 +773,58 @@ class Locoformer(Module):
594
773
  self.embedder = embedder
595
774
  self.unembedder = unembedder
596
775
 
597
- self.value_network = value_network
598
-
599
776
  self.fixed_window_size = transformer.fixed_window_size
600
777
  self.window_size = transformer.window_size
601
778
 
779
+ # determine value network, using HL Gauss Layer
780
+
781
+ self.to_value_pred = None
782
+
783
+ if exists(dim_value_input):
784
+ assert exists(reward_range)
785
+
786
+ self.to_value_pred = nn.Sequential(
787
+ value_network,
788
+ LinearNoBias(dim_value_input, num_reward_bins)
789
+ )
790
+
791
+ reward_min, reward_max = reward_range
792
+
793
+ self.hl_gauss_loss = HLGaussLoss(
794
+ min_value = reward_min,
795
+ max_value = reward_max,
796
+ num_bins = num_reward_bins,
797
+ **hl_gauss_loss_kwargs
798
+ )
799
+
800
+ # ppo related
801
+
802
+ self.discount_factor = discount_factor
803
+ self.gae_lam = gae_lam
804
+ self.ppo_eps_clip = ppo_eps_clip
805
+ self.ppo_entropy_weight = ppo_entropy_weight
806
+ self.ppo_value_clip = ppo_value_clip
807
+ self.value_loss_weight = value_loss_weight
808
+
809
+ self.calc_gae_kwargs = calc_gae_kwargs
810
+
811
+ # maybe use spo
812
+
813
+ self.use_spo = use_spo
814
+
815
+ # maybe recurrent kv cache, from Ding et al. https://arxiv.org/abs/2012.15688
816
+
817
+ self.recurrent_kv_cache = recurrent_kv_cache
818
+
819
+ # reward shaping function
820
+
821
+ self.has_reward_shaping = exists(reward_shaping_fns)
822
+ self.reward_shaping_fns = reward_shaping_fns
823
+
824
+ # loss related
825
+
826
+ self.register_buffer('zero', tensor(0.), persistent = False)
827
+
602
828
  @property
603
829
  def device(self):
604
830
  return next(self.parameters()).device
@@ -607,33 +833,163 @@ class Locoformer(Module):
607
833
  return self.unembedder.parameters()
608
834
 
609
835
  def critic_parameters(self):
610
- if not exists(self.value_network):
836
+ if not exists(self.to_value_pred):
611
837
  return []
612
838
 
613
- return self.value_network.parameters()
839
+ return self.to_value_pred.parameters()
840
+
841
+ def ppo(
842
+ self,
843
+ state,
844
+ action,
845
+ old_action_log_prob,
846
+ reward,
847
+ old_value,
848
+ mask,
849
+ episode_lens,
850
+ actor_optim: Optimizer | None = None,
851
+ critic_optim: Optimizer | None = None
852
+ ):
853
+ window_size = self.window_size
854
+ total_learnable_tokens = mask.sum().item()
855
+
856
+ seq_len = state.shape[1]
857
+ gae_mask = einx.less('j, i -> i j', arange(seq_len, device = self.device), episode_lens)
858
+
859
+ advantage, returns = calc_gae(reward, old_value, masks = gae_mask, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
860
+
861
+ advantage = normalize(advantage)
862
+
863
+ windowed_tensors = [
864
+ t.split(window_size, dim = 1) for t in
865
+ (
866
+ state,
867
+ action,
868
+ old_action_log_prob,
869
+ reward,
870
+ old_value,
871
+ mask,
872
+ advantage,
873
+ returns
874
+ )
875
+ ]
876
+
877
+ mean_actor_loss = self.zero.clone()
878
+ mean_critic_loss = self.zero.clone()
879
+
880
+ # learn across windows
881
+
882
+ cache = None
883
+
884
+ for (
885
+ state,
886
+ action,
887
+ old_action_log_prob,
888
+ reward,
889
+ old_value,
890
+ mask,
891
+ advantage,
892
+ returns
893
+ ) in zip(*windowed_tensors):
894
+
895
+ (action_logits, value_logits), cache = self.forward(state, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True)
896
+ entropy = calc_entropy(action_logits)
897
+
898
+ action = rearrange(action, 'b t -> b t 1')
899
+ log_prob = action_logits.gather(-1, action)
900
+ log_prob = rearrange(log_prob, 'b t 1 -> b t')
901
+
902
+ # update actor, classic clipped surrogate loss
903
+
904
+ eps_clip = self.ppo_eps_clip
905
+ ratio = (log_prob - old_action_log_prob).exp()
906
+
907
+ if self.use_spo:
908
+ actor_loss = -(ratio * advantage - (advantage.abs() * (ratio - 1.).square()) / (2 * eps_clip))
909
+ else:
910
+ actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
911
+
912
+ actor_loss = actor_loss - self.ppo_entropy_weight * entropy
913
+
914
+ windowed_actor_loss = actor_loss[mask].sum() / total_learnable_tokens
915
+ windowed_actor_loss.backward(retain_graph = True)
916
+
917
+ # update critic
918
+
919
+ value_loss = self.hl_gauss_loss(value_logits, returns, reduction = 'none')
920
+
921
+ value_clip = self.ppo_value_clip
922
+ value = self.hl_gauss_loss(value_logits)
923
+
924
+ clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
925
+ clipped_value_loss = self.hl_gauss_loss(clipped_value, returns, reduction = 'none')
926
+
927
+ critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
928
+
929
+ windowed_critic_loss = critic_loss[mask].sum() / total_learnable_tokens
930
+ windowed_critic_loss.backward(retain_graph = True)
931
+
932
+ # accumulate
933
+
934
+ mean_actor_loss.add_(windowed_actor_loss)
935
+ mean_critic_loss.add_(windowed_critic_loss)
936
+
937
+ # optimizer update
938
+
939
+ if exists(actor_optim):
940
+ actor_optim.step()
941
+ actor_optim.zero_grad()
942
+
943
+ if exists(critic_optim):
944
+ critic_optim.step()
945
+ critic_optim.zero_grad()
946
+
947
+ # return losses for logging
948
+
949
+ return mean_actor_loss.detach(), mean_critic_loss.detach()
950
+
951
+ def state_to_rewards(
952
+ self,
953
+ state
954
+ ) -> Tensor:
955
+
956
+ assert self.has_reward_shaping
957
+
958
+ rewards = [fn(state) for fn in self.reward_shaping_fns]
959
+
960
+ rewards = [tensor(reward) if not is_tensor(reward) else reward for reward in rewards]
961
+ return stack(rewards)
614
962
 
615
963
  def wrap_env_functions(self, env):
616
964
 
617
- def wrapped_reset(*args, **kwargs):
618
- state, _ = env.reset(*args, **kwargs)
965
+ def transform_output(el):
966
+ if isinstance(el, ndarray):
967
+ return from_numpy(el)
968
+ elif isinstance(el, (int, bool, float)):
969
+ return tensor(el)
970
+ else:
971
+ return el
619
972
 
620
- if isinstance(state, ndarray):
621
- state = from_numpy(state)
973
+ def wrapped_reset(*args, **kwargs):
974
+ env_reset_out = env.reset(*args, **kwargs)
622
975
 
623
- return state, _
976
+ return tree_map(transform_output, env_reset_out)
624
977
 
625
978
  def wrapped_step(action, *args, **kwargs):
626
- out = env.step(action.item(), *args, **kwargs)
627
979
 
628
- def transform_output(el):
629
- if isinstance(el, ndarray):
630
- return from_numpy(el)
631
- elif isinstance(el, (int, bool, float)):
632
- return tensor(el)
633
- else:
634
- return el
980
+ if is_tensor(action):
981
+ action = action.item()
982
+
983
+ env_step_out = env.step(action, *args, **kwargs)
984
+
985
+ env_step_out_torch = tree_map(transform_output, env_step_out)
635
986
 
636
- return tree_map(transform_output, out)
987
+ if not self.has_reward_shaping:
988
+ return env_step_out_torch
989
+
990
+ shaped_rewards = self.state_to_rewards(env_step_out_torch)
991
+
992
+ return env_step_out_torch, shaped_rewards
637
993
 
638
994
  return wrapped_reset, wrapped_step
639
995
 
@@ -643,6 +999,7 @@ class Locoformer(Module):
643
999
  inference_mode = False,
644
1000
  has_batch_dim = False,
645
1001
  has_time_dim = False,
1002
+ state_time_dim = 1,
646
1003
  **kwargs
647
1004
  ):
648
1005
  window_size = self.window_size
@@ -658,23 +1015,16 @@ class Locoformer(Module):
658
1015
  state = rearrange(state, '... -> 1 ...')
659
1016
 
660
1017
  if not has_time_dim:
661
- state = rearrange(state, '... d -> ... 1 d')
1018
+ state = state.unsqueeze(state_time_dim)
662
1019
 
663
1020
  # forwards
664
1021
 
665
1022
  out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
666
1023
 
667
- # handle cache
668
-
669
- cache_len = cache.shape[-2]
670
-
671
- if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
672
- cache = cache[..., -window_size:, :]
673
-
674
1024
  # maybe remove batch or time
675
1025
 
676
1026
  if not has_time_dim:
677
- out = tree_map_tensor(out, lambda t: rearrange(t, '... 1 d -> ... d'))
1027
+ out = tree_map_tensor(out, lambda t: t.squeeze(state_time_dim))
678
1028
 
679
1029
  if not has_batch_dim:
680
1030
  out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
@@ -703,16 +1053,35 @@ class Locoformer(Module):
703
1053
  def forward(
704
1054
  self,
705
1055
  state: Tensor,
706
- cache: Tensor | None = None,
1056
+ cache: Cache | None = None,
707
1057
  detach_cache = False,
708
- return_values = False
1058
+ return_values = False,
1059
+ return_raw_value_logits = False
709
1060
  ):
710
1061
 
711
1062
  state = state.to(self.device)
712
1063
 
713
1064
  tokens = self.embedder(state)
714
1065
 
715
- embed, kv_cache = self.transformer(tokens, cache = cache, return_kv_cache = True)
1066
+ # time
1067
+
1068
+ time = tokens.shape[-2]
1069
+
1070
+ # destruct the cache for the current timestep and the cache
1071
+
1072
+ prev_kv_cache = None
1073
+ timestep_start = 0
1074
+
1075
+ if exists(cache):
1076
+ timestep_start, prev_kv_cache = cache
1077
+
1078
+ # an assert - make sure during training or inference, forward never gets anything that crosses the window segment boundary, to open up some possibilities with extending memory
1079
+
1080
+ assert ((timestep_start % self.window_size) + time) <= self.window_size
1081
+
1082
+ # attention
1083
+
1084
+ embed, kv_cache = self.transformer(tokens, cache = prev_kv_cache, return_kv_cache = True)
716
1085
 
717
1086
  # unembed to actions - in language models this would be the next state
718
1087
 
@@ -723,21 +1092,34 @@ class Locoformer(Module):
723
1092
  # maybe detach cache
724
1093
 
725
1094
  if detach_cache:
726
- kv_cache = detach_all(kv_cache)
1095
+ kv_cache = kv_cache.detach()
727
1096
 
728
1097
  # handle returning of values
729
1098
 
730
1099
  if return_values:
731
- assert exists(self.value_network)
1100
+ assert exists(self.to_value_pred)
732
1101
 
733
- values = self.value_network(embed)
1102
+ values = self.to_value_pred(embed)
734
1103
 
735
- if values.ndim == 3:
736
- assert values.shape[-1] == 1
737
- values = rearrange(values, '... 1 -> ...')
1104
+ if not return_raw_value_logits:
1105
+ values = self.hl_gauss_loss(values) # converts the value logits to scalar values
738
1106
 
739
1107
  out = (out, values)
740
1108
 
741
1109
  # output and cache
742
1110
 
743
- return out, kv_cache
1111
+ next_timestep = time + timestep_start
1112
+
1113
+ # handle curtailing kv cache at the right intervals
1114
+
1115
+ window_size = self.window_size
1116
+
1117
+ if self.fixed_window_size or divisible_by(next_timestep, window_size * 2):
1118
+ kv_cache = kv_cache[..., -window_size:, :]
1119
+
1120
+ # maybe recurrent cache - shift the kv cache from one layer above to the one below, for extending on receptive field of past
1121
+
1122
+ if self.recurrent_kv_cache and divisible_by(next_timestep, window_size):
1123
+ kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
1124
+
1125
+ return out, (next_timestep, kv_cache)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "locoformer"
3
- version = "0.0.11"
3
+ version = "0.0.30"
4
4
  description = "LocoFormer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -30,6 +30,7 @@ dependencies = [
30
30
  "beartype",
31
31
  "einx>=0.3.0",
32
32
  "einops>=0.8.0",
33
+ "hl-gauss-pytorch>=0.2.0",
33
34
  "rotary-embedding-torch",
34
35
  "torch>=2.4",
35
36
  "x-mlps-pytorch",
@@ -2,18 +2,25 @@ import pytest
2
2
  param = pytest.mark.parametrize
3
3
 
4
4
  import torch
5
+ from torch import nn
5
6
  from x_mlps_pytorch import MLP
6
7
 
7
8
  from einops import rearrange
8
9
 
9
- def test_locoformer():
10
- from locoformer.locoformer import Locoformer
11
- from torch import nn
10
+ from locoformer.locoformer import Locoformer
11
+
12
+ @param('recurrent_kv_cache', (False, True))
13
+ def test_locoformer(
14
+ recurrent_kv_cache
15
+ ):
12
16
 
13
17
  model = Locoformer(
14
18
  embedder = nn.Embedding(256, 128),
15
19
  unembedder = nn.Linear(128, 256, bias = False),
16
- value_network = MLP(128, 32, 1),
20
+ value_network = MLP(128, 64, 32),
21
+ dim_value_input = 32,
22
+ reward_range = (-100., 100.),
23
+ recurrent_kv_cache = recurrent_kv_cache,
17
24
  transformer = dict(
18
25
  dim = 128,
19
26
  depth = 1,
@@ -83,4 +90,58 @@ def test_replay():
83
90
 
84
91
  dataloader = replay_buffer.dataloader(batch_size = 3)
85
92
 
86
- assert next(iter(dataloader))['state'].shape[0] == 3
93
+ assert next(iter(dataloader))['state'].shape[0] == 3
94
+
95
+ # we will now consider consecutive pairs of episodes as 2 trials to be used for in-context adaptation
96
+ # but realistically there will be a function that converts a given ReplayBuffer -> Int[batch, episode_indices]
97
+
98
+ from torch import stack, arange
99
+
100
+ episode_indices = arange(len(replay_buffer))
101
+ remapped_episodes = stack((episode_indices[:-1], episode_indices[1:]))
102
+
103
+ dataloader = replay_buffer.dataloader(
104
+ batch_size = 1,
105
+ episode_mapping = remapped_episodes
106
+ )
107
+
108
+ assert next(iter(dataloader))['_lens'][0] == (3 + 5) # first and second episodes are concatted together timewise
109
+
110
+ def test_reward_shaping():
111
+
112
+ model = Locoformer(
113
+ embedder = nn.Embedding(256, 128),
114
+ unembedder = nn.Linear(128, 256, bias = False),
115
+ value_network = MLP(128, 64, 32),
116
+ dim_value_input = 32,
117
+ reward_range = (-100., 100.),
118
+ reward_shaping_fns = [
119
+ lambda state: (state[3] - 2.5).pow(2).mean(),
120
+ lambda state: state[4:6].norm(dim = -1)
121
+ ],
122
+ transformer = dict(
123
+ dim = 128,
124
+ depth = 1,
125
+ window_size = 512
126
+ )
127
+ )
128
+
129
+ import numpy as np
130
+
131
+ class MockEnv:
132
+ def reset(self):
133
+ return np.random.normal(size = (10,))
134
+
135
+ def step(self, *args, **kwargs):
136
+ return np.random.normal(size = (10,))
137
+
138
+
139
+ env = MockEnv()
140
+
141
+ reset_fn, step_fn = model.wrap_env_functions(env)
142
+
143
+ reset_fn()
144
+
145
+ _, rewards = step_fn(3)
146
+
147
+ assert len(rewards) == 2
@@ -160,7 +160,7 @@ for i in range(NUM_BATCHES):
160
160
  optim.step()
161
161
  optim.zero_grad()
162
162
 
163
- if divisible_by(i + 1, GENERATE_EVERY):
163
+ if divisible_by(i, GENERATE_EVERY):
164
164
  model.eval()
165
165
 
166
166
  val_seq = next(val_loader_iter)
@@ -169,7 +169,7 @@ for i in range(NUM_BATCHES):
169
169
  prime = prime.to(model.device)
170
170
  out = prime
171
171
 
172
- stateful_forward, logits = model.get_stateful_forward(has_batch_dim = False, initial_states = prime, inference_mode = True)
172
+ stateful_forward, logits = model.get_stateful_forward(has_batch_dim = False, has_time_dim = True, initial_states = prime, inference_mode = True)
173
173
 
174
174
  # sample
175
175
 
@@ -3,7 +3,7 @@
3
3
  # "accelerate",
4
4
  # "fire",
5
5
  # "gymnasium[box2d]>=1.0.0",
6
- # "locoformer",
6
+ # "locoformer>=0.0.12",
7
7
  # "moviepy",
8
8
  # "tqdm"
9
9
  # ]
@@ -13,13 +13,14 @@ from fire import Fire
13
13
  from shutil import rmtree
14
14
  from tqdm import tqdm
15
15
  from collections import deque
16
+ from types import SimpleNamespace
16
17
 
17
18
  from accelerate import Accelerator
18
19
 
19
20
  import gymnasium as gym
20
21
 
21
22
  import torch
22
- from torch import from_numpy, randint, tensor, stack
23
+ from torch import from_numpy, randint, tensor, stack, arange
23
24
  import torch.nn.functional as F
24
25
  from torch.utils.data import TensorDataset, DataLoader
25
26
  from torch.optim import Adam
@@ -47,26 +48,64 @@ def gumbel_sample(logits, temperature = 1., eps = 1e-6):
47
48
  noise = gumbel_noise(logits)
48
49
  return ((logits / max(temperature, eps)) + noise).argmax(dim = -1)
49
50
 
51
+ # learn
52
+
53
+ def learn(
54
+ model,
55
+ actor_optim,
56
+ critic_optim,
57
+ accelerator,
58
+ replay,
59
+ batch_size = 16,
60
+ epochs = 2,
61
+ ):
62
+ dl = replay.dataloader(batch_size = batch_size, shuffle = True)
63
+ model, dl, actor_optim, critic_optim = accelerator.prepare(model, dl, actor_optim, critic_optim)
64
+
65
+ for _ in range(epochs):
66
+ for data in dl:
67
+
68
+ data = SimpleNamespace(**data)
69
+
70
+ actor_loss, critic_loss = model.ppo(
71
+ state = data.state,
72
+ action = data.action,
73
+ old_action_log_prob = data.action_log_prob,
74
+ reward = data.reward,
75
+ old_value = data.value,
76
+ mask = data.learnable,
77
+ episode_lens = data._lens,
78
+ actor_optim = actor_optim,
79
+ critic_optim = critic_optim
80
+ )
81
+
82
+ accelerator.print(f'actor: {actor_loss.item():.3f} | critic: {critic_loss.item():.3f}')
83
+
50
84
  # main function
51
85
 
52
86
  def main(
53
87
  env_name = 'LunarLander-v3',
54
88
  num_episodes = 50_000,
55
89
  max_timesteps = 500,
56
- num_timestep_before_learn = 5000,
90
+ num_episodes_before_learn = 64,
57
91
  clear_video = True,
58
92
  video_folder = 'recordings',
59
93
  record_every_episode = 250,
94
+ learning_rate = 8e-4,
60
95
  discount_factor = 0.99,
61
- learning_rate = 1e-4,
96
+ betas = (0.9, 0.99),
97
+ gae_lam = 0.95,
98
+ ppo_eps_clip = 0.2,
99
+ ppo_entropy_weight = .01,
62
100
  batch_size = 16,
63
- epochs = 2
101
+ epochs = 3,
102
+ reward_range = (-100., 100.)
64
103
  ):
65
104
 
66
105
  # accelerate
67
106
 
68
- accelerate = Accelerator()
69
- device = accelerate.device
107
+ accelerator = Accelerator()
108
+ device = accelerator.device
70
109
 
71
110
  # environment
72
111
 
@@ -91,14 +130,15 @@ def main(
91
130
  replay = ReplayBuffer(
92
131
  'replay',
93
132
  num_episodes,
94
- max_timesteps,
133
+ max_timesteps + 1, # one extra node for bootstrap node - not relevant for locoformer, but for completeness
95
134
  fields = dict(
96
135
  state = ('float', (dim_state,)),
97
136
  action = 'int',
98
137
  action_log_prob = 'float',
99
138
  reward = 'float',
100
139
  value = 'float',
101
- done = 'bool'
140
+ done = 'bool',
141
+ learnable = 'bool'
102
142
  )
103
143
  )
104
144
 
@@ -107,20 +147,30 @@ def main(
107
147
  locoformer = Locoformer(
108
148
  embedder = MLP(dim_state, 64, bias = False),
109
149
  unembedder = MLP(64, num_actions, bias = False),
110
- value_network = MLP(64, 1, bias = False),
111
150
  transformer = dict(
112
151
  dim = 64,
113
152
  dim_head = 32,
114
153
  heads = 4,
115
154
  depth = 4,
116
155
  window_size = 16
117
- )
156
+ ),
157
+ discount_factor = discount_factor,
158
+ gae_lam = gae_lam,
159
+ ppo_eps_clip = ppo_eps_clip,
160
+ ppo_entropy_weight = ppo_entropy_weight,
161
+ use_spo = True,
162
+ value_network = MLP(64, 64),
163
+ dim_value_input = 64,
164
+ reward_range = reward_range,
165
+ hl_gauss_loss_kwargs = dict(),
166
+ recurrent_kv_cache = True,
167
+ calc_gae_kwargs = dict(
168
+ use_accelerated = False
169
+ ),
118
170
  ).to(device)
119
171
 
120
- optim_actor = Adam([*locoformer.transformer.parameters(), *locoformer.actor_parameters()], lr = learning_rate)
121
- optim_critic = Adam([*locoformer.transformer.parameters(), *locoformer.critic_parameters()], lr = learning_rate)
122
-
123
- timesteps_learn = 0
172
+ optim_actor = Adam([*locoformer.transformer.parameters(), *locoformer.actor_parameters()], lr = learning_rate, betas = betas)
173
+ optim_critic = Adam([*locoformer.transformer.parameters(), *locoformer.critic_parameters()], lr = learning_rate, betas = betas)
124
174
 
125
175
  # able to wrap the env for all values to torch tensors and back
126
176
  # all environments should follow usual MDP interface, domain randomization should be given at instantiation
@@ -129,7 +179,8 @@ def main(
129
179
 
130
180
  # loop
131
181
 
132
- for _ in tqdm(range(num_episodes)):
182
+ for episodes_index in tqdm(range(num_episodes)):
183
+
133
184
  state, *_ = env_reset()
134
185
 
135
186
  timestep = 0
@@ -151,42 +202,59 @@ def main(
151
202
 
152
203
  # append to memory
153
204
 
154
- done = truncated or terminated
205
+ exceeds_max_timesteps = timestep == (max_timesteps - 1)
206
+ done = truncated or terminated or tensor(exceeds_max_timesteps)
155
207
 
156
208
  # get log prob of action
157
209
 
158
210
  action_log_prob = action_logits.gather(-1, rearrange(action, '-> 1'))
159
211
  action_log_prob = rearrange(action_log_prob, '1 ->')
160
212
 
161
- replay.store(
213
+ memory = replay.store(
162
214
  state = state,
163
215
  action = action,
164
216
  action_log_prob = action_log_prob,
165
217
  reward = reward,
166
218
  value = value,
167
- done = done
219
+ done = done,
220
+ learnable = tensor(True)
168
221
  )
169
222
 
170
223
  # increment counters
171
224
 
172
225
  timestep += 1
173
- timesteps_learn += 1
174
226
 
175
- # learn if hit the number of learn timesteps
227
+ # break if done or exceed max timestep
176
228
 
177
- if timesteps_learn >= num_timestep_before_learn:
178
- # todo - carry out learning
229
+ if done:
179
230
 
180
- timesteps_learn = 0
181
- memories.clear()
231
+ # handle bootstrap value, which is a non-learnable timestep added with the next value for GAE
232
+ # only if terminated signal not detected
182
233
 
183
- # break if done or exceed max timestep
234
+ if not terminated:
235
+ _, next_value = stateful_forward(next_state, return_values = True)
236
+
237
+ memory._replace(value = next_value, learnable = False)
238
+
239
+ replay.store(**memory._asdict())
184
240
 
185
- if done or timestep >= max_timesteps:
186
241
  break
187
242
 
188
243
  state = next_state
189
244
 
245
+ # learn if hit the number of learn timesteps
246
+
247
+ if divisible_by(episodes_index + 1, num_episodes_before_learn):
248
+
249
+ learn(
250
+ locoformer,
251
+ optim_actor,
252
+ optim_critic,
253
+ accelerator,
254
+ replay,
255
+ batch_size,
256
+ epochs,
257
+ )
190
258
  # main
191
259
 
192
260
  if __name__ == '__main__':
File without changes
File without changes
File without changes
File without changes
File without changes