locoformer 0.0.30__tar.gz → 0.0.37__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: locoformer
3
- Version: 0.0.30
3
+ Version: 0.0.37
4
4
  Summary: LocoFormer
5
5
  Project-URL: Homepage, https://pypi.org/project/locoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/locoformer
@@ -41,6 +41,7 @@ Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: hl-gauss-pytorch>=0.2.0
42
42
  Requires-Dist: rotary-embedding-torch
43
43
  Requires-Dist: torch>=2.4
44
+ Requires-Dist: x-evolution
44
45
  Requires-Dist: x-mlps-pytorch
45
46
  Provides-Extra: examples
46
47
  Requires-Dist: accelerate; extra == 'examples'
@@ -1,11 +1,14 @@
1
1
  from __future__ import annotations
2
2
  from typing import Callable
3
+ from types import SimpleNamespace
3
4
  from functools import partial
4
5
 
5
6
  from pathlib import Path
6
7
  from contextlib import contextmanager
7
8
  from collections import namedtuple
8
9
 
10
+ from inspect import signature
11
+
9
12
  import numpy as np
10
13
  from numpy import ndarray
11
14
  from numpy.lib.format import open_memmap
@@ -31,6 +34,10 @@ from hl_gauss_pytorch import HLGaussLoss
31
34
 
32
35
  from assoc_scan import AssocScan
33
36
 
37
+ from x_mlps_pytorch import MLP
38
+
39
+ from x_evolution import EvoStrategy
40
+
34
41
  # constants
35
42
 
36
43
  LinearNoBias = partial(Linear, bias = False)
@@ -54,6 +61,10 @@ def xnor(x, y):
54
61
  def divisible_by(num, den):
55
62
  return (num % den) == 0
56
63
 
64
+ def get_param_names(fn):
65
+ parameters = signature(fn).parameters
66
+ return list(parameters.keys())
67
+
57
68
  # tensor helpers
58
69
 
59
70
  def log(t, eps = 1e-20):
