locoformer 0.0.7__py3-none-any.whl → 0.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
locoformer/locoformer.py CHANGED
@@ -1,11 +1,25 @@
1
1
  from __future__ import annotations
2
+ from typing import Callable
2
3
  from functools import partial
3
4
 
5
+ from pathlib import Path
6
+ from contextlib import contextmanager
7
+ from collections import namedtuple
8
+
9
+ import numpy as np
10
+ from numpy import ndarray
11
+ from numpy.lib.format import open_memmap
12
+
13
+ from beartype import beartype
14
+ from beartype.door import is_bearable
15
+
4
16
  import torch
5
- from torch import nn, cat, stack, arange, Tensor, is_tensor
17
+ from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
6
18
  import torch.nn.functional as F
7
19
  from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
8
20
  from torch.utils._pytree import tree_map
21
+ from torch.utils.data import Dataset, DataLoader
22
+ from torch.optim import Optimizer
9
23
 
10
24
  import einx
11
25
  from einops import rearrange, einsum
@@ -13,10 +27,16 @@ from einops.layers.torch import Rearrange
13
27
 
14
28
  from rotary_embedding_torch import RotaryEmbedding
15
29
 
30
+ from hl_gauss_pytorch import HLGaussLoss
31
+
16
32
  from assoc_scan import AssocScan
17
33
 
34
+ # constants
35
+
18
36
  LinearNoBias = partial(Linear, bias = False)
19
37
 
38
+ Cache = namedtuple('Cache', ('curr_timestep', 'kv_cache')) # (int, Tensor)
39
+
20
40
  # helper functions
21
41
 
22
42
  def exists(v):
@@ -31,20 +51,36 @@ def first(arr):
31
51
  def divisible_by(num, den):
32
52
  return (num % den) == 0
33
53
 
54
+ # tensor helpers
55
+
56
+ def log(t, eps = 1e-20):
57
+ return t.clamp_min(eps).log()
58
+
59
+ def is_empty(t):
60
+ return t.numel() == 0
61
+
34
62
  def tree_map_tensor(x, fn):
35
63
  return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
36
64
 
37
- def detach_all(x):
38
- return tree_map_tensor(x, lambda t: t.detach())
65
+ def pad_at_dim(
66
+ t,
67
+ pad: tuple[int, int],
68
+ dim = -1,
69
+ value = 0.
70
+ ):
71
+ if pad == (0, 0):
72
+ return t
39
73
 
40
- def combine_kv_cache(cache1, cache2):
41
- combined_cache = []
74
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
75
+ zeros = ((0, 0) * dims_from_right)
76
+ return F.pad(t, (*zeros, *pad), value = value)
42
77
 
43
- for layer_cache1, layer_cache2 in zip(cache1, cache2):
44
- next_cache = cat((layer_cache1, layer_cache2), dim = -2)
45
- combined_cache.append(next_cache)
78
+ def normalize(t, eps = 1e-5):
79
+ return (t - t.mean()) / t.std().clamp_min(eps)
46
80
 
47
- return combined_cache
81
+ def calc_entropy(logits):
82
+ prob = logits.softmax(dim = -1)
83
+ return -(prob * log(prob)).sum(dim = -1)
48
84
 
49
85
  # generalized advantage estimate
50
86
 
