locoformer 0.0.43__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
locoformer/locoformer.py CHANGED
@@ -1,11 +1,14 @@
1
1
  from __future__ import annotations
2
+ import math
2
3
  from typing import Callable
3
4
  from types import SimpleNamespace
4
5
  from functools import partial, wraps
5
6
 
6
7
  from pathlib import Path
7
8
  from contextlib import contextmanager
8
- from collections import namedtuple
9
+ from collections import namedtuple, deque
10
+
11
+ from glom import glom
9
12
 
10
13
  from inspect import signature
11
14
 
@@ -17,16 +20,17 @@ from beartype import beartype
17
20
  from beartype.door import is_bearable
18
21
 
19
22
  import torch
20
- from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
23
+ from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy, nested
21
24
  import torch.nn.functional as F
22
25
  from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
23
26
  from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
24
27
  from torch.utils.data import Dataset, DataLoader
28
+ from torch.distributions import Normal
25
29
  from torch.optim import Optimizer
26
30
 
27
31
  import einx
28
- from einops import rearrange, einsum
29
- from einops.layers.torch import Rearrange
32
+ from einops import rearrange, repeat, einsum, reduce, pack
33
+ from einops.layers.torch import Rearrange, Reduce
30
34
 
31
35
  from rotary_embedding_torch import RotaryEmbedding
32
36
 
@@ -38,11 +42,24 @@ from x_mlps_pytorch import MLP
38
42
 
39
43
  from x_evolution import EvoStrategy
40
44
 
45
+ from discrete_continuous_embed_readout import EmbedAndReadout, Embed, Readout
46
+
47
+ from hyper_connections import mc_get_init_and_expand_reduce_stream_functions
48
+
49
+ from memmap_replay_buffer import ReplayBuffer, ReplayDataset
50
+
41
51
  # constants
42
52
 
43
53
  LinearNoBias = partial(Linear, bias = False)
44
54
 
45
- Cache = namedtuple('Cache', ('curr_timestep', 'kv_cache')) # (int, Tensor)
55
+ TransformerMemory = namedtuple('TransformerMemory', (
56
+ 'total_tokens',
57
+ 'kv_cache',
58
+ 'gru_cache',
59
+ 'mem_mlp_cache',
60
+ 'mem_mlp_hidden_states',
61
+ 'memory_segments'
62
+ ))
46
63
 
47
64
  # helper functions
48
65
 
@@ -52,6 +69,18 @@ def exists(v):
52
69
  def default(v, d):
53
70
  return v if exists(v) else d
54
71
 
72
+ def always(val):
73
+ def inner(*args, **kwargs):
74
+ return val
75
+
76
+ return inner
77
+
78
+ def identity(t, *args, **kwargs):
79
+ return t
80
+
81
+ def pick(data, keys):
82
+ return tuple(data[k] for k in keys)
83
+
55
84
  def first(arr):
56
85
  return arr[0]
57
86
 
@@ -100,6 +129,11 @@ def is_empty(t):
100
129
  def tree_map_tensor(x, fn):
101
130
  return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
102
131
 