@@ -81,10 +92,99 @@ def pad_at_dim(
81
92
  def normalize(t, eps = 1e-5):
82
93
  return (t - t.mean()) / t.std().clamp_min(eps)
83
94
 
95
+ def tensor_to_dict(
96
+ t: Tensor,
97
+ config: tuple[tuple[str, int] | str],
98
+ dim = -1,
99
+ return_dottable = True
100
+ ):
101
+ config = tuple((c, 1) if isinstance(c, str) else c for c in config)
102
+
103
+ names, sizes = zip(*config)
104
+ assert sum(sizes) == t.shape[dim]
105
+
106
+ t = t.split(sizes, dim = dim)
107
+ tensor_dict = dict(zip(names, t))
108
+
109
+ if not return_dottable:
110
+ return tensor_dict
111
+
112
+ return SimpleNamespace(**tensor_dict)
113
+
84
114
  def calc_entropy(logits):
85
115
  prob = logits.softmax(dim = -1)
86
116
  return -(prob * log(prob)).sum(dim = -1)
87
117
 
118
+ # reward functions - A.2
119
+
120
+ def reward_linear_velocity_command_tracking(
121
+ state,
122
+ command,
123
+ s1 = 1.
124
+ ):
125
+ if not (hasattr(state, 'v_xy') and hasattr(command, 'v_xy')):
126
+ return 0.
127
+
128
+ error = (state.v_xy - command.v_xy).norm(dim = -1).pow(2)
129
+ return torch.exp(-error / s1)
130
+
131
+ def reward_angular_velocity_command_tracking(
132
+ state,
133
+ command,
134
+ s2 = 1.
135
+ ):
136
+ if not (hasattr(state, 'w_z') and hasattr(command, 'w_z')):
137
+ return 0.
138
+
139
+ error = (state.w_z - command.w_z).norm(dim = -1).pow(2)
140
+ return torch.exp(-error / s2)
141
+
142
+ def reward_base_linear_velocity_penalty(
143
+ state
144
+ ):
145
+ if not hasattr(state, 'v_z'):
146
+ return 0.
147
+
148
+ return -state.v_z.norm(dim = -1).pow(2)
149
+
150
+ def reward_base_angular_velocity_penalty(
151
+ state
152
+ ):
153
+ if not hasattr(state, 'w_xy'):
154
+ return 0.
155
+
156
+ return -state.w_xy.norm(dim = -1).pow(2)
157
+
158
+ def reward_base_height_penalty(
159
+ state,
160
+ x_z_nominal = 0.27
161
+ ):
162
+ if not hasattr(state, 'x_z'):
163
+ return 0.
164
+
165
+ return -(state.x_z - x_z_nominal).norm(dim = -1).pow(2)
166
+
167
+ def reward_joint_acceleration_penalty(
168
+ state
169
+ ):
170
+ if not hasattr(state, 'joint_q'):
171
+ return 0.
172
+
173
+ return -state.joint_q.norm(dim = -1).pow(2)
174
+
175
+ def reward_torque_penalty(
176
+ state
177
+ ):
178
+ if not hasattr(state, 'tau'):
179
+ return 0.
180
+
181
+ return -state.tau.norm(dim = -1).pow(2)
182
+
183
+ def reward_alive(
184
+ state
185
+ ):
186
+ return 1.
187
+
88
188
  # generalized advantage estimate
89
189
 
90
190
  @torch.no_grad()
@@ -487,8 +587,8 @@ class MaybeAdaRMSNormWrapper(Module):
487
587
  self.to_gamma = LinearNoBias(dim_cond, dim)
488
588
  self.to_ada_norm_zero = nn.Linear(dim_cond, dim)
489
589
 
490
- nn.init.zeros_(self.to_gamma.weight, 0.)
491
- nn.init.zeros_(self.to_ada_norm_zero.weight, 0.)
590
+ nn.init.zeros_(self.to_gamma.weight)
591
+ nn.init.zeros_(self.to_ada_norm_zero.weight)
492
592
  nn.init.constant_(self.to_ada_norm_zero.bias, -5.)
493
593
 
494
594
  def forward(
@@ -499,6 +599,7 @@ class MaybeAdaRMSNormWrapper(Module):
499
599
  ):
500
600
 
501
601
  need_cond = self.accept_condition
602
+
502
603
  assert xnor(exists(cond), need_cond)
503
604
 
504
605
  prenormed = self.norm(x)
@@ -683,7 +784,9 @@ class TransformerXL(Module):
683
784
 
684
785
  condition = exists(dim_cond)
685
786
 
686
- norm_fn = partial(MaybeAdaRMSNormWrapper, dim = dim, dim_cond = dim_cond)
787
+ self.to_cond_tokens = MLP(dim_cond, dim * 2, activate_last = True) if exists(dim_cond) else None
788
+
789
+ norm_fn = partial(MaybeAdaRMSNormWrapper, dim = dim, dim_cond = (dim * 2) if condition else None)
687
790
 
688
791
  layers = ModuleList([])
689
792
 
@@ -710,20 +813,32 @@ class TransformerXL(Module):
710
813
  self,
711
814
  x,
712
815
  cache = None,
713
- return_kv_cache = False
816
+ return_kv_cache = False,
817
+ condition: Tensor | None = None
714
818
  ):
715
819
 
820
+ # cache and residuals
821
+
716
822
  cache = default(cache, (None,) * len(self.layers))
717
823
 
718
824
  next_kv_caches = []
719
825
  value_residual = None
720
826
 
827
+ # handle condition
828
+
829
+ cond_tokens = None
830
+ if exists(condition):
831
+ assert exists(self.to_cond_tokens)
832
+ cond_tokens = self.to_cond_tokens(condition)
833
+
834
+ # layers
835
+
721
836
  for (attn, ff), kv_cache in zip(self.layers, cache):
722
837
 