@@ -52,7 +88,7 @@ def combine_kv_cache(cache1, cache2):
52
88
  def calc_gae(
53
89
  rewards,
54
90
  values,
55
- masks,
91
+ masks = None,
56
92
  gamma = 0.99,
57
93
  lam = 0.95,
58
94
  use_accelerated = None
@@ -63,6 +99,9 @@ def calc_gae(
63
99
  values = F.pad(values, (0, 1), value = 0.)
64
100
  values, values_next = values[..., :-1], values[..., 1:]
65
101
 
102
+ if not exists(masks):
103
+ masks = torch.ones_like(values)
104
+
66
105
  delta = rewards + gamma * values_next * masks - values
67
106
  gates = gamma * lam * masks
68
107
 
@@ -72,7 +111,7 @@ def calc_gae(
72
111
 
73
112
  returns = gae + values
74
113
 
75
- return returns
114
+ return gae, returns
76
115
 
77
116
  # transformer-xl mask w/ flex attn
78
117
 
@@ -114,8 +153,8 @@ def create_xl_mask(
114
153
  # handle intra-episodic attention if needed
115
154
 
116
155
  if exists(episode_ids):
117
- q_episode = episodes[b, q + offset]
118
- k_episode = episodes[b, k]
156
+ q_episode = episode_ids[b, q + offset]
157
+ k_episode = episode_ids[b, k]
119
158
 
120
159
  intra_episode_mask = q_episode == k_episode
121
160
  mask = mask & intra_episode_mask
@@ -146,6 +185,284 @@ def create_sliding_mask(
146
185
  create_kwargs = dict(device = device) if exists(device) else dict()
147
186
  return create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True, **create_kwargs)
148
187
 
188
+ # data
189
+
190
+ def collate_var_time(data):
191
+
192
+ datum = first(data)
193
+ keys = datum.keys()
194
+
195
+ all_tensors = zip(*[datum.values() for datum in data])
196
+
197
+ collated_values = []
198
+
199
+ for key, tensors in zip(keys, all_tensors):
200
+
201
+ # the episode lens have zero dimension - think of a cleaner way to handle this later
202
+
203
+ if key != '_lens':
204
+
205
+ times = [t.shape[0] for t in tensors]
206
+ max_time = max(times)
207
+ tensors = [pad_at_dim(t, (0, max_time - t.shape[0]), dim = 0) for t in tensors]
208
+
209
+ collated_values.append(stack(tensors))
210
+
211
+ return dict(zip(keys, collated_values))
212
+
213
+ class ReplayDataset(Dataset):
214
+ def __init__(
215
+ self,
216
+ folder: str | Path,
217
+ fields: tuple[str, ...] | None = None
218
+ ):
219
+ if isinstance(folder, str):
220
+ folder = Path(folder)
221
+
222
+ episode_lens = folder / 'episode_lens.npy'
223
+ self.episode_lens = open_memmap(str(episode_lens), mode = 'r')
224
+
225
+ # get indices of non-zero lengthed episodes
226
+
227
+ nonzero_episodes = self.episode_lens > 0
228
+ self.indices = np.arange(self.episode_lens.shape[-1])[nonzero_episodes]
229
+
230
+ # get all data files
231
+
232
+ filepaths = [*folder.glob('*.data.npy')]
233
+ assert len(filepaths) > 0
234
+
235
+ fieldname_to_filepath = {path.name.split('.')[0]: path for path in filepaths}
236
+
237
+ fieldnames_from_files = set(fieldname_to_filepath.keys())
238
+
239
+ fields = default(fields, fieldnames_from_files)
240
+
241
+ self.memmaps = dict()
242
+
243
+ for field in fields:
244
+ assert field in fieldnames_from_files, f'invalid field {field} - must be one of {fieldnames_from_files}'
245
+
246
+ path = fieldname_to_filepath[field]
247
+
248
+ self.memmaps[field] = open_memmap(str(path), mode = 'r')
249
+
250
+ def __len__(self):
251
+ return len(self.indices)
252
+
253
+ def __getitem__(self, idx):
254
+ episode_index = self.indices[idx]
255
+
256
+ episode_len = self.episode_lens[episode_index]
257
+
258
+ data = {field: from_numpy(memmap[episode_index, :episode_len].copy()) for field, memmap in self.memmaps.items()}
259
+
260
+ data['_lens'] = tensor(episode_len)
261
+
262
+ return data
263
+
264
+ class RemappedReplayDataset(Dataset):
265
+ def __init__(
266
+ self,
267
+ dataset: ReplayDataset,
268
+ episode_mapping: Tensor | list[list[int]],
269
+ shuffle_episodes = False
270
+ ):
271
+ assert len(dataset) > 0
272
+ self.dataset = dataset
273
+
274
+ if is_tensor(episode_mapping):
275
+ assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
276
+ episode_mapping = episode_mapping.tolist()
277
+
278
+ self.episode_mapping = episode_mapping
279
+ self.shuffle_episodes = shuffle_episodes
280
+
281
+ def __len__(self):
282
+ return len(self.episode_mapping)
283
+
284
+ def __getitem__(self, idx):
285
+
286
+ episode_indices = self.episode_mapping[idx]
287
+
288
+ episode_indices = tensor(episode_indices)
289
+ episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
290
+
291
+ assert not is_empty(episode_indices)
292
+
293
+ if self.shuffle_episodes and episode_indices.numel() > 1:
294
+ num_episodes = len(episode_indices)
295
+ episode_indices = episode_indices[torch.randperm(num_episodes)]
296
+
297
+ episode_data = [self.dataset[i] for i in episode_indices.tolist()]
298
+
299
+ episode_lens = stack([data.pop('_lens') for data in episode_data])
300
+
301
+ keys = first(episode_data).keys()
302
+
303
+ values = [list(data.values()) for data in episode_data]
304
+
305
+ values = [cat(field_values) for field_values in zip(*values)] # concat across time
306
+
307
+ multi_episode_data = dict(zip(keys, values))
308
+
309
+ multi_episode_data['_lens'] = episode_lens.sum()
310
+
311
+ multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
312
+
313
+ return multi_episode_data
314
+
315
+ class ReplayBuffer:
316
+
317
+ @beartype
318
+ def __init__(
319
+ self,
320
+ folder: str | Path,
321
+ max_episodes: int,
322
+ max_timesteps: int,
323
+ fields: dict[
324
+ str,
325
+ str | tuple[str, int | tuple[int, ...]]
326
+ ]
327
+ ):
328
+
329
+ # folder for data
330
+
331
+ if not isinstance(folder, Path):
332
+ folder = Path(folder)
333
+ folder.mkdir(exist_ok = True)
334
+
335
+ self.folder = folder
336
+ assert folder.is_dir()
337
+
338
+ # keeping track of episode length
339
+
340
+ episode_lens = folder / 'episode_lens.npy'
341
+
342
+ self.episode_index = 0
343
+ self.timestep_index = 0
344
+
345
+ self.max_episodes = max_episodes
346
+ self.max_timesteps= max_timesteps
347
+
348
+ self.episode_lens = open_memmap(str(episode_lens), mode = 'w+', dtype = np.int32, shape = (max_episodes,))
349
+
350
+ # create the memmap for individual data tracks
351
+
352
+ self.shapes = dict()
353
+ self.dtypes = dict()
354
+ self.memmaps = dict()
355
+ self.fieldnames = set(fields.keys())
356
+
357
+ for field_name, field_info in fields.items():
358
+
359
+ # some flexibility
360
+
361
+ field_info = (field_info, ()) if isinstance(field_info, str) else field_info
362
+
363
+ dtype_str, shape = field_info
364
+ assert dtype_str in {'int', 'float', 'bool'}
365
+
366
+ dtype = dict(int = np.int32, float = np.float32, bool = np.bool_)[dtype_str]
367
+
368
+ # memmap file
369
+
370
+ filepath = folder / f'{field_name}.data.npy'
371
+ memmap = open_memmap(str(filepath), mode = 'w+', dtype = dtype, shape = (max_episodes, max_timesteps, *shape))
372
+
373
+ self.memmaps[field_name] = memmap
374
+ self.shapes[field_name] = shape
375
+ self.dtypes[field_name] = dtype
376
+
377
+ self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
378
+
379
+ def __len__(self):
380
+ return (self.episode_lens > 0).sum().item()
381
+
382
+ def reset_(self):
383
+ self.episode_lens[:] = 0
384
+ self.episode_index = 0
385
+ self.timestep_index = 0
386
+
387
+ def advance_episode(self):
388
+ self.episode_index = (self.episode_index + 1) % self.max_episodes
389
+ self.timestep_index = 0
390
+
391
+ def flush(self):
392
+ self.episode_lens[self.episode_index] = self.timestep_index
393
+
394
+ for memmap in self.memmaps.values():
395
+ memmap.flush()
396
+
397
+ self.episode_lens.flush()
398
+
399
+ @contextmanager
400
+ def one_episode(self):
401
+
402
+ yield
403
+
404
+ self.flush()
405
+ self.advance_episode()
406
+
407
+ @beartype
408
+ def store_datapoint(
409
+ self,
410
+ episode_index: int,
411
+ timestep_index: int,
412
+ name: str,
413
+ datapoint: Tensor | ndarray
414
+ ):
415
+ assert 0 <= episode_index < self.max_episodes
416
+ assert 0 <= timestep_index < self.max_timesteps
417
+
418
+ if is_tensor(datapoint):
419
+ datapoint = datapoint.detach().cpu().numpy()
420
+
421
+ assert name in self.fieldnames, f'invalid field name {name} - must be one of {self.fieldnames}'
422
+
423
+ assert datapoint.shape == self.shapes[name], f'invalid shape {datapoint.shape} - shape must be {self.shapes[name]}'
424
+
425
+ self.memmaps[name][self.episode_index, self.timestep_index] = datapoint
426
+
427
+ def store(
428
+ self,
429
+ **data
430
+ ):
431
+ assert is_bearable(data, dict[str, Tensor | ndarray])
432
+
433
+ assert not self.timestep_index >= self.max_timesteps, 'you exceeded the `max_timesteps` set on the replay buffer'
434
+
435
+ for name, datapoint in data.items():
436
+
437
+ self.store_datapoint(self.episode_index, self.timestep_index, name, datapoint)
438
+
439
+ self.timestep_index += 1
440
+
441
+ return self.memory_namedtuple(**data)
442
+
443
+ def dataset(
444
+ self,
445
+ episode_mapping: Tensor | list[list[int]] | None = None,
446
+ ) -> Dataset:
447
+ self.flush()
448
+
449
+ dataset = ReplayDataset(self.folder)
450
+
451
+ if not exists(episode_mapping):
452
+ return dataset
453
+
454
+ return RemappedReplayDataset(dataset, episode_mapping)
455
+
456
+ def dataloader(
457
+ self,
458
+ batch_size,
459
+ episode_mapping: Tensor | list[list[int]] | None = None,
460
+ **kwargs
461
+ ) -> DataLoader:
462
+ self.flush()
463
+
464
+ return DataLoader(self.dataset(episode_mapping), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
465
+
149
466
  # transformer-xl with ppo
150
467
 
151
468
  class Attention(Module):
@@ -204,7 +521,6 @@ class Attention(Module):
204
521
  return_kv_cache = False,
205
522
  ):
206
523
  seq_len = tokens.shape[-2]
207
- assert seq_len <= self.window_size
208
524
 
209
525
  device = tokens.device
210
526
 
@@ -365,7 +681,21 @@ class Locoformer(Module):
365
681
  embedder: Module,
366
682
  unembedder: Module,
367
683
  transformer: dict | TransformerXL,
368
- value_network: Module | None = None
684
+ discount_factor = 0.999,
685
+ gae_lam = 0.95,
686
+ ppo_eps_clip = 0.2,
687
+ ppo_entropy_weight = 0.01,
688
+ ppo_value_clip = 0.4,
689
+ dim_value_input = None, # needs to be set for value network to be available
690
+ value_network: Module = nn.Identity(),
691
+ reward_range: tuple[float, float] | None = None,
692
+ reward_shaping_fns: list[Callable[[Tensor], float | Tensor]] | None = None,
693
+ num_reward_bins = 32,
694
+ hl_gauss_loss_kwargs = dict(),
695
+ value_loss_weight = 0.5,
696
+ calc_gae_kwargs: dict = dict(),
697
+ recurrent_kv_cache = True,
698
+ use_spo = False # simple policy optimization https://arxiv.org/abs/2401.16025 - Levine's group (PI) verified it is more stable than PPO
369
699
  ):
370
700
  super().__init__()
371
701
 
@@ -377,11 +707,58 @@ class Locoformer(Module):
377
707
  self.embedder = embedder
378
708
  self.unembedder = unembedder
379
709
 
380
- self.value_network = value_network
381
-
382
710
  self.fixed_window_size = transformer.fixed_window_size
383
711
  self.window_size = transformer.window_size
384
712
 
713
+ # determine value network, using HL Gauss Layer
714
+
715
+ self.to_value_pred = None
716
+
717
+ if exists(dim_value_input):
718
+ assert exists(reward_range)
719
+
720
+ self.to_value_pred = nn.Sequential(
721
+ value_network,
722
+ LinearNoBias(dim_value_input, num_reward_bins)
723
+ )
724
+
725
+ reward_min, reward_max = reward_range
726
+
727
+ self.hl_gauss_loss = HLGaussLoss(
728
+ min_value = reward_min,
729
+ max_value = reward_max,
730
+ num_bins = num_reward_bins,
731
+ **hl_gauss_loss_kwargs
732
+ )
733
+
734
+ # ppo related
735
+
736
+ self.discount_factor = discount_factor
737
+ self.gae_lam = gae_lam
738
+ self.ppo_eps_clip = ppo_eps_clip
739
+ self.ppo_entropy_weight = ppo_entropy_weight
740
+ self.ppo_value_clip = ppo_value_clip
741
+ self.value_loss_weight = value_loss_weight
742
+
743
+ self.calc_gae_kwargs = calc_gae_kwargs
744
+
745
+ # maybe use spo
746
+
747
+ self.use_spo = use_spo
748
+
749
+ # maybe recurrent kv cache (todo: find and cite this paper from ages ago)
750
+
751
+ self.recurrent_kv_cache = recurrent_kv_cache
752
+
753
+ # reward shaping function
754
+
755
+ self.has_reward_shaping = exists(reward_shaping_fns)
756
+ self.reward_shaping_fns = reward_shaping_fns
757
+
758
+ # loss related
759
+
760
+ self.register_buffer('zero', tensor(0.), persistent = False)
761
+
385
762
  @property
386
763
  def device(self):
387
764
  return next(self.parameters()).device
@@ -390,7 +767,165 @@ class Locoformer(Module):
390
767
  return self.unembedder.parameters()
391
768
 
392
769
  def critic_parameters(self):
393
- return self.value_network.parameters()
770
+ if not exists(self.to_value_pred):
771
+ return []
772
+
773
+ return self.to_value_pred.parameters()
774
+
775
+ def ppo(
776
+ self,
777
+ state,
778
+ action,
779
+ old_action_log_prob,
780
+ reward,
781
+ old_value,
782
+ mask,
783
+ episode_lens,
784
+ actor_optim: Optimizer | None = None,
785
+ critic_optim: Optimizer | None = None
786
+ ):
787
+ window_size = self.window_size
788
+ total_learnable_tokens = mask.sum().item()
789
+
790
+ seq_len = state.shape[1]
791
+ gae_mask = einx.less('j, i -> i j', arange(seq_len, device = self.device), episode_lens)
792
+
793
+ advantage, returns = calc_gae(reward, old_value, masks = gae_mask, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
794
+
795
+ advantage = normalize(advantage)
796
+
797
+ windowed_tensors = [
798
+ t.split(window_size, dim = 1) for t in
799
+ (
800
+ state,
801
+ action,
802
+ old_action_log_prob,
803
+ reward,
804
+ old_value,
805
+ mask,
806
+ advantage,
807
+ returns
808
+ )
809
+ ]
810
+
811
+ mean_actor_loss = self.zero.clone()
812
+ mean_critic_loss = self.zero.clone()
813
+
814
+ # learn across windows
815
+
816
+ cache = None
817
+
818
+ for (
819
+ state,
820
+ action,
821
+ old_action_log_prob,
822
+ reward,
823
+ old_value,
824
+ mask,
825
+ advantage,
826
+ returns
827
+ ) in zip(*windowed_tensors):
828
+
829
+ (action_logits, value_logits), cache = self.forward(state, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True)
830
+ entropy = calc_entropy(action_logits)
831
+
832
+ action = rearrange(action, 'b t -> b t 1')
833
+ log_prob = action_logits.gather(-1, action)
834
+ log_prob = rearrange(log_prob, 'b t 1 -> b t')
835
+
836
+ # update actor, classic clipped surrogate loss
837
+
838
+ eps_clip = self.ppo_eps_clip
839
+ ratio = (log_prob - old_action_log_prob).exp()
840
+
841
+ if self.use_spo:
842
+ actor_loss = -(ratio * advantage - (advantage.abs() * (ratio - 1.).square()) / (2 * eps_clip))
843
+ else:
844
+ actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
845
+
846
+ actor_loss = actor_loss - self.ppo_entropy_weight * entropy
847
+
848
+ windowed_actor_loss = actor_loss[mask].sum() / total_learnable_tokens
849
+ windowed_actor_loss.backward(retain_graph = True)
850
+
851
+ # update critic
852
+
853
+ value_loss = self.hl_gauss_loss(value_logits, returns, reduction = 'none')
854
+
855
+ value_clip = self.ppo_value_clip
856
+ value = self.hl_gauss_loss(value_logits)
857
+
858
+ clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
859
+ clipped_value_loss = self.hl_gauss_loss(clipped_value, returns, reduction = 'none')
860
+
861
+ critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
862
+
863
+ windowed_critic_loss = critic_loss[mask].sum() / total_learnable_tokens
864
+ windowed_critic_loss.backward(retain_graph = True)
865
+
866
+ # accumulate
867
+
868
+ mean_actor_loss.add_(windowed_actor_loss)
869
+ mean_critic_loss.add_(windowed_critic_loss)
870
+
871
+ # optimizer update
872
+
873
+ if exists(actor_optim):
874
+ actor_optim.step()
875
+ actor_optim.zero_grad()
876
+
877
+ if exists(critic_optim):
878
+ critic_optim.step()
879
+ critic_optim.zero_grad()
880
+
881
+ # return losses for logging
882
+
883
+ return mean_actor_loss.detach(), mean_critic_loss.detach()
884
+
885
+ def state_to_rewards(
886
+ self,
887
+ state
888
+ ) -> Tensor:
889
+
890
+ assert self.has_reward_shaping
891
+
892
+ rewards = [fn(state) for fn in self.reward_shaping_fns]
893
+
894
+ rewards = [tensor(reward) if not is_tensor(reward) else reward for reward in rewards]
895
+ return stack(rewards)
896
+
897
+ def wrap_env_functions(self, env):
898
+
899
+ def transform_output(el):
900
+ if isinstance(el, ndarray):
901
+ return from_numpy(el)
902
+ elif isinstance(el, (int, bool, float)):
903
+ return tensor(el)
904
+ else:
905
+ return el
906
+
907
+ def wrapped_reset(*args, **kwargs):
908
+ env_reset_out = env.reset(*args, **kwargs)
909
+
910
+ return tree_map(transform_output, env_reset_out)
911
+
912
+ def wrapped_step(action, *args, **kwargs):
913
+
914
+ if is_tensor(action):
915
+ action = action.item()
916
+
917
+ env_step_out = env.step(action, *args, **kwargs)
918
+
919
+ env_step_out_torch = tree_map(transform_output, env_step_out)
920
+
921
+ if not self.has_reward_shaping:
922
+ return env_step_out_torch
923
+
924
+ shaped_rewards = self.state_to_rewards(env_step_out_torch)
925
+
926
+ return env_step_out_torch, shaped_rewards
927
+
928
+ return wrapped_reset, wrapped_step
394
929
 
395
930
  def get_stateful_forward(
396
931
  self,
@@ -398,6 +933,7 @@ class Locoformer(Module):
398
933
  inference_mode = False,
399
934
  has_batch_dim = False,
400
935
  has_time_dim = False,
936
+ state_time_dim = 1,
401
937
  **kwargs
402
938
  ):
403
939
  window_size = self.window_size
@@ -413,23 +949,16 @@ class Locoformer(Module):
413
949
  state = rearrange(state, '... -> 1 ...')
414
950
 
415
951
  if not has_time_dim:
416
- state = rearrange(state, '... d -> ... 1 d')
952
+ state = state.unsqueeze(state_time_dim)
417
953
 
418
954
  # forwards
419
955
 
420
956
  out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
421
957
 
422
- # handle cache
423
-
424
- cache_len = cache.shape[-2]
425
-
426
- if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
427
- cache = cache[..., -window_size:, :]
428
-
429
958
  # maybe remove batch or time
430
959
 
431
960
  if not has_time_dim:
432
- out = tree_map_tensor(out, lambda t: rearrange(t, '... 1 d -> ... d'))
961
+ out = tree_map_tensor(out, lambda t: t.squeeze(state_time_dim))
433
962
 
434
963
  if not has_batch_dim:
435
964
  out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
@@ -458,14 +987,35 @@ class Locoformer(Module):
458
987
  def forward(
459
988
  self,
460
989
  state: Tensor,
461
- cache: Tensor | None = None,
990
+ cache: Cache | None = None,
462
991
  detach_cache = False,
463
- return_values = False
992
+ return_values = False,
993
+ return_raw_value_logits = False
464
994
  ):
465
995
 
996
+ state = state.to(self.device)
997
+
466
998
  tokens = self.embedder(state)
467
999
 
468
- embed, kv_cache = self.transformer(tokens, cache = cache, return_kv_cache = True)
1000
+ # time
1001
+
1002
+ time = tokens.shape[-2]
1003
+
1004
+ # destruct the cache for the current timestep and the cache
1005
+
1006
+ prev_kv_cache = None
1007
+ timestep_start = 0
1008
+
1009
+ if exists(cache):
1010
+ timestep_start, prev_kv_cache = cache
1011
+
1012
+ # an assert - make sure during training or inference, forward never gets anything that crosses the window segment boundary, to open up some possibilities with extending memory
1013
+
1014
+ assert ((timestep_start % self.window_size) + time) <= self.window_size
1015
+
1016
+ # attention
1017
+
1018
+ embed, kv_cache = self.transformer(tokens, cache = prev_kv_cache, return_kv_cache = True)
469
1019
 
470
1020
  # unembed to actions - in language models this would be the next state
471
1021
 
@@ -476,21 +1026,34 @@ class Locoformer(Module):
476
1026
  # maybe detach cache
477
1027
 
478
1028
  if detach_cache:
479
- kv_cache = detach_all(kv_cache)
1029
+ kv_cache = kv_cache.detach()
480
1030
 
481
1031
  # handle returning of values
482
1032
 
483
1033
  if return_values:
484
- assert exists(self.value_network)
1034
+ assert exists(self.to_value_pred)
485
1035
 
486
- values = self.value_network(embed)
1036
+ values = self.to_value_pred(embed)
487
1037
 
488
- if values.ndim == 3:
489
- assert values.shape[-1] == 1
490
- values = rearrange(values, '... 1 -> ...')
1038
+ if not return_raw_value_logits:
1039
+ values = self.hl_gauss_loss(values) # converts the value logits to scalar values
491
1040
 
492
1041
  out = (out, values)
493
1042
 
494
1043
  # output and cache
495
1044
 
496
- return out, kv_cache
1045
+ next_timestep = time + timestep_start
1046
+
1047
+ # handle curtailing kv cache at the right intervals
1048
+
1049
+ window_size = self.window_size
1050
+
1051
+ if self.fixed_window_size or divisible_by(next_timestep, window_size * 2):
1052
+ kv_cache = kv_cache[..., -window_size:, :]
1053
+
1054
+ # maybe recurrent cache - shift the kv cache from one layer above to the one below, for extending on receptive field of past
1055
+
1056
+ if self.recurrent_kv_cache and divisible_by(next_timestep, window_size):
1057
+ kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
1058
+
1059
+ return out, (next_timestep, kv_cache)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: locoformer
3
- Version: 0.0.7
3
+ Version: 0.0.29
4
4
  Summary: LocoFormer
5
5
  Project-URL: Homepage, https://pypi.org/project/locoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/locoformer
@@ -35,8 +35,10 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: assoc-scan
38
+ Requires-Dist: beartype
38
39
  Requires-Dist: einops>=0.8.0
39
40
  Requires-Dist: einx>=0.3.0
41
+ Requires-Dist: hl-gauss-pytorch>=0.2.0
40
42
  Requires-Dist: rotary-embedding-torch
41
43
  Requires-Dist: torch>=2.4
42
44
  Requires-Dist: x-mlps-pytorch
@@ -53,7 +55,7 @@ Description-Content-Type: text/markdown
53
55
 
54
56
  [LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
55
57
 
56
- The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
58
+ The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) and extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
57
59
 
58
60
  ## Sponsors
59
61
 
@@ -0,0 +1,6 @@
1
+ locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
2
+ locoformer/locoformer.py,sha256=Tr_1btuoTZ0huXeDcAeuHxTPaVeCUEGc5iLvMYGDLck,29982
3
+ locoformer-0.0.29.dist-info/METADATA,sha256=5Fi3EOsgpBvpzAFVZQyrlink-HcHE8EgFl10Y5l8mqM,3256
4
+ locoformer-0.0.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ locoformer-0.0.29.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ locoformer-0.0.29.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
2
- locoformer/locoformer.py,sha256=lJQs0CKr9iztF8tie1FRUVEItCt-IZbIILQqKcgK2sI,13142
3
- locoformer-0.0.7.dist-info/METADATA,sha256=PZ_phKV3t4Bha0GnUB5HPmE9w8A5fvNevsuN532Ls3s,3193
4
- locoformer-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- locoformer-0.0.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- locoformer-0.0.7.dist-info/RECORD,,