132
+ def lens_to_mask(lens, max_len):
133
+ device = lens.device
134
+ seq = arange(max_len, device = device)
135
+ return einx.less('j, i -> i j', seq, lens)
136
+
103
137
  def pad_at_dim(
104
138
  t,
105
139
  pad: tuple[int, int],
@@ -113,8 +147,20 @@ def pad_at_dim(
113
147
  zeros = ((0, 0) * dims_from_right)
114
148
  return F.pad(t, (*zeros, *pad), value = value)
115
149
 
116
- def normalize(t, eps = 1e-5):
117
- return (t - t.mean()) / t.std().clamp_min(eps)
150
+ def safe_cat(t, next_t, dim = -1):
151
+ if not exists(t):
152
+ return next_t
153
+
154
+ return cat((t, next_t), dim = dim)
155
+
156
+ def normalize(t, mask = None, eps = 1e-5):
157
+ if exists(mask):
158
+ assert mask.any()
159
+
160
+ t_for_stats = t[mask] if exists(mask) else t
161
+ var, mean = torch.var_mean(t_for_stats)
162
+
163
+ return (t - mean) / var.sqrt().clamp_min(eps)
118
164
 
119
165
  def tensor_to_dict(
120
166
  t: Tensor,
@@ -135,9 +181,75 @@ def tensor_to_dict(
135
181
 
136
182
  return SimpleNamespace(**tensor_dict)
137
183
 
138
- def calc_entropy(logits):
139
- prob = logits.softmax(dim = -1)
140
- return -(prob * log(prob)).sum(dim = -1)
184
+ # dataset related
185
+
186
+ class RemappedReplayDataset(Dataset):
187
+ def __init__(
188
+ self,
189
+ dataset: ReplayDataset,
190
+ episode_mapping: Tensor | list[list[int]],
191
+ shuffle_episodes = False,
192
+ num_trials_select = None
193
+ ):
194
+ assert len(dataset) > 0
195
+ self.dataset = dataset
196
+
197
+ if is_tensor(episode_mapping):
198
+ assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
199
+ episode_mapping = episode_mapping.tolist()
200
+
201
+ self.episode_mapping = episode_mapping
202
+ self.shuffle_episodes = shuffle_episodes
203
+
204
+ assert not (exists(num_trials_select) and num_trials_select <= 0)
205
+ self.sub_select_trials = exists(num_trials_select)
206
+ self.num_trials_select = num_trials_select
207
+
208
+ def __len__(self):
209
+ return len(self.episode_mapping)
210
+
211
+ def __getitem__(self, idx):
212
+
213
+ episode_indices = self.episode_mapping[idx]
214
+
215
+ episode_indices = tensor(episode_indices)
216
+ episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
217
+
218
+ assert not is_empty(episode_indices)
219
+
220
+ # shuffle the episode indices if either shuffle episodes is turned on, or `num_trial_select` passed in (for sub selecting episodes from a set)
221
+
222
+ if (
223
+ episode_indices.numel() > 1 and
224
+ (self.shuffle_episodes or self.sub_select_trials)
225
+ ):
226
+ num_episodes = len(episode_indices)
227
+ episode_indices = episode_indices[torch.randperm(num_episodes)]
228
+
229
+ # crop out the episodes
230
+
231
+ if self.sub_select_trials:
232
+ episode_indices = episode_indices[:self.num_trials_select]
233
+
234
+ # now select out the episode data and merge along time
235
+
236
+ episode_data = [self.dataset[i] for i in episode_indices.tolist()]
237
+
238
+ episode_lens = stack([data.pop('_lens') for data in episode_data])
239
+
240
+ keys = first(episode_data).keys()
241
+
242
+ values = [list(data.values()) for data in episode_data]
243
+
244
+ values = [cat(field_values) for field_values in zip(*values)] # concat across time
245
+
246
+ multi_episode_data = dict(zip(keys, values))
247
+
248
+ multi_episode_data['_lens'] = episode_lens.sum()
249
+
250
+ multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
251
+
252
+ return multi_episode_data
141
253
 
142
254
  # reward functions - A.2
143
255
 
@@ -300,305 +412,6 @@ def create_sliding_mask(
300
412
  create_kwargs = dict(device = device) if exists(device) else dict()
301
413
  return create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True, **create_kwargs)
302
414
 
303
- # data
304
-
305
- def collate_var_time(data):
306
-
307
- datum = first(data)
308
- keys = datum.keys()
309
-
310
- all_tensors = zip(*[datum.values() for datum in data])
311
-
312
- collated_values = []
313
-
314
- for key, tensors in zip(keys, all_tensors):
315
-
316
- # the episode lens have zero dimension - think of a cleaner way to handle this later
317
-
318
- if key != '_lens':
319
-
320
- times = [t.shape[0] for t in tensors]
321
- max_time = max(times)
322
- tensors = [pad_at_dim(t, (0, max_time - t.shape[0]), dim = 0) for t in tensors]
323
-
324
- collated_values.append(stack(tensors))
325
-
326
- return dict(zip(keys, collated_values))
327
-
328
- class ReplayDataset(Dataset):
329
- def __init__(
330
- self,
331
- folder: str | Path,
332
- fields: tuple[str, ...] | None = None
333
- ):
334
- if isinstance(folder, str):
335
- folder = Path(folder)
336
-
337
- episode_lens = folder / 'episode_lens.npy'
338
- self.episode_lens = open_memmap(str(episode_lens), mode = 'r')
339
-
340
- # get indices of non-zero lengthed episodes
341
-
342
- nonzero_episodes = self.episode_lens > 0
343
- self.indices = np.arange(self.episode_lens.shape[-1])[nonzero_episodes]
344
-
345
- # get all data files
346
-
347
- filepaths = [*folder.glob('*.data.npy')]
348
- assert len(filepaths) > 0
349
-
350
- fieldname_to_filepath = {path.name.split('.')[0]: path for path in filepaths}
351
-
352
- fieldnames_from_files = set(fieldname_to_filepath.keys())
353
-
354
- fields = default(fields, fieldnames_from_files)
355
-
356
- self.memmaps = dict()
357
-
358
- for field in fields:
359
- assert field in fieldnames_from_files, f'invalid field {field} - must be one of {fieldnames_from_files}'
360
-
361
- path = fieldname_to_filepath[field]
362
-
363
- self.memmaps[field] = open_memmap(str(path), mode = 'r')
364
-
365
- def __len__(self):
366
- return len(self.indices)
367
-
368
- def __getitem__(self, idx):
369
- episode_index = self.indices[idx]
370
-
371
- episode_len = self.episode_lens[episode_index]
372
-
373
- data = {field: from_numpy(memmap[episode_index, :episode_len].copy()) for field, memmap in self.memmaps.items()}
374
-
375
- data['_lens'] = tensor(episode_len)
376
-
377
- return data
378
-
379
- class RemappedReplayDataset(Dataset):
380
- def __init__(
381
- self,
382
- dataset: ReplayDataset,
383
- episode_mapping: Tensor | list[list[int]],
384
- shuffle_episodes = False,
385
- num_trials_select = None
386
- ):
387
- assert len(dataset) > 0
388
- self.dataset = dataset
389
-
390
- if is_tensor(episode_mapping):
391
- assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
392
- episode_mapping = episode_mapping.tolist()
393
-
394
- self.episode_mapping = episode_mapping
395
- self.shuffle_episodes = shuffle_episodes
396
-
397
- assert not (exists(num_trials_select) and num_trials_select >= 1)
398
- self.sub_select_trials = exists(num_trials_select)
399
- self.num_trials_select = num_trials_select
400
-
401
- def __len__(self):
402
- return len(self.episode_mapping)
403
-
404
- def __getitem__(self, idx):
405
-
406
- episode_indices = self.episode_mapping[idx]
407
-
408
- episode_indices = tensor(episode_indices)
409
- episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
410
-
411
- assert not is_empty(episode_indices)
412
-
413
- # shuffle the episode indices if either shuffle episodes is turned on, or `num_trial_select` passed in (for sub selecting episodes from a set)
414
-
415
- if (
416
- episode_indices.numel() > 1 and
417
- (self.shuffle_episodes or self.sub_select_trials)
418
- ):
419
- num_episodes = len(episode_indices)
420
- episode_indices = episode_indices[torch.randperm(num_episodes)]
421
-
422
- # crop out the episodes
423
-
424
- if self.sub_select_trials:
425
- episode_indices = episode_indices[:self.num_trials_select]
426
-
427
- # now select out the episode data and merge along time
428
-
429
- episode_data = [self.dataset[i] for i in episode_indices.tolist()]
430
-
431
- episode_lens = stack([data.pop('_lens') for data in episode_data])
432
-
433
- keys = first(episode_data).keys()
434
-
435
- values = [list(data.values()) for data in episode_data]
436
-
437
- values = [cat(field_values) for field_values in zip(*values)] # concat across time
438
-
439
- multi_episode_data = dict(zip(keys, values))
440
-
441
- multi_episode_data['_lens'] = episode_lens.sum()
442
-
443
- multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
444
-
445
- return multi_episode_data
446
-
447
- class ReplayBuffer:
448
-
449
- @beartype
450
- def __init__(
451
- self,
452
- folder: str | Path,
453
- max_episodes: int,
454
- max_timesteps: int,
455
- fields: dict[
456
- str,
457
- str | tuple[str, int | tuple[int, ...]]
458
- ]
459
- ):
460
-
461
- # folder for data
462
-
463
- if not isinstance(folder, Path):
464
- folder = Path(folder)
465
- folder.mkdir(exist_ok = True)
466
-
467
- self.folder = folder
468
- assert folder.is_dir()
469
-
470
- # keeping track of episode length
471
-
472
- episode_lens = folder / 'episode_lens.npy'
473
-
474
- self.episode_index = 0
475
- self.timestep_index = 0
476
-
477
- self.max_episodes = max_episodes
478
- self.max_timesteps= max_timesteps
479
-
480
- self.episode_lens = open_memmap(str(episode_lens), mode = 'w+', dtype = np.int32, shape = (max_episodes,))
481
-
482
- # create the memmap for individual data tracks
483
-
484
- self.shapes = dict()
485
- self.dtypes = dict()
486
- self.memmaps = dict()
487
- self.fieldnames = set(fields.keys())
488
-
489
- for field_name, field_info in fields.items():
490
-
491
- # some flexibility
492
-
493
- field_info = (field_info, ()) if isinstance(field_info, str) else field_info
494
-
495
- dtype_str, shape = field_info
496
- assert dtype_str in {'int', 'float', 'bool'}
497
-
498
- dtype = dict(int = np.int32, float = np.float32, bool = np.bool_)[dtype_str]
499
-
500
- # memmap file
501
-
502
- filepath = folder / f'{field_name}.data.npy'
503
-
504
- if isinstance(shape, int):
505
- shape = (shape,)
506
-
507
- memmap = open_memmap(str(filepath), mode = 'w+', dtype = dtype, shape = (max_episodes, max_timesteps, *shape))
508
-
509
- self.memmaps[field_name] = memmap
510
- self.shapes[field_name] = shape
511
- self.dtypes[field_name] = dtype
512
-
513
- self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
514
-
515
- def __len__(self):
516
- return (self.episode_lens > 0).sum().item()
517
-
518
- def reset_(self):
519
- self.episode_lens[:] = 0
520
- self.episode_index = 0
521
- self.timestep_index = 0
522
-
523
- def advance_episode(self):
524
- self.episode_index = (self.episode_index + 1) % self.max_episodes
525
- self.timestep_index = 0
526
-
527
- def flush(self):
528
- self.episode_lens[self.episode_index] = self.timestep_index
529
-
530
- for memmap in self.memmaps.values():
531
- memmap.flush()
532
-
533
- self.episode_lens.flush()
534
-
535
- @contextmanager
536
- def one_episode(self):
537
-
538
- yield
539
-
540
- self.flush()
541
- self.advance_episode()
542
-
543
- @beartype
544
- def store_datapoint(
545
- self,
546
- episode_index: int,
547
- timestep_index: int,
548
- name: str,
549
- datapoint: Tensor | ndarray
550
- ):
551
- assert 0 <= episode_index < self.max_episodes
552
- assert 0 <= timestep_index < self.max_timesteps
553
-
554
- if is_tensor(datapoint):
555
- datapoint = datapoint.detach().cpu().numpy()
556
-
557
- assert name in self.fieldnames, f'invalid field name {name} - must be one of {self.fieldnames}'
558
-
559
- assert datapoint.shape == self.shapes[name], f'invalid shape {datapoint.shape} - shape must be {self.shapes[name]}'
560
-
561
- self.memmaps[name][self.episode_index, self.timestep_index] = datapoint
562
-
563
- def store(
564
- self,
565
- **data
566
- ):
567
- assert is_bearable(data, dict[str, Tensor | ndarray])
568
-
569
- assert not self.timestep_index >= self.max_timesteps, 'you exceeded the `max_timesteps` set on the replay buffer'
570
-
571
- for name, datapoint in data.items():
572
-
573
- self.store_datapoint(self.episode_index, self.timestep_index, name, datapoint)
574
-
575
- self.timestep_index += 1
576
-
577
- return self.memory_namedtuple(**data)
578
-
579
- def dataset(
580
- self,
581
- episode_mapping: Tensor | list[list[int]] | None = None,
582
- ) -> Dataset:
583
- self.flush()
584
-
585
- dataset = ReplayDataset(self.folder)
586
-
587
- if not exists(episode_mapping):
588
- return dataset
589
-
590
- return RemappedReplayDataset(dataset, episode_mapping)
591
-
592
- def dataloader(
593
- self,
594
- batch_size,
595
- episode_mapping: Tensor | list[list[int]] | None = None,
596
- **kwargs
597
- ) -> DataLoader:
598
- self.flush()
599
-
600
- return DataLoader(self.dataset(episode_mapping), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
601
-
602
415
  # normalization + conditioning (needed for the commands to the robot)
603
416
 
604
417
  class MaybeAdaRMSNormWrapper(Module):
@@ -627,34 +440,52 @@ class MaybeAdaRMSNormWrapper(Module):
627
440
  def forward(
628
441
  self,
629
442
  x,
443
+ *args,
630
444
  cond = None,
445
+ cond_mask = None,
631
446
  **kwargs
632
447
  ):
633
448
 
634
449
  need_cond = self.accept_condition
450
+ has_input_cond = need_cond and exists(cond)
635
451
 
636
- assert xnor(exists(cond), need_cond)
452
+ if exists(cond):
453
+ assert self.accept_condition
637
454
 
638
455
  prenormed = self.norm(x)
639
456
 
640
- if need_cond:
457
+ if has_input_cond:
641
458
  if cond.ndim == 2:
642
459
  cond = rearrange(cond, 'b d -> b 1 d')
643
460
 
644
- scale_in = self.to_gamma(cond)
645
- prenormed = prenormed * (scale_in + 1.)
461
+ cond_scale = self.to_gamma(cond)
462
+
463
+ conditioned = prenormed * cond_scale
464
+
465
+ # handle a condition mask
646
466
 
647
- all_fn_out = self.fn(prenormed, **kwargs)
467
+ if exists(cond_mask):
468
+ prenormed = einx.where('b n, b n d, b n d', cond_mask, conditioned, prenormed)
469
+ else:
470
+ prenormed = conditioned
471
+
472
+ # the main block, either attention or feedforward or whatever
648
473
 
649
- if not need_cond:
474
+ all_fn_out = self.fn(prenormed, *args, **kwargs)
475
+
476
+ if not has_input_cond:
650
477
  return all_fn_out
651
478
 
652
479
  # function may return multiple args
653
480
 
654
481
  (out, *rest), tree_spec = tree_flatten(all_fn_out)
655
482
 
656
- if need_cond:
657
- scale_out = self.to_ada_norm_zero(cond).sigmoid()
483
+ scale_out = self.to_ada_norm_zero(cond).sigmoid()
484
+
485
+ if exists(cond_mask):
486
+ is_cond = rearrange(cond_mask, '... -> ... 1')
487
+ out = torch.where(is_cond, out * scale_out, out)
488
+ else:
658
489
  out = out * scale_out
659
490
 
660
491
  # restore
@@ -673,7 +504,8 @@ class Attention(Module):
673
504
  dim_head = 64,
674
505
  heads = 8,
675
506
  fixed_window_size = False,
676
- accept_value_residual = False
507
+ accept_value_residual = False,
508
+ max_mem_segments = 1
677
509
  ):
678
510
  super().__init__()
679
511
  self.scale = dim_head ** -0.5
@@ -709,12 +541,16 @@ class Attention(Module):
709
541
 
710
542
  self.fixed_window_size = fixed_window_size
711
543
  self.window_size = window_size
544
+ self.max_mem_segments = max_mem_segments
545
+
546
+ self.register_buffer('causal_mask', None, persistent = False)
712
547
 
713
548
  def forward(
714
549
  self,
715
550
  tokens,
716
551
  value_residual = None,
717
552
  kv_cache = None,
553
+ past_segments = None,
718
554
  return_kv_cache = False,
719
555
  ):
720
556
  seq_len = tokens.shape[-2]
@@ -739,6 +575,11 @@ class Attention(Module):
739
575
  k = cat((ck, k), dim = -2)
740
576
  v = cat((cv, v), dim = -2)
741
577
 
578
+ if exists(past_segments):
579
+ pk, pv = past_segments
580
+ k = cat((pk, k), dim = -2)
581
+ v = cat((pv, v), dim = -2)
582
+
742
583
  if return_kv_cache:
743
584
  next_kv_cache = stack((k, v))
744
585
 
@@ -752,9 +593,12 @@ class Attention(Module):
752
593
  i_seq = arange(i, device = device)
753
594
  j_seq = arange(j, device = device) - (j - i)
754
595
  dist = einx.subtract('i, j -> i j', i_seq, j_seq)
755
- causal_mask = (dist < 0) | (dist > self.window_size)
596
+ causal_mask = (dist < 0) | (dist > (self.max_mem_segments * self.window_size))
756
597
  else:
757
- causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
598
+ if not exists(self.causal_mask) or self.causal_mask.shape != (i, j):
599
+ self.causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
600
+
601
+ causal_mask = self.causal_mask
758
602
 
759
603
  sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
760
604
 
@@ -801,6 +645,7 @@ class FeedForward(Module):
801
645
  return self.proj_out(x)
802
646
 
803
647
  class TransformerXL(Module):
648
+ @beartype
804
649
  def __init__(
805
650
  self,
806
651
  dim,
@@ -812,8 +657,31 @@ class TransformerXL(Module):
812
657
  dim_cond = None,
813
658
  final_norm = True,
814
659
  fixed_window_size = False,
660
+ gru_layers = False,
661
+ long_term_mem_layers: tuple[int, ...] = (),
662
+ mem_kwargs: dict = dict(),
663
+ num_residual_streams = 1,
664
+ max_mem_segments = 1
815
665
  ):
816
666
  super().__init__()
667
+ self.dim = dim
668
+
669
+ # memory
670
+
671
+ long_term_mem_layers = set(long_term_mem_layers)
672
+
673
+ assert all([1 <= l <= depth for l in long_term_mem_layers])
674
+
675
+ self.long_term_mem_layers = long_term_mem_layers
676
+ self.num_mem_mlps = len(long_term_mem_layers)
677
+ self.has_mem = self.num_mem_mlps > 0
678
+ self.max_mem_segments = max_mem_segments
679
+
680
+ # hyper connections
681
+
682
+ init_hyper_conn, self.expand_streams, self.reduce_streams = mc_get_init_and_expand_reduce_stream_functions(num_residual_streams)
683
+
684
+ # condition
817
685
 
818
686
  condition = exists(dim_cond)
819
687
 
@@ -821,22 +689,48 @@ class TransformerXL(Module):
821
689
 
822
690
  norm_fn = partial(MaybeAdaRMSNormWrapper, dim = dim, dim_cond = (dim * 2) if condition else None)
823
691
 
692
+ # layers
693
+
824
694
  layers = ModuleList([])
825
695
 
826
696
  for i in range(depth):
827
- is_first = i == 0
697
+ layer = i + 1
698
+ is_first = layer == 1
699
+ has_mem = layer in long_term_mem_layers
700
+
701
+ gru = norm_fn(nn.GRU(dim, dim, batch_first = True)) if gru_layers else None
828
702
 
829
- attn = norm_fn(Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first))
703
+ mem = MemoryMLP(dim, **mem_kwargs) if has_mem else None
704
+
705
+ attn = norm_fn(Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first, max_mem_segments = max_mem_segments))
830
706
 
831
707
  ff = norm_fn(FeedForward(dim = dim, expansion_factor = expansion_factor))
832
708
 
709
+ # maybe hyper connections
710
+
711
+ mem_store_hc = init_hyper_conn(dim = dim, add_branch_out_to_residual = False)
712
+
713
+ if num_residual_streams > 1:
714
+
715
+ attn = init_hyper_conn(dim = dim, branch = attn)
716
+
717
+ ff = init_hyper_conn(dim = dim, branch = ff)
718
+
719
+ if gru_layers:
720
+ gru = init_hyper_conn(dim = dim, branch = gru)
721
+
722
+ if has_mem:
723
+ mem = init_hyper_conn(dim = dim, branch = mem, forward_method_names = ('store',))
724
+
833
725
  layers.append(ModuleList([
834
- attn, ff
726
+ gru, mem, mem_store_hc, attn, ff
835
727
  ]))
836
728
 
837
729
  self.layers = layers
838
730
  self.norm = RMSNorm(dim) if final_norm else Identity()
839
731
 
732
+ self.gru_layers = gru_layers
733
+
840
734
  # fixed window size
841
735
 
842
736
  self.fixed_window_size = fixed_window_size
@@ -845,92 +739,439 @@ class TransformerXL(Module):
845
739
  def forward(
846
740
  self,
847
741
  x,
848
- cache = None,
742
+ cache: TransformerMemory | None = None,
849
743
  return_kv_cache = False,
850
- condition: Tensor | None = None
744
+ condition: Tensor | None = None,
745
+ cond_mask: Tensor | None = None
851
746
  ):
747
+ curr_token_seq_len = x.shape[-2]
852
748
 
853
749
  # cache and residuals
854
750
 
855
- cache = default(cache, (None,) * len(self.layers))
751
+ num_layers = len(self.layers)
752
+
753
+ # extract variables from cache
754
+
755
+ is_first_window = True
756
+ total_tokens = 0
757
+ kv_cache = gru_cache = mem_mlp_cache = mem_mlp_hidden_states = memory_segments = None
758
+
759
+ if exists(cache):
760
+ total_tokens, kv_cache, gru_cache, mem_mlp_cache, mem_mlp_hidden_states, memory_segments = cache
761
+ is_first_window = total_tokens < self.window_size
762
+ else:
763
+ memory_segments = deque(maxlen = self.max_mem_segments)
764
+
765
+ # handle memory segments
766
+
767
+ past_segments = None
768
+ if len(memory_segments) > 0:
769
+ past_segments = stack(list(memory_segments), dim = 0)
770
+ past_segments = rearrange(past_segments, 'l depth kv b h n d -> depth kv b h (l n) d')
771
+
772
+ kv_cache = default(kv_cache, (None,) * num_layers)
773
+ gru_cache = default(gru_cache, (None,) * num_layers)
774
+ mem_mlp_cache = default(mem_mlp_cache, (None,) * num_layers)
775
+ mem_mlp_hidden_states = default(mem_mlp_hidden_states, (None,) * num_layers)
776
+
777
+ # prepare next cache
856
778
 
857
779
  next_kv_caches = []
780
+ next_gru_hiddens = [] if self.gru_layers else None
781
+ next_mem_mlp_cache = [] if self.has_mem else None
782
+ next_mem_mlp_hidden_states = [] if self.has_mem else None
783
+ next_total_tokens = total_tokens + curr_token_seq_len
784
+
785
+ is_window_boundary = divisible_by(next_total_tokens, self.window_size)
786
+
858
787
  value_residual = None
859
788
 
860
789
  # handle condition
861
790
 
862
791
  cond_tokens = None
792
+
863
793
  if exists(condition):
864
794
  assert exists(self.to_cond_tokens)
865
795
  cond_tokens = self.to_cond_tokens(condition)
866
796
 
797
+ cond_kwargs = dict(cond = cond_tokens, cond_mask = cond_mask)
798
+
799
+ # hc expand
800
+
801
+ x = self.expand_streams(x)
802
+
867
803
  # layers
868
804
 
869
- for (attn, ff), kv_cache in zip(self.layers, cache):
805
+ for layer_index, ((maybe_gru, maybe_mem, maybe_mem_store_hc, attn, ff), layer_gru_cache, layer_mem_mlp, layer_kv_cache, layer_hidden_states) in enumerate(zip(self.layers, gru_cache, mem_mlp_cache, kv_cache, mem_mlp_hidden_states)):
870
806
 
871
- attn_out, (next_kv_cache, values) = attn(x, cond = cond_tokens, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
807
+ # handle maybe rnn
872
808
 
873
- x = attn_out + x
874
- x = ff(x, cond = cond_tokens) + x
809
+ if exists(maybe_gru):
810
+ x, gru_hiddens = maybe_gru(x, layer_gru_cache, **cond_kwargs)
811
+
812
+ next_gru_hiddens.append(gru_hiddens)
813
+
814
+ # maybe handle retrieving
815
+
816
+ is_mem_layer = exists(maybe_mem)
817
+
818
+ if (
819
+ not is_first_window and
820
+ is_mem_layer
821
+ ):
822
+ x = maybe_mem(x, layer_mem_mlp)
823
+
824
+ # attention
825
+
826
+ layer_past_segments = None
827
+ if exists(past_segments):
828
+ layer_past_segments = past_segments[layer_index]
829
+
830
+ x, (next_kv_cache, values) = attn(x, **cond_kwargs, value_residual = value_residual, kv_cache = layer_kv_cache, past_segments = layer_past_segments, return_kv_cache = True)
831
+
832
+ # handle storing of memory
833
+
834
+ if self.has_mem:
835
+ next_mem_mlp = layer_mem_mlp
836
+ next_layer_hidden_states = layer_hidden_states
837
+
838
+ if is_mem_layer:
839
+ # accumulate hidden states
840
+ next_layer_hidden_states = safe_cat(layer_hidden_states, x, dim = -2)
841
+
842
+ if is_window_boundary:
843
+ mem_store_input, _ = maybe_mem_store_hc(next_layer_hidden_states)
844
+
845
+ next_mem_mlp = maybe_mem.store(mem_store_input, layer_mem_mlp)
846
+ next_layer_hidden_states = None
847
+
848
+ next_mem_mlp_cache.append(next_mem_mlp)
849
+ next_mem_mlp_hidden_states.append(next_layer_hidden_states)
850
+
851
+ # feedforward
852
+
853
+ x = ff(x, **cond_kwargs)
875
854
 
876
855
  next_kv_caches.append(next_kv_cache)
877
856
  value_residual = default(value_residual, values)
878
857
 
879
- embed = self.norm(x)
858
+ # hc reduce
880
859
 
881
- if not return_kv_cache:
882
- return embed
860
+ x = self.reduce_streams(x)
861
+
862
+ # norm
863
+
864
+ embed = self.norm(x)
883
865
 
884
866
  next_kv_cache = stack(next_kv_caches)
885
867
 
886
- next_kv_cache = next_kv_cache[..., -self.window_size:, :]
868
+ if exists(next_gru_hiddens):
869
+ next_gru_hiddens = stack(next_gru_hiddens)
870
+
871
+ next_cache = TransformerMemory(next_total_tokens, next_kv_cache, next_gru_hiddens, next_mem_mlp_cache, next_mem_mlp_hidden_states, memory_segments)
872
+
873
+ return embed, next_cache
874
+
875
+ # simple 2 layer memory mlp
876
+ # following ttt/titans
877
+
878
+ from torch.func import functional_call, grad, vmap
879
+
880
+ class MemoryMLP(Module):
881
+ def __init__(
882
+ self,
883
+ dim,
884
+ expansion_factor = 4.
885
+ ):
886
+ super().__init__()
887
+
888
+ dim_hidden = int(dim * expansion_factor)
889
+
890
+ self.norm = nn.RMSNorm(dim)
891
+
892
+ # queries, keys, values
893
+
894
+ self.to_queries = Linear(dim, dim, bias = False)
895
+
896
+ self.to_key_values = nn.Sequential(
897
+ Linear(dim, dim * 2, bias = False),
898
+ nn.SiLU()
899
+ )
900
+
901
+ # memory mlp
902
+
903
+ self.mlp = MLP(dim, dim_hidden, dim, activation = nn.SiLU())
904
+
905
+ # initial params
906
+
907
+ self.init_mlp_params = dict(self.mlp.named_parameters())
908
+
909
+ # grad for storing
910
+
911
+ def retrieve_fn(params, queries: Tensor):
912
+ return functional_call(self.mlp, params, queries)
913
+
914
+ def loss_fn(params, inputs: tuple[Tensor, Tensor, Tensor]):
915
+ keys, values, learning_rate = inputs
916
+ pred = functional_call(self.mlp, params, keys)
917
+ loss = F.mse_loss(pred, values, reduction = 'none')
918
+ loss = loss * learning_rate
919
+ return loss.mean()
920
+
921
+ self.grad_fn = vmap(grad(loss_fn), in_dims = (0, (0, 0, 0)))
887
922
 
888
- return embed, next_kv_cache
923
+ self.retrieve_fn = vmap(retrieve_fn, in_dims = (0, 0))
924
+
925
+ # forgetting
926
+
927
+ self.to_forget_gate = nn.Sequential(
928
+ Reduce('b n d -> b d', 'mean'),
929
+ nn.Linear(dim, 1, bias = False),
930
+ Rearrange('b 1 -> b'),
931
+ nn.Sigmoid()
932
+ )
933
+
934
+ # loss weight / learning rate
935
+
936
+ self.to_loss_weight = nn.Linear(dim, 1, bias = False)
937
+
938
+ def get_init_mlp_params(
939
+ self,
940
+ batch_size
941
+ ):
942
+ return {name: repeat(params, '... -> b ...', b = batch_size) for name, params in self.init_mlp_params.items()}
943
+
944
+ def store(
945
+ self,
946
+ tokens, # (b n d)
947
+ memories: dict[str, Tensor] | None = None
948
+ ):
949
+
950
+ batch_size = tokens.shape[0]
951
+
952
+ if not exists(memories):
953
+ memories = self.get_init_mlp_params(batch_size)
954
+
955
+ tokens = self.norm(tokens)
956
+
957
+ keys, values = self.to_key_values(tokens).chunk(2, dim = -1)
958
+
959
+ loss_weight = self.to_loss_weight(tokens)
960
+
961
+ grad = self.grad_fn(memories, (keys, values, loss_weight))
962
+
963
+ # prepare forget
964
+
965
+ forget = self.to_forget_gate(tokens)
966
+
967
+ # update memories
968
+
969
+ next_memories = dict()
970
+
971
+ for param_name, past_memory in memories.items():
972
+ change = grad[param_name]
973
+
974
+ past_memory = einx.multiply('b, b ...', forget, past_memory)
975
+
976
+ next_memories[param_name] = past_memory - change
977
+
978
+ return next_memories
979
+
980
+ def forward(
981
+ self,
982
+ tokens, # (b n d)
983
+ memories: dict[str, Tensor] | None = None
984
+ ):
985
+ batch_size = tokens.shape[0]
986
+
987
+ if not exists(memories):
988
+ memories = self.get_init_mlp_params(batch_size)
989
+
990
+ tokens = self.norm(tokens)
991
+
992
+ queries = self.to_queries(tokens)
993
+
994
+ retrieved = self.retrieve_fn(memories, queries)
995
+
996
+ return retrieved
997
+
998
+ # state embedder
999
+
1000
+ class StateEmbedder(Module):
1001
+ @beartype
1002
+ def __init__(
1003
+ self,
1004
+ dim,
1005
+ dim_state: tuple[int, ...] | list[int] | int,
1006
+ num_internal_states: int | None = None,
1007
+ internal_states_selectors: list[list[int]] | None = None
1008
+ ):
1009
+ super().__init__()
1010
+ dim_hidden = dim * 2
1011
+
1012
+ self.image_to_token = nn.Sequential(
1013
+ Rearrange('b t c h w -> b c t h w'),
1014
+ nn.Conv3d(3, dim_hidden, (1, 7, 7), padding = (0, 3, 3)),
1015
+ nn.ReLU(),
1016
+ nn.Conv3d(dim_hidden, dim_hidden, (1, 3, 3), stride = (1, 2, 2), padding = (0, 1, 1)),
1017
+ nn.ReLU(),
1018
+ nn.Conv3d(dim_hidden, dim_hidden, (1, 3, 3), stride = (1, 2, 2), padding = (0, 1, 1)),
1019
+ Reduce('b c t h w -> b t c', 'mean'),
1020
+ nn.Linear(dim_hidden, dim)
1021
+ )
1022
+
1023
+ dim_states = (dim_state,) if not isinstance(dim_state, (tuple, list)) else dim_state
1024
+
1025
+ self.dim_states = dim_states
1026
+ self.state_to_token = ModuleList([MLP(dim_state, dim, bias = False) for dim_state in dim_states])
1027
+
1028
+ # internal state embeds for each robot
1029
+
1030
+ self.internal_state_embedder = None
1031
+
1032
+ if exists(num_internal_states) and exists(internal_states_selectors):
1033
+ self.internal_state_embedder = Embed(
1034
+ dim,
1035
+ num_continuous = num_internal_states,
1036
+ selectors = internal_states_selectors
1037
+ )
1038
+
1039
+ @property
1040
+ def device(self):
1041
+ return next(self.parameters()).device
1042
+
1043
+ def forward(
1044
+ self,
1045
+ state,
1046
+ state_type,
1047
+ state_id = 0,
1048
+ internal_state = None,
1049
+ internal_state_selector_id: int | None = None
1050
+ ):
1051
+
1052
+ if state_type == 'image':
1053
+ token_embeds = self.image_to_token(state)
1054
+ elif state_type == 'raw':
1055
+ state_to_token = self.state_to_token[state_id]
1056
+ token_embeds = state_to_token(state)
1057
+ else:
1058
+ raise ValueError('invalid state type')
1059
+
1060
+ if (
1061
+ exists(internal_state_selector_id) and
1062
+ exists(internal_state) and
1063
+ exists(self.internal_state_embedder)
1064
+ ):
1065
+ internal_state = internal_state.to(self.device)
1066
+
1067
+ internal_state_embed = self.internal_state_embedder(internal_state, selector_index = internal_state_selector_id)
1068
+
1069
+ token_embeds = token_embeds + internal_state_embed
1070
+
1071
+ return token_embeds
889
1072
 
890
1073
  # class
891
1074
 
1075
+ OneRewardShaper = Callable[..., float | Tensor]
1076
+
1077
+ MaybeOneRewardShaper = OneRewardShaper | None
1078
+
1079
+ @beartype
1080
+ def default_parse_env_reset_out(reset_out: tuple):
1081
+ assert len(reset_out) == 2
1082
+ return dict(zip(('state', 'info'), reset_out))
1083
+
1084
+ @beartype
1085
+ def default_parse_env_step_out(step_out: tuple):
1086
+ assert len(step_out) in {4, 5}
1087
+
1088
+ if len(step_out) == 5:
1089
+ data_dict = dict(zip(('state', 'reward', 'terminated', 'truncated', 'info'), step_out))
1090
+ elif len(step_out) == 4:
1091
+ data_dict = dict(zip(('state', 'reward', 'terminated', 'info'), step_out))
1092
+ data_dict['truncated'] = False
1093
+
1094
+ return data_dict
1095
+
892
1096
  class Locoformer(Module):
893
1097
  def __init__(
894
1098
  self,
895
- embedder: Module | ModuleList | list[Module],
896
- unembedder: Module,
1099
+ embedder: dict | Module,
1100
+ unembedder: dict | Readout,
897
1101
  transformer: dict | TransformerXL,
1102
+ *,
898
1103
  discount_factor = 0.999,
899
1104
  gae_lam = 0.95,
900
1105
  ppo_eps_clip = 0.2,
901
1106
  ppo_entropy_weight = 0.01,
902
1107
  ppo_value_clip = 0.4,
1108
+ ppo_soft_constrain_action_max = None,
1109
+ ppo_soft_constrain_action_loss_weight = 0.1,
903
1110
  dim_value_input = None, # needs to be set for value network to be available
904
1111
  value_network: Module = nn.Identity(),
1112
+ policy_network: Module = nn.Identity(),
1113
+ state_pred_network: Module | None = None,
1114
+ embed_past_action = False,
1115
+ state_pred_loss_weight = 0.05,
905
1116
  reward_range: tuple[float, float] | None = None,
906
- reward_shaping_fns: list[Callable[..., float | Tensor]] | None = None,
1117
+ reward_shaping_fns: (
1118
+ MaybeOneRewardShaper |
1119
+ list[MaybeOneRewardShaper] |
1120
+ list[list[MaybeOneRewardShaper]]
1121
+ ) = None,
907
1122
  num_reward_bins = 32,
908
1123
  hl_gauss_loss_kwargs = dict(),
909
1124
  value_loss_weight = 0.5,
910
1125
  calc_gae_kwargs: dict = dict(),
911
- recurrent_kv_cache = True,
912
- use_spo = False # simple policy optimization https://arxiv.org/abs/2401.16025 - Levine's group (PI) verified it is more stable than PPO
1126
+ parse_env_reset_out: Callable | None = None,
1127
+ parse_env_step_out: Callable | None = None,
1128
+ recurrent_cache = True,
1129
+ use_spo = False, # simple policy optimization https://arxiv.org/abs/2401.16025 - Levine's group (PI) verified it is more stable than PPO
1130
+ asymmetric_spo = False, # https://openreview.net/pdf?id=BA6n0nmagi
1131
+ max_mem_segments = 1
913
1132
  ):
914
1133
  super().__init__()
915
1134
 
916
1135
  if isinstance(transformer, dict):
917
- transformer = TransformerXL(**transformer)
1136
+ transformer = TransformerXL(max_mem_segments = max_mem_segments, **transformer)
918
1137
 
919
1138
  self.transformer = transformer
920
1139
 
921
1140
  # handle state embedder
922
1141
 
923
- if isinstance(embedder, list):
924
- embedder = ModuleList(embedder)
1142
+ if isinstance(embedder, dict):
1143
+ embedder = StateEmbedder(**embedder)
925
1144
 
926
1145
  self.embedder = embedder
927
1146
 
928
1147
  # unembed state to actions or ssl predictions
929
1148
 
1149
+ action_embedder = None
1150
+ if isinstance(unembedder, dict):
1151
+ action_embedder, unembedder = EmbedAndReadout(
1152
+ explicit_single_action_dim_given = True,
1153
+ **unembedder,
1154
+ )
1155
+
930
1156
  self.unembedder = unembedder
931
1157
 
1158
+ # embedding past actions
1159
+
1160
+ self.past_action_embedder = None
1161
+ self.embed_past_action = embed_past_action
1162
+
1163
+ if embed_past_action and exists(action_embedder):
1164
+ self.past_action_embedder = action_embedder
1165
+
1166
+ # attention window related
1167
+
932
1168
  self.fixed_window_size = transformer.fixed_window_size
933
1169
  self.window_size = transformer.window_size
1170
+ self.max_mem_segments = max_mem_segments
1171
+
1172
+ # policy network
1173
+
1174
+ self.policy_network = policy_network
934
1175
 
935
1176
  # determine value network, using HL Gauss Layer
936
1177
 
@@ -953,6 +1194,22 @@ class Locoformer(Module):
953
1194
  **hl_gauss_loss_kwargs
954
1195
  )
955
1196
 
1197
+ # state prediction related
1198
+
1199
+ self.can_pred_state = exists(state_pred_network)
1200
+ self.state_pred_network = state_pred_network
1201
+
1202
+ if exists(state_pred_network):
1203
+ dim_states = self.embedder.dim_states
1204
+ total_dim_states = sum(dim_states)
1205
+
1206
+ selectors = [t.tolist() for t in arange(total_dim_states).split(dim_states)]
1207
+
1208
+ self.state_pred_head = Readout(transformer.dim, num_continuous = total_dim_states, selectors = selectors)
1209
+
1210
+ self.has_state_pred_loss = state_pred_loss_weight > 0.
1211
+ self.state_pred_loss_weight = state_pred_loss_weight
1212
+
956
1213
  # ppo related
957
1214
 
958
1215
  self.discount_factor = discount_factor
@@ -960,6 +1217,9 @@ class Locoformer(Module):
960
1217
  self.ppo_eps_clip = ppo_eps_clip
961
1218
  self.ppo_entropy_weight = ppo_entropy_weight
962
1219
  self.ppo_value_clip = ppo_value_clip
1220
+ self.ppo_soft_constrain_action_max = ppo_soft_constrain_action_max
1221
+ self.ppo_soft_constrain_action_loss_weight = ppo_soft_constrain_action_loss_weight
1222
+
963
1223
  self.value_loss_weight = value_loss_weight
964
1224
 
965
1225
  self.calc_gae_kwargs = calc_gae_kwargs
@@ -968,14 +1228,26 @@ class Locoformer(Module):
968
1228
 
969
1229
  self.use_spo = use_spo
970
1230
 
1231
+ self.asymmetric_spo = asymmetric_spo
1232
+
971
1233
  # maybe recurrent kv cache, from Ding et al. https://arxiv.org/abs/2012.15688
972
1234
 
973
- self.recurrent_kv_cache = recurrent_kv_cache
1235
+ self.recurrent_cache = recurrent_cache
1236
+
1237
+ # environment returns to dictionary
1238
+
1239
+ self.parse_env_reset_out = default(parse_env_reset_out, default_parse_env_reset_out)
1240
+ self.parse_env_step_out = default(parse_env_step_out, default_parse_env_step_out)
974
1241
 
975
1242
  # reward shaping function
976
1243
 
977
1244
  self.has_reward_shaping = exists(reward_shaping_fns)
1245
+
1246
+ if is_bearable(reward_shaping_fns, OneRewardShaper):
1247
+ reward_shaping_fns = [reward_shaping_fns]
1248
+
978
1249
  self.reward_shaping_fns = reward_shaping_fns
1250
+ self.reward_shaping_fns_multiple_envs = is_bearable(reward_shaping_fns, list[list[OneRewardShaper]])
979
1251
 
980
1252
  # loss related
981
1253
 
@@ -986,7 +1258,10 @@ class Locoformer(Module):
986
1258
  return next(self.parameters()).device
987
1259
 
988
1260
  def actor_parameters(self):
989
- return self.unembedder.parameters()
1261
+ return [
1262
+ *self.policy_network.parameters(),
1263
+ *self.unembedder.parameters()
1264
+ ]
990
1265
 
991
1266
  def critic_parameters(self):
992
1267
  if not exists(self.to_value_pred):
@@ -994,6 +1269,69 @@ class Locoformer(Module):
994
1269
 
995
1270
  return self.to_value_pred.parameters()
996
1271
 
1272
+ @beartype
1273
+ def learn(
1274
+ self,
1275
+ optims,
1276
+ accelerator,
1277
+ replay,
1278
+ state_embed_kwargs: dict,
1279
+ action_select_kwargs: dict,
1280
+ state_id_kwarg: dict = dict(),
1281
+ batch_size = 16,
1282
+ epochs = 2,
1283
+ use_vision = False,
1284
+ compute_state_pred_loss = False,
1285
+ state_pred_loss_weight = None,
1286
+ maybe_construct_trial_from_buffer: Callable[[ReplayBuffer], Tensor] | None = None
1287
+ ):
1288
+ state_field = 'state_image' if use_vision else 'state'
1289
+
1290
+ episode_mapping = None
1291
+
1292
+ if exists(maybe_construct_trial_from_buffer):
1293
+ episode_mapping = maybe_construct_trial_from_buffer(replay)
1294
+
1295
+ dataset = replay.dataset()
1296
+
1297
+ if exists(episode_mapping):
1298
+ dataset = RemappedReplayDataset(dataset, episode_mapping)
1299
+
1300
+ dl = replay.dataloader(
1301
+ batch_size = batch_size,
1302
+ dataset = dataset,
1303
+ shuffle = True
1304
+ )
1305
+
1306
+ self, dl, *optims = accelerator.prepare(self, dl, *optims)
1307
+
1308
+ for _ in range(epochs):
1309
+ for data in dl:
1310
+
1311
+ data = SimpleNamespace(**data)
1312
+
1313
+ actor_loss, critic_loss = self.ppo(
1314
+ state = getattr(data, state_field),
1315
+ internal_state = getattr(data, 'internal_state', None),
1316
+ action = data.action,
1317
+ action_log_prob = data.action_log_prob,
1318
+ reward = data.reward,
1319
+ value = data.value,
1320
+ done = data.done,
1321
+ condition = getattr(data, 'condition', None),
1322
+ cond_mask = getattr(data, 'cond_mask', None),
1323
+ episode_lens = data._lens,
1324
+ optims = optims,
1325
+ state_embed_kwargs = state_embed_kwargs,
1326
+ action_select_kwargs = action_select_kwargs,
1327
+ state_id_kwarg = state_id_kwarg,
1328
+ compute_state_pred_loss = compute_state_pred_loss,
1329
+ state_pred_loss_weight = state_pred_loss_weight,
1330
+ accelerator = accelerator
1331
+ )
1332
+
1333
+ accelerator.print(f'actor: {actor_loss.item():.3f} | critic: {critic_loss.item():.3f}')
1334
+
997
1335
  def evolve(
998
1336
  self,
999
1337
  environment,
@@ -1005,47 +1343,66 @@ class Locoformer(Module):
1005
1343
  def ppo(
1006
1344
  self,
1007
1345
  state,
1346
+ internal_state,
1008
1347
  action,
1009
- old_action_log_prob,
1348
+ action_log_prob,
1010
1349
  reward,
1011
- old_value,
1012
- mask,
1350
+ value,
1351
+ done,
1013
1352
  episode_lens,
1014
1353
  condition: Tensor | None = None,
1015
- state_type: int | None = None,
1016
- actor_optim: Optimizer | None = None,
1017
- critic_optim: Optimizer | None = None
1354
+ cond_mask: Tensor | None = None,
1355
+ optims: list[Optimizer] | None = None,
1356
+ state_embed_kwargs: dict = dict(),
1357
+ action_select_kwargs: dict = dict(),
1358
+ state_id_kwarg: dict = dict(),
1359
+ compute_state_pred_loss = True,
1360
+ state_pred_loss_weight = None,
1361
+ accelerator = None,
1362
+ max_grad_norm = 0.5
1018
1363
  ):
1019
- window_size = self.window_size
1020
- total_learnable_tokens = mask.sum().item()
1364
+ state_pred_loss_weight = default(state_pred_loss_weight, self.state_pred_loss_weight)
1021
1365
 
1366
+ window_size = self.window_size
1367
+ mask = ~done
1022
1368
  seq_len = state.shape[1]
1023
- gae_mask = einx.less('j, i -> i j', arange(seq_len, device = self.device), episode_lens)
1369
+ padding_mask = einx.less('j, i -> i j', arange(seq_len, device = self.device), episode_lens)
1370
+ gae_mask = padding_mask & mask
1371
+
1372
+ total_learnable_tokens = gae_mask.sum().item()
1373
+
1374
+ advantage, returns = calc_gae(reward, value, masks = gae_mask, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
1024
1375
 
1025
- advantage, returns = calc_gae(reward, old_value, masks = gae_mask, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
1376
+ advantage = normalize(advantage, mask = gae_mask)
1026
1377
 
1027
- advantage = normalize(advantage)
1378
+ advantage = rearrange(advantage, '... -> ... 1')
1028
1379
 
1029
- data_tensors = (
1030
- state,
1031
- action,
1032
- old_action_log_prob,
1033
- reward,
1034
- old_value,
1035
- mask,
1036
- advantage,
1037
- returns
1380
+ past_action = pad_at_dim(action, (1, -1), dim = -2)
1381
+
1382
+ data_dict = dict(
1383
+ state = state,
1384
+ internal_state = internal_state,
1385
+ action = action,
1386
+ past_action = past_action,
1387
+ old_action_log_prob = action_log_prob,
1388
+ reward = reward,
1389
+ mask = mask,
1390
+ advantage = advantage,
1391
+ returns = returns,
1392
+ windowed_gae_mask = gae_mask,
1393
+ condition = condition,
1394
+ cond_mask = cond_mask
1038
1395
  )
1039
1396
 
1040
- has_condition = exists(condition)
1397
+ num_windows = math.ceil(seq_len / window_size)
1041
1398
 
1042
- if exists(condition):
1043
- data_tensors = (*data_tensors, condition)
1399
+ windowed_data = dict()
1044
1400
 
1045
- windowed_tensors = [
1046
- t.split(window_size, dim = 1) for t in
1047
- data_tensors
1048
- ]
1401
+ for name, tensor in data_dict.items():
1402
+ if exists(tensor):
1403
+ windowed_data[name] = tensor.split(window_size, dim = 1)
1404
+ else:
1405
+ windowed_data[name] = (None,) * num_windows
1049
1406
 
1050
1407
  mean_actor_loss = self.zero.clone()
1051
1408
  mean_critic_loss = self.zero.clone()
@@ -1054,56 +1411,90 @@ class Locoformer(Module):
1054
1411
 
1055
1412
  cache = None
1056
1413
 
1057
- for (
1058
- state,
1059
- action,
1060
- old_action_log_prob,
1061
- reward,
1062
- old_value,
1063
- mask,
1064
- advantage,
1065
- returns,
1066
- *rest
1067
- ) in zip(*windowed_tensors):
1414
+ for window_tensors in zip(*windowed_data.values()):
1068
1415
 
1069
- if has_condition:
1070
- condition, = rest
1416
+ data = SimpleNamespace(**dict(zip(windowed_data.keys(), window_tensors)))
1071
1417
 
1072
- (action_logits, value_logits), cache = self.forward(state, condition = condition, state_type = state_type, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True)
1073
- entropy = calc_entropy(action_logits)
1418
+ ((action_logits, maybe_state_pred), value_logits), cache = self.forward(data.state, past_action = data.past_action if self.embed_past_action else None, state_embed_kwargs = {**state_embed_kwargs, 'internal_state': data.internal_state}, action_select_kwargs = action_select_kwargs, state_id_kwarg = state_id_kwarg, condition = data.condition, cond_mask = data.cond_mask, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True, return_state_pred = True)
1074
1419
 
1075
- action = rearrange(action, 'b t -> b t 1')
1076
- log_prob = action_logits.gather(-1, action)
1077
- log_prob = rearrange(log_prob, 'b t 1 -> b t')
1420
+ log_prob = self.unembedder.log_prob(action_logits, data.action, **action_select_kwargs)
1078
1421
 
1079
1422
  # update actor, classic clipped surrogate loss
1080
1423
 
1081
1424
  eps_clip = self.ppo_eps_clip
1082
- ratio = (log_prob - old_action_log_prob).exp()
1425
+ ratio = (log_prob - data.old_action_log_prob).exp()
1426
+
1427
+ calc_spo = lambda: -(ratio * data.advantage - (data.advantage.abs() * (ratio - 1.).square()) / (2 * eps_clip))
1083
1428
 
1084
- if self.use_spo:
1085
- actor_loss = -(ratio * advantage - (advantage.abs() * (ratio - 1.).square()) / (2 * eps_clip))
1429
+ calc_ppo = lambda: -torch.min(ratio * data.advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * data.advantage)
1430
+
1431
+ if self.asymmetric_spo:
1432
+ actor_loss = torch.where(data.advantage >= 0, calc_ppo(), calc_spo())
1433
+ elif self.use_spo:
1434
+ actor_loss = calc_spo()
1086
1435
  else:
1087
- actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
1436
+ actor_loss = calc_ppo()
1088
1437
 
1089
- actor_loss = actor_loss - self.ppo_entropy_weight * entropy
1438
+ # maybe entropy
1090
1439
 
1091
- windowed_actor_loss = actor_loss[mask].sum() / total_learnable_tokens
1092
- windowed_actor_loss.backward(retain_graph = True)
1440
+ if self.ppo_entropy_weight > 0.:
1441
+ entropy = self.unembedder.entropy(action_logits, **action_select_kwargs)
1093
1442
 
1094
- # update critic
1443
+ if exists(entropy):
1444
+ actor_loss = actor_loss - self.ppo_entropy_weight * entropy
1445
+
1446
+ windowed_actor_loss = actor_loss[data.windowed_gae_mask].sum() / total_learnable_tokens
1447
+
1448
+ # maybe add state prediction
1095
1449
 
1096
- value_loss = self.hl_gauss_loss(value_logits, returns, reduction = 'none')
1450
+ if (
1451
+ exists(maybe_state_pred) and
1452
+ self.has_state_pred_loss and
1453
+ compute_state_pred_loss and
1454
+ data.windowed_gae_mask[:, :-1].any()
1455
+ ):
1456
+ state_pred = maybe_state_pred[:, :-1]
1457
+ state_labels = data.state[:, 1:]
1458
+ loss_mask = data.windowed_gae_mask[:, :-1]
1459
+
1460
+ state_id = state_id_kwarg.get('state_id', 0)
1461
+
1462
+ state_pred_loss = self.state_pred_head.calculate_loss(state_pred, state_labels, selector_index = state_id, return_unreduced_loss = True)
1463
+
1464
+ state_pred_loss = state_pred_loss.mean(dim = -1) # average over state features
1465
+
1466
+ windowed_state_pred_loss = state_pred_loss[loss_mask].sum() / total_learnable_tokens
1467
+
1468
+ windowed_actor_loss = (
1469
+ windowed_actor_loss +
1470
+ windowed_state_pred_loss * state_pred_loss_weight
1471
+ )
1472
+
1473
+ # maybe soft constrain continuous actions
1474
+
1475
+ if (
1476
+ self.ppo_soft_constrain_action_max and
1477
+ self.unembedder.has_continuous
1478
+ ):
1479
+ loss_mask = data.windowed_gae_mask
1480
+
1481
+ soft_constrain_loss = (action_logits[..., 0].abs() - self.ppo_soft_constrain_action_max).relu().pow(2)
1482
+ windowed_soft_constrain_loss = soft_constrain_loss[loss_mask].sum() / total_learnable_tokens
1097
1483
 
1098
- value_clip = self.ppo_value_clip
1099
- value = self.hl_gauss_loss(value_logits)
1484
+ windowed_actor_loss = (
1485
+ windowed_actor_loss +
1486
+ windowed_soft_constrain_loss * self.ppo_soft_constrain_action_loss_weight
1487
+ )
1100
1488
 
1101
- clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
1102
- clipped_value_loss = self.hl_gauss_loss(clipped_value, returns, reduction = 'none')
1489
+ # windowed loss
1103
1490
 
1104
- critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
1491
+ windowed_actor_loss.backward(retain_graph = True)
1492
+
1493
+ # update critic
1105
1494
 
1106
- windowed_critic_loss = critic_loss[mask].sum() / total_learnable_tokens
1495
+ value_loss = self.hl_gauss_loss(value_logits, data.returns, reduction = 'none') * self.value_loss_weight
1496
+
1497
+ windowed_critic_loss = value_loss[data.windowed_gae_mask].sum() / total_learnable_tokens
1107
1498
  windowed_critic_loss.backward(retain_graph = True)
1108
1499
 
1109
1500
  # accumulate
@@ -1113,13 +1504,16 @@ class Locoformer(Module):
1113
1504
 
1114
1505
  # optimizer update
1115
1506
 
1116
- if exists(actor_optim):
1117
- actor_optim.step()
1118
- actor_optim.zero_grad()
1507
+ if exists(optims):
1508
+
1509
+ if exists(accelerator):
1510
+ accelerator.clip_grad_norm_(self.parameters(), max_grad_norm)
1511
+ else:
1512
+ nn.utils.clip_grad_norm_(self.parameters(), max_grad_norm)
1119
1513
 
1120
- if exists(critic_optim):
1121
- critic_optim.step()
1122
- critic_optim.zero_grad()
1514
+ for optim in optims:
1515
+ optim.step()
1516
+ optim.zero_grad()
1123
1517
 
1124
1518
  # return losses for logging
1125
1519
 
@@ -1128,14 +1522,18 @@ class Locoformer(Module):
1128
1522
  def state_and_command_to_rewards(
1129
1523
  self,
1130
1524
  state,
1131
- commands = None
1525
+ commands = None,
1526
+ env_index: int | None = None
1132
1527
  ) -> Tensor:
1133
1528
 
1134
1529
  assert self.has_reward_shaping
1530
+ assert xnor(exists(env_index), self.reward_shaping_fns_multiple_envs), f'`env_index` must be passed in if multiple reward shaping functions are defined, and vice versa (not passed in if only single list of reward shaping functions)'
1135
1531
 
1136
1532
  rewards = []
1137
1533
 
1138
- for fn in self.reward_shaping_fns:
1534
+ reward_shaping_fns = self.reward_shaping_fns[env_index] if exists(env_index) else self.reward_shaping_fns
1535
+
1536
+ for fn in reward_shaping_fns:
1139
1537
  param_names = get_param_names(fn)
1140
1538
  param_names = set(param_names) & {'state', 'command'}
1141
1539
 
@@ -1152,9 +1550,23 @@ class Locoformer(Module):
1152
1550
 
1153
1551
  rewards = [tensor(reward) if not is_tensor(reward) else reward for reward in rewards]
1154
1552
 
1155
- return stack(rewards)
1553
+ assert all([r.numel() == 1 for r in rewards])
1554
+
1555
+ if len(rewards) == 0:
1556
+ return None
1156
1557
 
1157
- def wrap_env_functions(self, env):
1558
+ packed_rewards, _ = pack(rewards, '*')
1559
+ return packed_rewards
1560
+
1561
+ @beartype
1562
+ def wrap_env_functions(
1563
+ self,
1564
+ env,
1565
+ env_output_transforms: dict[str, Callable] = dict(),
1566
+ state_transform: Callable = identity,
1567
+ reward_norm = 1.,
1568
+ command_generator: Callable = always(None)
1569
+ ):
1158
1570
 
1159
1571
  def transform_output(el):
1160
1572
  if isinstance(el, ndarray):
@@ -1167,23 +1579,57 @@ class Locoformer(Module):
1167
1579
  def wrapped_reset(*args, **kwargs):
1168
1580
  env_reset_out = env.reset(*args, **kwargs)
1169
1581
 
1170
- return tree_map(transform_output, env_reset_out)
1582
+ env_reset_out_torch = tree_map(transform_output, env_reset_out)
1583
+
1584
+ env_reset_out_dict = self.parse_env_reset_out(env_reset_out_torch)
1585
+
1586
+ env_reset_out_dict['state'] = state_transform(env_reset_out_dict['state'])
1587
+
1588
+ derived_states = dict()
1171
1589
 
1172
- def wrapped_step(action, *args, **kwargs):
1590
+ for derived_name, transform in env_output_transforms.items():
1591
+ derived_states[derived_name] = transform(env_reset_out_dict, env)
1592
+
1593
+ env_reset_out_dict['derived_state'] = derived_states
1594
+
1595
+ return env_reset_out_dict
1596
+
1597
+ def wrapped_step(action, *args, command = None, env_index = None, **kwargs):
1173
1598
 
1174
1599
  if is_tensor(action):
1175
- action = action.item()
1600
+ if action.numel() == 1:
1601
+ action = action.item()
1602
+ else:
1603
+ action = action.tolist()
1176
1604
 
1177
1605
  env_step_out = env.step(action, *args, **kwargs)
1178
1606
 
1179
1607
  env_step_out_torch = tree_map(transform_output, env_step_out)
1180
1608
 
1181
- if not self.has_reward_shaping:
1182
- return env_step_out_torch
1609
+ env_step_out_dict = self.parse_env_step_out(env_step_out_torch)
1610
+
1611
+ env_step_out_dict['state'] = state_transform(env_step_out_dict['state'])
1612
+
1613
+ env_step_out_dict['reward'] = env_step_out_dict['reward'] / reward_norm
1183
1614
 
1184
- shaped_rewards = self.state_and_command_to_rewards(env_step_out_torch)
1615
+ if self.has_reward_shaping:
1616
+ shaped_rewards = self.state_and_command_to_rewards(env_step_out_dict['state'], command, env_index = env_index)
1185
1617
 
1186
- return env_step_out_torch, shaped_rewards
1618
+ if exists(shaped_rewards):
1619
+ env_step_out_dict['shaped_rewards'] = shaped_rewards
1620
+
1621
+ # add shaped rewards to main reward
1622
+
1623
+ env_step_out_dict['reward'] = env_step_out_dict['reward'] + shaped_rewards.sum()
1624
+
1625
+ derived_states = dict()
1626
+
1627
+ for derived_name, transform in env_output_transforms.items():
1628
+ derived_states[derived_name] = transform(env_step_out_dict, env)
1629
+
1630
+ env_step_out_dict['derived_state'] = derived_states
1631
+
1632
+ return env_step_out_dict
1187
1633
 
1188
1634
  return wrapped_reset, wrapped_step
1189
1635
 
@@ -1196,14 +1642,13 @@ class Locoformer(Module):
1196
1642
  state_time_dim = 1,
1197
1643
  **kwargs
1198
1644
  ):
1199
- window_size = self.window_size
1200
1645
 
1201
1646
  cache = None
1202
1647
 
1203
1648
  def stateful_forward(
1204
1649
  state: Tensor,
1205
1650
  condition: Tensor | None = None,
1206
- state_type: int | None = None,
1651
+ cond_mask: Tensor | None = None,
1207
1652
  **override_kwargs
1208
1653
  ):
1209
1654
  nonlocal cache
@@ -1213,6 +1658,9 @@ class Locoformer(Module):
1213
1658
  if exists(condition):
1214
1659
  condition = condition.to(self.device)
1215
1660
 
1661
+ if exists(cond_mask):
1662
+ cond_mask = cond_mask.to(self.device)
1663
+
1216
1664
  # handle no batch or time, for easier time rolling out against envs
1217
1665
 
1218
1666
  if not has_batch_dim:
@@ -1221,15 +1669,27 @@ class Locoformer(Module):
1221
1669
  if exists(condition):
1222
1670
  condition = rearrange(condition, '... -> 1 ...')
1223
1671
 
1672
+ if exists(cond_mask):
1673
+ cond_mask = rearrange(cond_mask, '... -> 1 ...')
1674
+
1224
1675
  if not has_time_dim:
1225
1676
  state = state.unsqueeze(state_time_dim)
1226
1677
 
1227
1678
  if exists(condition):
1228
1679
  condition = rearrange(condition, '... d -> ... 1 d')
1229
1680
 
1681
+ if exists(cond_mask):
1682
+ cond_mask = cond_mask.unsqueeze(state_time_dim)
1683
+
1230
1684
  # forwards
1231
1685
 
1232
- out, cache = self.forward(state, condition = condition, state_type = state_type, cache = cache, **{**kwargs, **override_kwargs})
1686
+ out, cache = self.forward(
1687
+ state,
1688
+ condition = condition,
1689
+ cond_mask = cond_mask,
1690
+ cache = cache,
1691
+ **{**kwargs, **override_kwargs}
1692
+ )
1233
1693
 
1234
1694
  # maybe remove batch or time
1235
1695
 
@@ -1260,60 +1720,275 @@ class Locoformer(Module):
1260
1720
 
1261
1721
  return stateful_forward, initial_logits
1262
1722
 
1723
+ @beartype
1724
+ def gather_experience_from_env_(
1725
+ self,
1726
+ wrapped_env_functions: tuple[Callable, Callable],
1727
+ replay: ReplayBuffer,
1728
+ embed_past_action = False,
1729
+ max_timesteps = None,
1730
+ use_vision = False,
1731
+ action_select_kwargs: dict = dict(),
1732
+ state_embed_kwargs: dict = dict(),
1733
+ state_id_kwarg: dict = dict(),
1734
+ env_index: int | None = None,
1735
+ state_entropy_bonus_weight = 0.,
1736
+ action_rescale_range: tuple[float, float] | None = None,
1737
+ command_fn: Callable = always(None)
1738
+ ):
1739
+
1740
+ env_reset, env_step = wrapped_env_functions
1741
+
1742
+ reset_out_dict = env_reset()
1743
+ derived, state = pick(reset_out_dict, ('derived_state', 'state'))
1744
+
1745
+ state_image = derived.get('state_image', None)
1746
+ internal_state = derived.get('internal_state', None)
1747
+
1748
+ timestep = 0
1749
+
1750
+ max_timesteps = default(max_timesteps, replay.max_timesteps)
1751
+
1752
+ stateful_forward = self.get_stateful_forward(
1753
+ has_batch_dim = False,
1754
+ has_time_dim = False,
1755
+ inference_mode = True
1756
+ )
1757
+
1758
+ cum_rewards = 0.
1759
+
1760
+ with replay.one_episode() as final_meta_data_store_dict:
1761
+
1762
+ past_action = None
1763
+
1764
+ while True:
1765
+ state_for_model = state_image if use_vision else state
1766
+
1767
+ maybe_command = command_fn(state_for_model)
1768
+
1769
+ # predict next action
1770
+
1771
+ (action_logits, state_pred), value = stateful_forward(
1772
+ state_for_model,
1773
+ condition = maybe_command,
1774
+ cond_mask = tensor(exists(maybe_command)),
1775
+ state_embed_kwargs = {**state_embed_kwargs, 'internal_state': internal_state},
1776
+ action_select_kwargs = action_select_kwargs,
1777
+ state_id_kwarg = state_id_kwarg,
1778
+ past_action = past_action if embed_past_action else None,
1779
+ return_values = True,
1780
+ return_state_pred = True
1781
+ )
1782
+
1783
+ action = self.unembedder.sample(action_logits, **action_select_kwargs)
1784
+
1785
+ # maybe clip
1786
+
1787
+ if exists(action_rescale_range):
1788
+ min_val, max_val = action_rescale_range
1789
+ action = (action + 1.) * 0.5 * (max_val - min_val) + min_val
1790
+
1791
+ # pass to environment
1792
+
1793
+ step_dict = env_step(action, command = maybe_command, env_index = env_index)
1794
+
1795
+ derived, next_state, reward, terminated, truncated = pick(step_dict, ('derived_state', 'state', 'reward', 'terminated', 'truncated'))
1796
+
1797
+ next_state_image = derived.get('state_image', None)
1798
+ next_internal_state = derived.get('internal_state', None)
1799
+
1800
+ # maybe state entropy bonus
1801
+
1802
+ if state_entropy_bonus_weight > 0. and exists(state_pred):
1803
+ state_id = state_id_kwarg.get('state_id', 0)
1804
+ entropy = self.state_pred_head.entropy(state_pred, selector_index = state_id)
1805
+
1806
+ state_entropy_bonus = (entropy * state_entropy_bonus_weight).sum()
1807
+
1808
+ reward = reward + state_entropy_bonus.item() # the entropy is directly related to log variance
1809
+
1810
+ # cum rewards
1811
+
1812
+ cum_rewards += reward
1813
+
1814
+ # increment counters
1815
+ # we will store the step with done=False, as only the bootstrap/boundary node is done=True
1816
+
1817
+ exceeds_max_timesteps = max_timesteps >= 0 and timestep == (max_timesteps - 1)
1818
+ should_stop = truncated or terminated or tensor(exceeds_max_timesteps)
1819
+
1820
+ # get log prob of action
1821
+
1822
+ action_log_prob = self.unembedder.log_prob(action_logits, action, **action_select_kwargs)
1823
+
1824
+ memory = replay.store(
1825
+ state = state,
1826
+ state_image = state_image,
1827
+ action = action,
1828
+ action_log_prob = action_log_prob,
1829
+ internal_state = internal_state,
1830
+ reward = reward,
1831
+ value = value,
1832
+ done = tensor(False),
1833
+ condition = maybe_command,
1834
+ cond_mask = tensor(exists(maybe_command))
1835
+ )
1836
+
1837
+ timestep += 1
1838
+
1839
+ # break if done or exceed max timestep
1840
+ if should_stop:
1841
+
1842
+ # handle bootstrap value, which is a non-learnable timestep added with the next value for GAE
1843
+ # only if terminated signal not detected
1844
+ if not terminated:
1845
+ next_state_for_model = next_state_image if use_vision else next_state
1846
+
1847
+ _, next_value = stateful_forward(next_state_for_model, condition = maybe_command, cond_mask = tensor(exists(maybe_command)), return_values = True, state_embed_kwargs = {**state_embed_kwargs, 'internal_state': internal_state}, state_id_kwarg = state_id_kwarg, action_select_kwargs = action_select_kwargs)
1848
+
1849
+ terminal_node = dict(
1850
+ state = next_state,
1851
+ state_image = next_state_image,
1852
+ internal_state = next_internal_state,
1853
+ value = next_value,
1854
+ reward = next_value,
1855
+ done = True,
1856
+ condition = maybe_command,
1857
+ cond_mask = exists(maybe_command)
1858
+ )
1859
+
1860
+ else:
1861
+ # terminal node - store a step with 0 reward and value, and done=True, to stop GAE scan
1862
+ terminal_node = dict(
1863
+ state = next_state,
1864
+ state_image = next_state_image,
1865
+ internal_state = next_internal_state,
1866
+ value = torch.zeros_like(value),
1867
+ reward = torch.zeros_like(reward),
1868
+ done = True,
1869
+ condition = maybe_command,
1870
+ cond_mask = exists(maybe_command)
1871
+ )
1872
+
1873
+ terminal_node = {key: value for key, value in terminal_node.items() if key in memory._fields}
1874
+
1875
+ terminal_memory = memory._replace(**terminal_node)
1876
+
1877
+ replay.store(**terminal_memory._asdict())
1878
+
1879
+ # store the final cumulative reward into meta data
1880
+
1881
+ final_meta_data_store_dict.update(cum_rewards = cum_rewards)
1882
+
1883
+ break
1884
+
1885
+ state = next_state
1886
+ state_image = next_state_image
1887
+ internal_state = next_internal_state
1888
+
1889
+ past_action = action
1890
+
1891
+ return cum_rewards
1892
+
1263
1893
  def forward(
1264
1894
  self,
1265
1895
  state: Tensor,
1266
- cache: Cache | None = None,
1896
+ cache: TransformerMemory | None = None,
1267
1897
  condition: Tensor | None = None,
1268
- state_type: int | None = None,
1898
+ cond_mask: Tensor | None = None,
1899
+ past_action: Tensor | None = None,
1900
+ state_embed_kwargs: dict = dict(),
1901
+ action_select_kwargs: dict = dict(),
1902
+ state_id_kwarg: dict = dict(),
1269
1903
  detach_cache = False,
1270
1904
  return_values = False,
1905
+ return_state_pred = False,
1271
1906
  return_raw_value_logits = False
1272
1907
  ):
1273
1908
 
1274
1909
  state = state.to(self.device)
1275
1910
 
1911
+ # move condition
1912
+
1913
+ if exists(condition):
1914
+ condition = condition.to(self.device)
1915
+
1276
1916
  # determine which function to invoke for state to token for transformer
1277
1917
 
1278
1918
  state_to_token = self.embedder
1279
1919
 
1280
- if exists(state_type):
1281
- state_to_token = self.embedder[state_type]
1282
-
1283
1920
  # embed
1284
1921
 
1285
- tokens = state_to_token(state)
1922
+ tokens = state_to_token(state, **state_embed_kwargs, **state_id_kwarg)
1286
1923
 
1287
- # time
1924
+ # maybe add past action
1288
1925
 
1289
- time = tokens.shape[-2]
1926
+ # determine if first window and start of sequence
1290
1927
 
1291
- # destruct the cache for the current timestep and the cache
1928
+ total_tokens = cache.total_tokens if exists(cache) else 0
1292
1929
 
1293
- prev_kv_cache = None
1294
- timestep_start = 0
1930
+ is_start_of_sequence = total_tokens == 0
1295
1931
 
1296
- if exists(cache):
1297
- timestep_start, prev_kv_cache = cache
1932
+ # maybe add past action
1933
+
1934
+ if exists(past_action):
1935
+ assert self.embed_past_action
1936
+ past_action_embed = self.past_action_embedder(past_action, **action_select_kwargs)
1937
+
1938
+ if is_start_of_sequence:
1939
+ past_action_embed = pad_at_dim(past_action_embed[..., 1:, :], (1, 0), dim = -2)
1940
+
1941
+ tokens = tokens + past_action_embed
1942
+
1943
+ # time
1944
+
1945
+ time = tokens.shape[-2]
1298
1946
 
1299
1947
  # an assert - make sure during training or inference, forward never gets anything that crosses the window segment boundary, to open up some possibilities with extending memory
1300
1948
 
1301
- assert ((timestep_start % self.window_size) + time) <= self.window_size
1949
+ assert ((total_tokens % self.window_size) + time) <= self.window_size
1302
1950
 
1303
1951
  # attention
1304
1952
 
1305
- embed, kv_cache = self.transformer(tokens, condition = condition, cache = prev_kv_cache, return_kv_cache = True)
1953
+ if not exists(cache):
1954
+ memory_segments = deque(maxlen = self.max_mem_segments)
1955
+ else:
1956
+ total_tokens, kv_cache, gru_cache, mem_mlp_cache, mem_mlp_hidden_states, memory_segments = cache
1957
+
1958
+ attn_cache = cache
1959
+
1960
+ embed, cache = self.transformer(
1961
+ tokens,
1962
+ condition = condition,
1963
+ cond_mask = cond_mask,
1964
+ cache = attn_cache,
1965
+ return_kv_cache = True
1966
+ )
1306
1967
 
1307
1968
  # unembed to actions - in language models this would be the next state
1308
1969
 
1309
- action_logits = self.unembedder(embed)
1970
+ policy_embed = self.policy_network(embed)
1971
+
1972
+ action_logits = self.unembedder(policy_embed, **action_select_kwargs)
1310
1973
 
1311
1974
  out = action_logits
1312
1975
 
1976
+ # maybe return state prediction
1977
+
1978
+ if return_state_pred:
1979
+ state_pred = None
1980
+
1981
+ if self.can_pred_state:
1982
+ state_id = state_id_kwarg.get('state_id', 0)
1983
+ state_pred_embed = self.state_pred_network(embed)
1984
+ state_pred = self.state_pred_head(state_pred_embed, selector_index = state_id)
1985
+
1986
+ out = (out, state_pred)
1987
+
1313
1988
  # maybe detach cache
1314
1989
 
1315
1990
  if detach_cache:
1316
- kv_cache = kv_cache.detach()
1991
+ cache = tree_map_tensor(cache, lambda t: t.detach())
1317
1992
 
1318
1993
  # handle returning of values
1319
1994
 
@@ -1327,20 +2002,22 @@ class Locoformer(Module):
1327
2002
 
1328
2003
  out = (out, values)
1329
2004
 
1330
- # output and cache
2005
+ # handle curtailing kv cache at the right intervals
1331
2006
 
1332
- next_timestep = time + timestep_start
2007
+ total_tokens, kv_cache, gru_cache, mem_mlp_cache, mem_mlp_hidden_states, memory_segments = cache
1333
2008
 
1334
- # handle curtailing kv cache at the right intervals
2009
+ if self.fixed_window_size or divisible_by(total_tokens, self.window_size * 2):
2010
+ kv_cache = kv_cache[..., -self.window_size:, :]
1335
2011
 
1336
- window_size = self.window_size
2012
+ if self.recurrent_cache and divisible_by(total_tokens, self.window_size):
2013
+ kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
1337
2014
 
1338
- if self.fixed_window_size or divisible_by(next_timestep, window_size * 2):
1339
- kv_cache = kv_cache[..., -window_size:, :]
2015
+ if exists(gru_cache):
2016
+ gru_cache = torch.roll(gru_cache, shifts = -1, dims = 0)
1340
2017
 
1341
- # maybe recurrent cache - shift the kv cache from one layer above to the one below, for extending on receptive field of past
2018
+ if divisible_by(total_tokens, self.window_size):
2019
+ memory_segments.append(kv_cache.detach())
1342
2020
 
1343
- if self.recurrent_kv_cache and divisible_by(next_timestep, window_size):
1344
- kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
2021
+ cache = TransformerMemory(total_tokens, kv_cache, gru_cache, mem_mlp_cache, mem_mlp_hidden_states, memory_segments)
1345
2022
 
1346
- return out, (next_timestep, kv_cache)
2023
+ return out, cache