723
- attn_out, (next_kv_cache, values) = attn(x, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
838
+ attn_out, (next_kv_cache, values) = attn(x, cond = cond_tokens, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
724
839
 
725
840
  x = attn_out + x
726
- x = ff(x) + x
841
+ x = ff(x, cond = cond_tokens) + x
727
842
 
728
843
  next_kv_caches.append(next_kv_cache)
729
844
  value_residual = default(value_residual, values)
@@ -752,10 +867,10 @@ class Locoformer(Module):
752
867
  ppo_eps_clip = 0.2,
753
868
  ppo_entropy_weight = 0.01,
754
869
  ppo_value_clip = 0.4,
755
- dim_value_input = None, # needs to be set for value network to be available
870
+ dim_value_input = None, # needs to be set for value network to be available
756
871
  value_network: Module = nn.Identity(),
757
872
  reward_range: tuple[float, float] | None = None,
758
- reward_shaping_fns: list[Callable[[Tensor], float | Tensor]] | None = None,
873
+ reward_shaping_fns: list[Callable[..., float | Tensor]] | None = None,
759
874
  num_reward_bins = 32,
760
875
  hl_gauss_loss_kwargs = dict(),
761
876
  value_loss_weight = 0.5,
@@ -838,6 +953,14 @@ class Locoformer(Module):
838
953
 
839
954
  return self.to_value_pred.parameters()
840
955
 
956
+ def evolve(
957
+ self,
958
+ environment,
959
+ **kwargs
960
+ ):
961
+ evo_strat = EvoStrategy(self, environment = environment, **kwargs)
962
+ evo_strat()
963
+
841
964
  def ppo(
842
965
  self,
843
966
  state,
@@ -948,16 +1071,33 @@ class Locoformer(Module):
948
1071
 
949
1072
  return mean_actor_loss.detach(), mean_critic_loss.detach()
950
1073
 
951
- def state_to_rewards(
1074
+ def state_and_command_to_rewards(
952
1075
  self,
953
- state
1076
+ state,
1077
+ commands = None
954
1078
  ) -> Tensor:
955
1079
 
956
1080
  assert self.has_reward_shaping
957
1081
 
958
- rewards = [fn(state) for fn in self.reward_shaping_fns]
1082
+ rewards = []
1083
+
1084
+ for fn in self.reward_shaping_fns:
1085
+ param_names = get_param_names(fn)
1086
+ param_names = set(param_names) & {'state', 'command'}
1087
+
1088
+ if param_names == {'state'}: # only state
1089
+ reward = fn(state = state)
1090
+ elif param_names == {'state', 'command'}: # state and command
1091
+ reward = fn(state = state, command = commands)
1092
+ else:
1093
+ raise ValueError('invalid number of arguments for reward shaping function')
1094
+
1095
+ rewards.append(reward)
1096
+
1097
+ # cast to Tensor if returns a float, just make it flexible for researcher
959
1098
 
960
1099
  rewards = [tensor(reward) if not is_tensor(reward) else reward for reward in rewards]
1100
+
961
1101
  return stack(rewards)
962
1102
 
963
1103
  def wrap_env_functions(self, env):
@@ -987,7 +1127,7 @@ class Locoformer(Module):
987
1127
  if not self.has_reward_shaping:
988
1128
  return env_step_out_torch
989
1129
 
990
- shaped_rewards = self.state_to_rewards(env_step_out_torch)
1130
+ shaped_rewards = self.state_and_command_to_rewards(env_step_out_torch)
991
1131
 
992
1132
  return env_step_out_torch, shaped_rewards
993
1133
 
@@ -1006,7 +1146,11 @@ class Locoformer(Module):
1006
1146
 
1007
1147
  cache = None
1008
1148
 
1009
- def stateful_forward(state: Tensor, **override_kwargs):
1149
+ def stateful_forward(
1150
+ state: Tensor,
1151
+ condition: Tensor | None = None,
1152
+ **override_kwargs
1153
+ ):
1010
1154
  nonlocal cache
1011
1155
 
1012
1156
  # handle no batch or time, for easier time rolling out against envs
@@ -1014,12 +1158,18 @@ class Locoformer(Module):
1014
1158
  if not has_batch_dim:
1015
1159
  state = rearrange(state, '... -> 1 ...')
1016
1160
 
1161
+ if exists(command):
1162
+ condition = rearrange(condition, '... -> 1 ...')
1163
+
1017
1164
  if not has_time_dim:
1018
1165
  state = state.unsqueeze(state_time_dim)
1019
1166
 
1167
+ if exists(command):
1168
+ condition = rearrange(condition, '... d -> ... 1 d')
1169
+
1020
1170
  # forwards
1021
1171
 
1022
- out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
1172
+ out, cache = self.forward(state, condition = condition, cache = cache, **{**kwargs, **override_kwargs})
1023
1173
 
1024
1174
  # maybe remove batch or time
1025
1175
 
@@ -1054,6 +1204,7 @@ class Locoformer(Module):
1054
1204
  self,
1055
1205
  state: Tensor,
1056
1206
  cache: Cache | None = None,
1207
+ condition: Tensor | None = None,
1057
1208
  detach_cache = False,
1058
1209
  return_values = False,
1059
1210
  return_raw_value_logits = False
@@ -1081,7 +1232,7 @@ class Locoformer(Module):
1081
1232
 
1082
1233
  # attention
1083
1234
 
1084
- embed, kv_cache = self.transformer(tokens, cache = prev_kv_cache, return_kv_cache = True)
1235
+ embed, kv_cache = self.transformer(tokens, condition = condition, cache = prev_kv_cache, return_kv_cache = True)
1085
1236
 
1086
1237
  # unembed to actions - in language models this would be the next state
1087
1238
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "locoformer"
3
- version = "0.0.30"
3
+ version = "0.0.37"
4
4
  description = "LocoFormer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -33,6 +33,7 @@ dependencies = [
33
33
  "hl-gauss-pytorch>=0.2.0",
34
34
  "rotary-embedding-torch",
35
35
  "torch>=2.4",
36
+ "x-evolution",
36
37
  "x-mlps-pytorch",
37
38
  ]
38
39
 
@@ -10,8 +10,10 @@ from einops import rearrange
10
10
  from locoformer.locoformer import Locoformer
11
11
 
12
12
  @param('recurrent_kv_cache', (False, True))
13
+ @param('has_commands', (False, True))
13
14
  def test_locoformer(
14
- recurrent_kv_cache
15
+ recurrent_kv_cache,
16
+ has_commands
15
17
  ):
16
18
 
17
19
  model = Locoformer(
@@ -24,24 +26,31 @@ def test_locoformer(
24
26
  transformer = dict(
25
27
  dim = 128,
26
28
  depth = 1,
27
- window_size = 512
29
+ window_size = 512,
30
+ dim_cond = 2 if has_commands else None
28
31
  )
29
32
  )
30
33
 
31
34
  seq = torch.randint(0, 256, (3, 512))
32
35
 
33
- (logits, values), cache = model(seq, return_values = True)
34
- (logits, values), cache = model(seq, return_values = True, cache = cache)
35
- (logits, values), cache = model(seq, return_values = True, cache = cache)
36
+ commands = None
37
+ if has_commands:
38
+ commands = torch.randn(3, 512, 2)
39
+
40
+ (logits, values), cache = model(seq, condition = commands, return_values = True)
41
+ (logits, values), cache = model(seq, condition = commands, return_values = True, cache = cache)
42
+ (logits, values), cache = model(seq, condition = commands, return_values = True, cache = cache)
36
43
 
37
44
  assert logits.shape == (3, 512, 256)
38
45
 
39
46
  stateful_forward = model.get_stateful_forward(has_batch_dim = True, has_time_dim = True, return_values = True, inference_mode = True)
40
47
 
48
+ inference_command = torch.randn(1, 1, 2) if has_commands else None
49
+
41
50
  for state in seq.unbind(dim = -1):
42
51
  state = rearrange(state, 'b -> b 1')
43
52
 
44
- logits, values = stateful_forward(state)
53
+ logits, values = stateful_forward(state, condition = inference_command)
45
54
  assert logits.shape == (3, 1, 256)
46
55
 
47
56
  def test_replay():
@@ -117,7 +126,7 @@ def test_reward_shaping():
117
126
  reward_range = (-100., 100.),
118
127
  reward_shaping_fns = [
119
128
  lambda state: (state[3] - 2.5).pow(2).mean(),
120
- lambda state: state[4:6].norm(dim = -1)
129
+ lambda state, command: state[4:6].norm(dim = -1)
121
130
  ],
122
131
  transformer = dict(
123
132
  dim = 128,
@@ -145,3 +154,29 @@ def test_reward_shaping():
145
154
  _, rewards = step_fn(3)
146
155
 
147
156
  assert len(rewards) == 2
157
+
158
+ def test_tensor_to_dict():
159
+ state = torch.randn(1, 3, 5)
160
+ config = (('xyz', 3), 'vx', 'vy')
161
+
162
+ from locoformer.locoformer import tensor_to_dict
163
+
164
+ state_dict = tensor_to_dict(state, config)
165
+ assert hasattr(state_dict, 'xyz') and state_dict.xyz.shape == (1, 3, 3)
166
+
167
+ def test_evo():
168
+
169
+ model = Locoformer(
170
+ embedder = nn.Embedding(256, 128),
171
+ unembedder = nn.Linear(128, 256, bias = False),
172
+ value_network = MLP(128, 64, 32),
173
+ dim_value_input = 32,
174
+ reward_range = (-100., 100.),
175
+ transformer = dict(
176
+ dim = 128,
177
+ depth = 1,
178
+ window_size = 512,
179
+ )
180
+ )
181
+
182
+ model.evolve(lambda model: 1., num_generations = 1)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes