locoformer 0.0.15__py3-none-any.whl → 0.0.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
locoformer/locoformer.py CHANGED
@@ -1,10 +1,14 @@
1
1
  from __future__ import annotations
2
- from functools import partial
2
+ from typing import Callable
3
+ from types import SimpleNamespace
4
+ from functools import partial, wraps
3
5
 
4
6
  from pathlib import Path
5
7
  from contextlib import contextmanager
6
8
  from collections import namedtuple
7
9
 
10
+ from inspect import signature
11
+
8
12
  import numpy as np
9
13
  from numpy import ndarray
10
14
  from numpy.lib.format import open_memmap
@@ -16,8 +20,9 @@ import torch
16
20
  from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
17
21
  import torch.nn.functional as F
18
22
  from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
19
- from torch.utils._pytree import tree_map
23
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
20
24
  from torch.utils.data import Dataset, DataLoader
25
+ from torch.optim import Optimizer
21
26
 
22
27
  import einx
23
28
  from einops import rearrange, einsum
@@ -25,10 +30,20 @@ from einops.layers.torch import Rearrange
25
30
 
26
31
  from rotary_embedding_torch import RotaryEmbedding
27
32
 
33
+ from hl_gauss_pytorch import HLGaussLoss
34
+
28
35
  from assoc_scan import AssocScan
29
36
 
37
+ from x_mlps_pytorch import MLP
38
+
39
+ from x_evolution import EvoStrategy
40
+
41
+ # constants
42
+
30
43
  LinearNoBias = partial(Linear, bias = False)
31
44
 
45
+ Cache = namedtuple('Cache', ('curr_timestep', 'kv_cache')) # (int, Tensor)
46
+
32
47
  # helper functions
33
48
 
34
49
  def exists(v):
@@ -40,20 +55,51 @@ def default(v, d):
40
55
  def first(arr):
41
56
  return arr[0]
42
57
 
58
+ def xnor(x, y):
59
+ return not (x ^ y)
60
+
43
61
  def divisible_by(num, den):
44
62
  return (num % den) == 0
45
63
 
64
+ def get_param_names(fn):
65
+ parameters = signature(fn).parameters
66
+ return list(parameters.keys())
67
+
68
+ def check_has_param_attr(
69
+ param_name,
70
+ param_attr,
71
+ default_value = None
72
+ ):
73
+ def decorator(fn):
74
+ sig = signature(fn)
75
+
76
+ @wraps(fn)
77
+ def inner(*args, **kwargs):
78
+
79
+ bound_args = sig.bind(*args, **kwargs).arguments
80
+
81
+ if not (
82
+ param_name in bound_args and
83
+ hasattr(bound_args[param_name], param_attr)
84
+ ):
85
+ return default_value
86
+
87
+ return fn(*args, **kwargs)
88
+
89
+ return inner
90
+ return decorator
91
+
46
92
  # tensor helpers
47
93
 
48
94
  def log(t, eps = 1e-20):
49
95
  return t.clamp_min(eps).log()
50
96
 
97
+ def is_empty(t):
98
+ return t.numel() == 0
99
+
51
100
  def tree_map_tensor(x, fn):
52
101
  return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
53
102
 
54
- def detach_all(x):
55
- return tree_map_tensor(x, lambda t: t.detach())
56
-
57
103
  def pad_at_dim(
58
104
  t,
59
105
  pad: tuple[int, int],
@@ -67,10 +113,90 @@ def pad_at_dim(
67
113
  zeros = ((0, 0) * dims_from_right)
68
114
  return F.pad(t, (*zeros, *pad), value = value)
69
115
 
116
+ def normalize(t, eps = 1e-5):
117
+ return (t - t.mean()) / t.std().clamp_min(eps)
118
+
119
+ def tensor_to_dict(
120
+ t: Tensor,
121
+ config: tuple[tuple[str, int] | str],
122
+ dim = -1,
123
+ return_dottable = True
124
+ ):
125
+ config = tuple((c, 1) if isinstance(c, str) else c for c in config)
126
+
127
+ names, sizes = zip(*config)
128
+ assert sum(sizes) == t.shape[dim]
129
+
130
+ t = t.split(sizes, dim = dim)
131
+ tensor_dict = dict(zip(names, t))
132
+
133
+ if not return_dottable:
134
+ return tensor_dict
135
+
136
+ return SimpleNamespace(**tensor_dict)
137
+
70
138
  def calc_entropy(logits):
71
139
  prob = logits.softmax(dim = -1)
72
140
  return -(prob * log(prob)).sum(dim = -1)
73
141
 
142
+ # reward functions - A.2
143
+
144
+ @check_has_param_attr('state', 'v_xy')
145
+ @check_has_param_attr('command', 'v_xy')
146
+ def reward_linear_velocity_command_tracking(
147
+ state,
148
+ command,
149
+ s1 = 1.
150
+ ):
151
+ error = (state.v_xy - command.v_xy).norm(dim = -1).pow(2)
152
+ return torch.exp(-error / s1)
153
+
154
+ @check_has_param_attr('state', 'w_z')
155
+ @check_has_param_attr('command', 'w_z')
156
+ def reward_angular_velocity_command_tracking(
157
+ state,
158
+ command,
159
+ s2 = 1.
160
+ ):
161
+ error = (state.w_z - command.w_z).norm(dim = -1).pow(2)
162
+ return torch.exp(-error / s2)
163
+
164
+ @check_has_param_attr('state', 'v_z')
165
+ def reward_base_linear_velocity_penalty(
166
+ state
167
+ ):
168
+ return -state.v_z.norm(dim = -1).pow(2)
169
+
170
+ @check_has_param_attr('state', 'w_xy')
171
+ def reward_base_angular_velocity_penalty(
172
+ state
173
+ ):
174
+ return -state.w_xy.norm(dim = -1).pow(2)
175
+
176
+ @check_has_param_attr('state', 'x_z')
177
+ def reward_base_height_penalty(
178
+ state,
179
+ x_z_nominal = 0.27
180
+ ):
181
+ return -(state.x_z - x_z_nominal).norm(dim = -1).pow(2)
182
+
183
+ @check_has_param_attr('state', 'joint_q')
184
+ def reward_joint_acceleration_penalty(
185
+ state
186
+ ):
187
+ return -state.joint_q.norm(dim = -1).pow(2)
188
+
189
+ @check_has_param_attr('state', 'tau')
190
+ def reward_torque_penalty(
191
+ state
192
+ ):
193
+ return -state.tau.norm(dim = -1).pow(2)
194
+
195
+ def reward_alive(
196
+ state
197
+ ):
198
+ return 1.
199
+
74
200
  # generalized advantage estimate
75
201
 
76
202
  @torch.no_grad()
@@ -100,7 +226,7 @@ def calc_gae(
100
226
 
101
227
  returns = gae + values
102
228
 
103
- return returns
229
+ return gae, returns
104
230
 
105
231
  # transformer-xl mask w/ flex attn
106
232
 
@@ -250,6 +376,74 @@ class ReplayDataset(Dataset):
250
376
 
251
377
  return data
252
378
 
379
+ class RemappedReplayDataset(Dataset):
380
+ def __init__(
381
+ self,
382
+ dataset: ReplayDataset,
383
+ episode_mapping: Tensor | list[list[int]],
384
+ shuffle_episodes = False,
385
+ num_trials_select = None
386
+ ):
387
+ assert len(dataset) > 0
388
+ self.dataset = dataset
389
+
390
+ if is_tensor(episode_mapping):
391
+ assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
392
+ episode_mapping = episode_mapping.tolist()
393
+
394
+ self.episode_mapping = episode_mapping
395
+ self.shuffle_episodes = shuffle_episodes
396
+
397
+ assert not (exists(num_trials_select) and num_trials_select >= 1)
398
+ self.sub_select_trials = exists(num_trials_select)
399
+ self.num_trials_select = num_trials_select
400
+
401
+ def __len__(self):
402
+ return len(self.episode_mapping)
403
+
404
+ def __getitem__(self, idx):
405
+
406
+ episode_indices = self.episode_mapping[idx]
407
+
408
+ episode_indices = tensor(episode_indices)
409
+ episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
410
+
411
+ assert not is_empty(episode_indices)
412
+
413
+ # shuffle the episode indices if either shuffle episodes is turned on, or `num_trial_select` passed in (for sub selecting episodes from a set)
414
+
415
+ if (
416
+ episode_indices.numel() > 1 and
417
+ (self.shuffle_episodes or self.sub_select_trials)
418
+ ):
419
+ num_episodes = len(episode_indices)
420
+ episode_indices = episode_indices[torch.randperm(num_episodes)]
421
+
422
+ # crop out the episodes
423
+
424
+ if self.sub_select_trials:
425
+ episode_indices = episode_indices[:self.num_trials_select]
426
+
427
+ # now select out the episode data and merge along time
428
+
429
+ episode_data = [self.dataset[i] for i in episode_indices.tolist()]
430
+
431
+ episode_lens = stack([data.pop('_lens') for data in episode_data])
432
+
433
+ keys = first(episode_data).keys()
434
+
435
+ values = [list(data.values()) for data in episode_data]
436
+
437
+ values = [cat(field_values) for field_values in zip(*values)] # concat across time
438
+
439
+ multi_episode_data = dict(zip(keys, values))
440
+
441
+ multi_episode_data['_lens'] = episode_lens.sum()
442
+
443
+ multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
444
+
445
+ return multi_episode_data
446
+
253
447
  class ReplayBuffer:
254
448
 
255
449
  @beartype
@@ -306,6 +500,10 @@ class ReplayBuffer:
306
500
  # memmap file
307
501
 
308
502
  filepath = folder / f'{field_name}.data.npy'
503
+
504
+ if isinstance(shape, int):
505
+ shape = (shape,)
506
+
309
507
  memmap = open_memmap(str(filepath), mode = 'w+', dtype = dtype, shape = (max_episodes, max_timesteps, *shape))
310
508
 
311
509
  self.memmaps[field_name] = memmap
@@ -314,6 +512,9 @@ class ReplayBuffer:
314
512
 
315
513
  self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
316
514
 
515
+ def __len__(self):
516
+ return (self.episode_lens > 0).sum().item()
517
+
317
518
  def reset_(self):
318
519
  self.episode_lens[:] = 0
319
520
  self.episode_index = 0
@@ -375,15 +576,92 @@ class ReplayBuffer:
375
576
 
376
577
  return self.memory_namedtuple(**data)
377
578
 
378
- def dataset(self) -> Dataset:
579
+ def dataset(
580
+ self,
581
+ episode_mapping: Tensor | list[list[int]] | None = None,
582
+ ) -> Dataset:
379
583
  self.flush()
380
584
 
381
- return ReplayDataset(self.folder)
585
+ dataset = ReplayDataset(self.folder)
586
+
587
+ if not exists(episode_mapping):
588
+ return dataset
382
589
 
383
- def dataloader(self, batch_size, **kwargs) -> DataLoader:
590
+ return RemappedReplayDataset(dataset, episode_mapping)
591
+
592
+ def dataloader(
593
+ self,
594
+ batch_size,
595
+ episode_mapping: Tensor | list[list[int]] | None = None,
596
+ **kwargs
597
+ ) -> DataLoader:
384
598
  self.flush()
385
599
 
386
- return DataLoader(self.dataset(), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
600
+ return DataLoader(self.dataset(episode_mapping), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
601
+
602
+ # normalization + conditioning (needed for the commands to the robot)
603
+
604
+ class MaybeAdaRMSNormWrapper(Module):
605
+ def __init__(
606
+ self,
607
+ fn: Module,
608
+ dim,
609
+ dim_cond = None
610
+ ):
611
+ super().__init__()
612
+ condition = exists(dim_cond)
613
+
614
+ self.fn = fn
615
+ self.norm = nn.RMSNorm(dim, elementwise_affine = not condition)
616
+
617
+ self.accept_condition = condition
618
+
619
+ if condition:
620
+ self.to_gamma = LinearNoBias(dim_cond, dim)
621
+ self.to_ada_norm_zero = nn.Linear(dim_cond, dim)
622
+
623
+ nn.init.zeros_(self.to_gamma.weight)
624
+ nn.init.zeros_(self.to_ada_norm_zero.weight)
625
+ nn.init.constant_(self.to_ada_norm_zero.bias, -5.)
626
+
627
+ def forward(
628
+ self,
629
+ x,
630
+ cond = None,
631
+ **kwargs
632
+ ):
633
+
634
+ need_cond = self.accept_condition
635
+
636
+ assert xnor(exists(cond), need_cond)
637
+
638
+ prenormed = self.norm(x)
639
+
640
+ if need_cond:
641
+ if cond.ndim == 2:
642
+ cond = rearrange(cond, 'b d -> b 1 d')
643
+
644
+ scale_in = self.to_gamma(cond)
645
+ prenormed = prenormed * (scale_in + 1.)
646
+
647
+ all_fn_out = self.fn(prenormed, **kwargs)
648
+
649
+ if not need_cond:
650
+ return all_fn_out
651
+
652
+ # function may return multiple args
653
+
654
+ (out, *rest), tree_spec = tree_flatten(all_fn_out)
655
+
656
+ if need_cond:
657
+ scale_out = self.to_ada_norm_zero(cond).sigmoid()
658
+ out = out * scale_out
659
+
660
+ # restore
661
+
662
+ all_fn_out = tree_unflatten((out, *rest), tree_spec)
663
+
664
+ return all_fn_out
387
665
 
388
666
  # transformer-xl with ppo
389
667
 
@@ -394,15 +672,12 @@ class Attention(Module):
394
672
  window_size,
395
673
  dim_head = 64,
396
674
  heads = 8,
397
- pre_rmsnorm = True,
398
675
  fixed_window_size = False,
399
676
  accept_value_residual = False
400
677
  ):
401
678
  super().__init__()
402
679
  self.scale = dim_head ** -0.5
403
680
 
404
- self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
405
-
406
681
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
407
682
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
408
683
 
@@ -446,8 +721,6 @@ class Attention(Module):
446
721
 
447
722
  device = tokens.device
448
723
 
449
- tokens = self.norm(tokens)
450
-
451
724
  q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
452
725
 
453
726
  q, k, v = map(self.split_heads, (q, k, v))
@@ -536,19 +809,26 @@ class TransformerXL(Module):
536
809
  dim_head = 64,
537
810
  heads = 8,
538
811
  expansion_factor = 4.,
812
+ dim_cond = None,
539
813
  final_norm = True,
540
814
  fixed_window_size = False,
541
815
  ):
542
816
  super().__init__()
543
817
 
818
+ condition = exists(dim_cond)
819
+
820
+ self.to_cond_tokens = MLP(dim_cond, dim * 2, activate_last = True) if exists(dim_cond) else None
821
+
822
+ norm_fn = partial(MaybeAdaRMSNormWrapper, dim = dim, dim_cond = (dim * 2) if condition else None)
823
+
544
824
  layers = ModuleList([])
545
825
 
546
826
  for i in range(depth):
547
827
  is_first = i == 0
548
828
 
549
- attn = Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first)
829
+ attn = norm_fn(Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first))
550
830
 
551
- ff = FeedForward(dim = dim, expansion_factor = expansion_factor)
831
+ ff = norm_fn(FeedForward(dim = dim, expansion_factor = expansion_factor))
552
832
 
553
833
  layers.append(ModuleList([
554
834
  attn, ff
@@ -566,20 +846,32 @@ class TransformerXL(Module):
566
846
  self,
567
847
  x,
568
848
  cache = None,
569
- return_kv_cache = False
849
+ return_kv_cache = False,
850
+ condition: Tensor | None = None
570
851
  ):
571
852
 
853
+ # cache and residuals
854
+
572
855
  cache = default(cache, (None,) * len(self.layers))
573
856
 
574
857
  next_kv_caches = []
575
858
  value_residual = None
576
859
 
860
+ # handle condition
861
+
862
+ cond_tokens = None
863
+ if exists(condition):
864
+ assert exists(self.to_cond_tokens)
865
+ cond_tokens = self.to_cond_tokens(condition)
866
+
867
+ # layers
868
+
577
869
  for (attn, ff), kv_cache in zip(self.layers, cache):
578
870
 
579
- attn_out, (next_kv_cache, values) = attn(x, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
871
+ attn_out, (next_kv_cache, values) = attn(x, cond = cond_tokens, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
580
872
 
581
873
  x = attn_out + x
582
- x = ff(x) + x
874
+ x = ff(x, cond = cond_tokens) + x
583
875
 
584
876
  next_kv_caches.append(next_kv_cache)
585
877
  value_residual = default(value_residual, values)
@@ -600,16 +892,24 @@ class TransformerXL(Module):
600
892
  class Locoformer(Module):
601
893
  def __init__(
602
894
  self,
603
- embedder: Module,
895
+ embedder: Module | ModuleList | list[Module],
604
896
  unembedder: Module,
605
897
  transformer: dict | TransformerXL,
606
- value_network: Module | None = None,
607
898
  discount_factor = 0.999,
608
899
  gae_lam = 0.95,
609
900
  ppo_eps_clip = 0.2,
610
901
  ppo_entropy_weight = 0.01,
611
902
  ppo_value_clip = 0.4,
612
- value_loss_weight = 0.5
903
+ dim_value_input = None, # needs to be set for value network to be available
904
+ value_network: Module = nn.Identity(),
905
+ reward_range: tuple[float, float] | None = None,
906
+ reward_shaping_fns: list[Callable[..., float | Tensor]] | None = None,
907
+ num_reward_bins = 32,
908
+ hl_gauss_loss_kwargs = dict(),
909
+ value_loss_weight = 0.5,
910
+ calc_gae_kwargs: dict = dict(),
911
+ recurrent_kv_cache = True,
912
+ use_spo = False # simple policy optimization https://arxiv.org/abs/2401.16025 - Levine's group (PI) verified it is more stable than PPO
613
913
  ):
614
914
  super().__init__()
615
915
 
@@ -618,14 +918,41 @@ class Locoformer(Module):
618
918
 
619
919
  self.transformer = transformer
620
920
 
921
+ # handle state embedder
922
+
923
+ if isinstance(embedder, list):
924
+ embedder = ModuleList(embedder)
925
+
621
926
  self.embedder = embedder
622
- self.unembedder = unembedder
623
927
 
624
- self.value_network = value_network
928
+ # unembed state to actions or ssl predictions
929
+
930
+ self.unembedder = unembedder
625
931
 
626
932
  self.fixed_window_size = transformer.fixed_window_size
627
933
  self.window_size = transformer.window_size
628
934
 
935
+ # determine value network, using HL Gauss Layer
936
+
937
+ self.to_value_pred = None
938
+
939
+ if exists(dim_value_input):
940
+ assert exists(reward_range)
941
+
942
+ self.to_value_pred = nn.Sequential(
943
+ value_network,
944
+ LinearNoBias(dim_value_input, num_reward_bins)
945
+ )
946
+
947
+ reward_min, reward_max = reward_range
948
+
949
+ self.hl_gauss_loss = HLGaussLoss(
950
+ min_value = reward_min,
951
+ max_value = reward_max,
952
+ num_bins = num_reward_bins,
953
+ **hl_gauss_loss_kwargs
954
+ )
955
+
629
956
  # ppo related
630
957
 
631
958
  self.discount_factor = discount_factor
@@ -635,6 +962,25 @@ class Locoformer(Module):
635
962
  self.ppo_value_clip = ppo_value_clip
636
963
  self.value_loss_weight = value_loss_weight
637
964
 
965
+ self.calc_gae_kwargs = calc_gae_kwargs
966
+
967
+ # maybe use spo
968
+
969
+ self.use_spo = use_spo
970
+
971
+ # maybe recurrent kv cache, from Ding et al. https://arxiv.org/abs/2012.15688
972
+
973
+ self.recurrent_kv_cache = recurrent_kv_cache
974
+
975
+ # reward shaping function
976
+
977
+ self.has_reward_shaping = exists(reward_shaping_fns)
978
+ self.reward_shaping_fns = reward_shaping_fns
979
+
980
+ # loss related
981
+
982
+ self.register_buffer('zero', tensor(0.), persistent = False)
983
+
638
984
  @property
639
985
  def device(self):
640
986
  return next(self.parameters()).device
@@ -643,10 +989,18 @@ class Locoformer(Module):
643
989
  return self.unembedder.parameters()
644
990
 
645
991
  def critic_parameters(self):
646
- if not exists(self.value_network):
992
+ if not exists(self.to_value_pred):
647
993
  return []
648
994
 
649
- return self.value_network.parameters()
995
+ return self.to_value_pred.parameters()
996
+
997
+ def evolve(
998
+ self,
999
+ environment,
1000
+ **kwargs
1001
+ ):
1002
+ evo_strat = EvoStrategy(self, environment = environment, **kwargs)
1003
+ evo_strat()
650
1004
 
651
1005
  def ppo(
652
1006
  self,
@@ -656,79 +1010,180 @@ class Locoformer(Module):
656
1010
  reward,
657
1011
  old_value,
658
1012
  mask,
659
- actor_optim,
660
- critic_optim
1013
+ episode_lens,
1014
+ condition: Tensor | None = None,
1015
+ state_type: int | None = None,
1016
+ actor_optim: Optimizer | None = None,
1017
+ critic_optim: Optimizer | None = None
661
1018
  ):
1019
+ window_size = self.window_size
1020
+ total_learnable_tokens = mask.sum().item()
662
1021
 
663
- (action_logits, value), _ = self.forward(state, return_values = True)
664
- entropy = calc_entropy(action_logits)
1022
+ seq_len = state.shape[1]
1023
+ gae_mask = einx.less('j, i -> i j', arange(seq_len, device = self.device), episode_lens)
665
1024
 
666
- action = rearrange(action, 'b t -> b t 1')
667
- log_prob = action_logits.gather(-1, action)
668
- log_prob = rearrange(log_prob, 'b t 1 -> b t')
1025
+ advantage, returns = calc_gae(reward, old_value, masks = gae_mask, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
1026
+
1027
+ advantage = normalize(advantage)
1028
+
1029
+ data_tensors = (
1030
+ state,
1031
+ action,
1032
+ old_action_log_prob,
1033
+ reward,
1034
+ old_value,
1035
+ mask,
1036
+ advantage,
1037
+ returns
1038
+ )
669
1039
 
670
- # update actor, classic clipped surrogate loss
1040
+ has_condition = exists(condition)
671
1041
 
672
- eps_clip = self.ppo_eps_clip
673
- ratio = (log_prob - old_action_log_prob).exp()
1042
+ if exists(condition):
1043
+ data_tensors = (*data_tensors, condition)
1044
+
1045
+ windowed_tensors = [
1046
+ t.split(window_size, dim = 1) for t in
1047
+ data_tensors
1048
+ ]
1049
+
1050
+ mean_actor_loss = self.zero.clone()
1051
+ mean_critic_loss = self.zero.clone()
1052
+
1053
+ # learn across windows
1054
+
1055
+ cache = None
674
1056
 
675
- returns = calc_gae(reward, old_value, lam = self.gae_lam, gamma = self.discount_factor)
676
- advantage = returns - old_value
1057
+ for (
1058
+ state,
1059
+ action,
1060
+ old_action_log_prob,
1061
+ reward,
1062
+ old_value,
1063
+ mask,
1064
+ advantage,
1065
+ returns,
1066
+ *rest
1067
+ ) in zip(*windowed_tensors):
677
1068
 
678
- actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
1069
+ if has_condition:
1070
+ condition, = rest
679
1071
 
680
- actor_loss = actor_loss - self.ppo_entropy_weight * entropy
1072
+ (action_logits, value_logits), cache = self.forward(state, condition = condition, state_type = state_type, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True)
1073
+ entropy = calc_entropy(action_logits)
681
1074
 
682
- mean_actor_loss = actor_loss[mask].mean()
683
- mean_actor_loss.backward(retain_graph = True)
1075
+ action = rearrange(action, 'b t -> b t 1')
1076
+ log_prob = action_logits.gather(-1, action)
1077
+ log_prob = rearrange(log_prob, 'b t 1 -> b t')
684
1078
 
685
- # update critic
1079
+ # update actor, classic clipped surrogate loss
686
1080
 
687
- value_loss = F.mse_loss(returns, value, reduction = 'none')
1081
+ eps_clip = self.ppo_eps_clip
1082
+ ratio = (log_prob - old_action_log_prob).exp()
688
1083
 
689
- value_clip = self.ppo_value_clip
690
- clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
691
- clipped_value_loss = F.mse_loss(returns, clipped_value, reduction = 'none')
1084
+ if self.use_spo:
1085
+ actor_loss = -(ratio * advantage - (advantage.abs() * (ratio - 1.).square()) / (2 * eps_clip))
1086
+ else:
1087
+ actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
692
1088
 
693
- critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
1089
+ actor_loss = actor_loss - self.ppo_entropy_weight * entropy
694
1090
 
695
- mean_critic_loss = critic_loss[mask].mean()
696
- mean_critic_loss.backward()
1091
+ windowed_actor_loss = actor_loss[mask].sum() / total_learnable_tokens
1092
+ windowed_actor_loss.backward(retain_graph = True)
1093
+
1094
+ # update critic
1095
+
1096
+ value_loss = self.hl_gauss_loss(value_logits, returns, reduction = 'none')
1097
+
1098
+ value_clip = self.ppo_value_clip
1099
+ value = self.hl_gauss_loss(value_logits)
1100
+
1101
+ clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
1102
+ clipped_value_loss = self.hl_gauss_loss(clipped_value, returns, reduction = 'none')
1103
+
1104
+ critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
1105
+
1106
+ windowed_critic_loss = critic_loss[mask].sum() / total_learnable_tokens
1107
+ windowed_critic_loss.backward(retain_graph = True)
1108
+
1109
+ # accumulate
1110
+
1111
+ mean_actor_loss.add_(windowed_actor_loss)
1112
+ mean_critic_loss.add_(windowed_critic_loss)
697
1113
 
698
1114
  # optimizer update
699
1115
 
700
- actor_optim.step()
701
- actor_optim.zero_grad()
1116
+ if exists(actor_optim):
1117
+ actor_optim.step()
1118
+ actor_optim.zero_grad()
702
1119
 
703
- critic_optim.step()
704
- critic_optim.zero_grad()
1120
+ if exists(critic_optim):
1121
+ critic_optim.step()
1122
+ critic_optim.zero_grad()
705
1123
 
706
1124
  # return losses for logging
707
1125
 
708
1126
  return mean_actor_loss.detach(), mean_critic_loss.detach()
709
1127
 
1128
+ def state_and_command_to_rewards(
1129
+ self,
1130
+ state,
1131
+ commands = None
1132
+ ) -> Tensor:
1133
+
1134
+ assert self.has_reward_shaping
1135
+
1136
+ rewards = []
1137
+
1138
+ for fn in self.reward_shaping_fns:
1139
+ param_names = get_param_names(fn)
1140
+ param_names = set(param_names) & {'state', 'command'}
1141
+
1142
+ if param_names == {'state'}: # only state
1143
+ reward = fn(state = state)
1144
+ elif param_names == {'state', 'command'}: # state and command
1145
+ reward = fn(state = state, command = commands)
1146
+ else:
1147
+ raise ValueError('invalid number of arguments for reward shaping function')
1148
+
1149
+ rewards.append(reward)
1150
+
1151
+ # cast to Tensor if returns a float, just make it flexible for researcher
1152
+
1153
+ rewards = [tensor(reward) if not is_tensor(reward) else reward for reward in rewards]
1154
+
1155
+ return stack(rewards)
1156
+
710
1157
  def wrap_env_functions(self, env):
711
1158
 
712
- def wrapped_reset(*args, **kwargs):
713
- state, _ = env.reset(*args, **kwargs)
1159
+ def transform_output(el):
1160
+ if isinstance(el, ndarray):
1161
+ return from_numpy(el)
1162
+ elif isinstance(el, (int, bool, float)):
1163
+ return tensor(el)
1164
+ else:
1165
+ return el
714
1166
 
715
- if isinstance(state, ndarray):
716
- state = from_numpy(state)
1167
+ def wrapped_reset(*args, **kwargs):
1168
+ env_reset_out = env.reset(*args, **kwargs)
717
1169
 
718
- return state, _
1170
+ return tree_map(transform_output, env_reset_out)
719
1171
 
720
1172
  def wrapped_step(action, *args, **kwargs):
721
- out = env.step(action.item(), *args, **kwargs)
722
1173
 
723
- def transform_output(el):
724
- if isinstance(el, ndarray):
725
- return from_numpy(el)
726
- elif isinstance(el, (int, bool, float)):
727
- return tensor(el)
728
- else:
729
- return el
1174
+ if is_tensor(action):
1175
+ action = action.item()
1176
+
1177
+ env_step_out = env.step(action, *args, **kwargs)
1178
+
1179
+ env_step_out_torch = tree_map(transform_output, env_step_out)
730
1180
 
731
- return tree_map(transform_output, out)
1181
+ if not self.has_reward_shaping:
1182
+ return env_step_out_torch
1183
+
1184
+ shaped_rewards = self.state_and_command_to_rewards(env_step_out_torch)
1185
+
1186
+ return env_step_out_torch, shaped_rewards
732
1187
 
733
1188
  return wrapped_reset, wrapped_step
734
1189
 
@@ -738,38 +1193,48 @@ class Locoformer(Module):
738
1193
  inference_mode = False,
739
1194
  has_batch_dim = False,
740
1195
  has_time_dim = False,
1196
+ state_time_dim = 1,
741
1197
  **kwargs
742
1198
  ):
743
1199
  window_size = self.window_size
744
1200
 
745
1201
  cache = None
746
1202
 
747
- def stateful_forward(state: Tensor, **override_kwargs):
1203
+ def stateful_forward(
1204
+ state: Tensor,
1205
+ condition: Tensor | None = None,
1206
+ state_type: int | None = None,
1207
+ **override_kwargs
1208
+ ):
748
1209
  nonlocal cache
749
1210
 
1211
+ state = state.to(self.device)
1212
+
1213
+ if exists(condition):
1214
+ condition = condition.to(self.device)
1215
+
750
1216
  # handle no batch or time, for easier time rolling out against envs
751
1217
 
752
1218
  if not has_batch_dim:
753
1219
  state = rearrange(state, '... -> 1 ...')
754
1220
 
755
- if not has_time_dim:
756
- state = rearrange(state, '... d -> ... 1 d')
757
-
758
- # forwards
1221
+ if exists(condition):
1222
+ condition = rearrange(condition, '... -> 1 ...')
759
1223
 
760
- out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
1224
+ if not has_time_dim:
1225
+ state = state.unsqueeze(state_time_dim)
761
1226
 
762
- # handle cache
1227
+ if exists(condition):
1228
+ condition = rearrange(condition, '... d -> ... 1 d')
763
1229
 
764
- cache_len = cache.shape[-2]
1230
+ # forwards
765
1231
 
766
- if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
767
- cache = cache[..., -window_size:, :]
1232
+ out, cache = self.forward(state, condition = condition, state_type = state_type, cache = cache, **{**kwargs, **override_kwargs})
768
1233
 
769
1234
  # maybe remove batch or time
770
1235
 
771
1236
  if not has_time_dim:
772
- out = tree_map_tensor(out, lambda t: rearrange(t, '... 1 d -> ... d'))
1237
+ out = tree_map_tensor(out, lambda t: t.squeeze(state_time_dim))
773
1238
 
774
1239
  if not has_batch_dim:
775
1240
  out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
@@ -798,16 +1263,46 @@ class Locoformer(Module):
798
1263
  def forward(
799
1264
  self,
800
1265
  state: Tensor,
801
- cache: Tensor | None = None,
1266
+ cache: Cache | None = None,
1267
+ condition: Tensor | None = None,
1268
+ state_type: int | None = None,
802
1269
  detach_cache = False,
803
- return_values = False
1270
+ return_values = False,
1271
+ return_raw_value_logits = False
804
1272
  ):
805
1273
 
806
1274
  state = state.to(self.device)
807
1275
 
808
- tokens = self.embedder(state)
1276
+ # determine which function to invoke for state to token for transformer
1277
+
1278
+ state_to_token = self.embedder
1279
+
1280
+ if exists(state_type):
1281
+ state_to_token = self.embedder[state_type]
1282
+
1283
+ # embed
1284
+
1285
+ tokens = state_to_token(state)
1286
+
1287
+ # time
1288
+
1289
+ time = tokens.shape[-2]
1290
+
1291
+ # destruct the cache for the current timestep and the cache
809
1292
 
810
- embed, kv_cache = self.transformer(tokens, cache = cache, return_kv_cache = True)
1293
+ prev_kv_cache = None
1294
+ timestep_start = 0
1295
+
1296
+ if exists(cache):
1297
+ timestep_start, prev_kv_cache = cache
1298
+
1299
+ # an assert - make sure during training or inference, forward never gets anything that crosses the window segment boundary, to open up some possibilities with extending memory
1300
+
1301
+ assert ((timestep_start % self.window_size) + time) <= self.window_size
1302
+
1303
+ # attention
1304
+
1305
+ embed, kv_cache = self.transformer(tokens, condition = condition, cache = prev_kv_cache, return_kv_cache = True)
811
1306
 
812
1307
  # unembed to actions - in language models this would be the next state
813
1308
 
@@ -818,21 +1313,34 @@ class Locoformer(Module):
818
1313
  # maybe detach cache
819
1314
 
820
1315
  if detach_cache:
821
- kv_cache = detach_all(kv_cache)
1316
+ kv_cache = kv_cache.detach()
822
1317
 
823
1318
  # handle returning of values
824
1319
 
825
1320
  if return_values:
826
- assert exists(self.value_network)
1321
+ assert exists(self.to_value_pred)
827
1322
 
828
- values = self.value_network(embed)
1323
+ values = self.to_value_pred(embed)
829
1324
 
830
- if values.ndim == 3:
831
- assert values.shape[-1] == 1
832
- values = rearrange(values, '... 1 -> ...')
1325
+ if not return_raw_value_logits:
1326
+ values = self.hl_gauss_loss(values) # converts the value logits to scalar values
833
1327
 
834
1328
  out = (out, values)
835
1329
 
836
1330
  # output and cache
837
1331
 
838
- return out, kv_cache
1332
+ next_timestep = time + timestep_start
1333
+
1334
+ # handle curtailing kv cache at the right intervals
1335
+
1336
+ window_size = self.window_size
1337
+
1338
+ if self.fixed_window_size or divisible_by(next_timestep, window_size * 2):
1339
+ kv_cache = kv_cache[..., -window_size:, :]
1340
+
1341
+ # maybe recurrent cache - shift the kv cache from one layer above to the one below, for extending on receptive field of past
1342
+
1343
+ if self.recurrent_kv_cache and divisible_by(next_timestep, window_size):
1344
+ kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
1345
+
1346
+ return out, (next_timestep, kv_cache)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: locoformer
3
- Version: 0.0.15
3
+ Version: 0.0.43
4
4
  Summary: LocoFormer
5
5
  Project-URL: Homepage, https://pypi.org/project/locoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/locoformer
@@ -38,8 +38,10 @@ Requires-Dist: assoc-scan
38
38
  Requires-Dist: beartype
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
+ Requires-Dist: hl-gauss-pytorch>=0.2.0
41
42
  Requires-Dist: rotary-embedding-torch
42
43
  Requires-Dist: torch>=2.4
44
+ Requires-Dist: x-evolution
43
45
  Requires-Dist: x-mlps-pytorch
44
46
  Provides-Extra: examples
45
47
  Requires-Dist: accelerate; extra == 'examples'
@@ -54,7 +56,7 @@ Description-Content-Type: text/markdown
54
56
 
55
57
  [LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
56
58
 
57
- The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
59
+ The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) and extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
58
60
 
59
61
  ## Sponsors
60
62
 
@@ -0,0 +1,6 @@
1
+ locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
2
+ locoformer/locoformer.py,sha256=5gQTtseqs92K9ee9HJ1gEqhm8MFPFDFXPnoPxLnf8Nw,37531
3
+ locoformer-0.0.43.dist-info/METADATA,sha256=Vgx50wEmRpwrGxoOntARE2oU7g5TdqcM2ZUvrpOBjIk,3283
4
+ locoformer-0.0.43.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ locoformer-0.0.43.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ locoformer-0.0.43.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.27.0
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,6 +0,0 @@
1
- locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
2
- locoformer/locoformer.py,sha256=1jPK41G4HB1PEPtlusQxcrne489E-3QKXAULZ20FEZM,22740
3
- locoformer-0.0.15.dist-info/METADATA,sha256=IHtK7NvVQewYQ0GBB7v1KG90_H2Jakxir0MakUIA-jU,3218
4
- locoformer-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- locoformer-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- locoformer-0.0.15.dist-info/RECORD,,