locoformer 0.0.29__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
locoformer/locoformer.py CHANGED
@@ -1,10 +1,16 @@
1
1
  from __future__ import annotations
2
+ import math
2
3
  from typing import Callable
3
- from functools import partial
4
+ from types import SimpleNamespace
5
+ from functools import partial, wraps
4
6
 
5
7
  from pathlib import Path
6
8
  from contextlib import contextmanager
7
- from collections import namedtuple
9
+ from collections import namedtuple, deque
10
+
11
+ from glom import glom
12
+
13
+ from inspect import signature
8
14
 
9
15
  import numpy as np
10
16
  from numpy import ndarray
@@ -14,16 +20,17 @@ from beartype import beartype
14
20
  from beartype.door import is_bearable
15
21
 
16
22
  import torch
17
- from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
23
+ from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy, nested
18
24
  import torch.nn.functional as F
19
25
  from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
20
- from torch.utils._pytree import tree_map
26
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
21
27
  from torch.utils.data import Dataset, DataLoader
28
+ from torch.distributions import Normal
22
29
  from torch.optim import Optimizer
23
30
 
24
31
  import einx
25
- from einops import rearrange, einsum
26
- from einops.layers.torch import Rearrange
32
+ from einops import rearrange, repeat, einsum, reduce, pack
33
+ from einops.layers.torch import Rearrange, Reduce
27
34
 
28
35
  from rotary_embedding_torch import RotaryEmbedding
29
36
 
@@ -31,11 +38,28 @@ from hl_gauss_pytorch import HLGaussLoss
31
38
 
32
39
  from assoc_scan import AssocScan
33
40
 
41
+ from x_mlps_pytorch import MLP
42
+
43
+ from x_evolution import EvoStrategy
44
+
45
+ from discrete_continuous_embed_readout import EmbedAndReadout, Embed, Readout
46
+
47
+ from hyper_connections import mc_get_init_and_expand_reduce_stream_functions
48
+
49
+ from memmap_replay_buffer import ReplayBuffer, ReplayDataset
50
+
34
51
  # constants
35
52
 
36
53
  LinearNoBias = partial(Linear, bias = False)
37
54
 
38
- Cache = namedtuple('Cache', ('curr_timestep', 'kv_cache')) # (int, Tensor)
55
+ TransformerMemory = namedtuple('TransformerMemory', (
56
+ 'total_tokens',
57
+ 'kv_cache',
58
+ 'gru_cache',
59
+ 'mem_mlp_cache',
60
+ 'mem_mlp_hidden_states',
61
+ 'memory_segments'
62
+ ))
39
63
 
40
64
  # helper functions
41
65
 
@@ -45,12 +69,55 @@ def exists(v):
45
69
  def default(v, d):
46
70
  return v if exists(v) else d
47
71
 
72
+ def always(val):
73
+ def inner(*args, **kwargs):
74
+ return val
75
+
76
+ return inner
77
+
78
+ def identity(t, *args, **kwargs):
79
+ return t
80
+
81
+ def pick(data, keys):
82
+ return tuple(data[k] for k in keys)
83
+
48
84
  def first(arr):
49
85
  return arr[0]
50
86
 
87
+ def xnor(x, y):
88
+ return not (x ^ y)
89
+
51
90
  def divisible_by(num, den):
52
91
  return (num % den) == 0
53
92
 
93
+ def get_param_names(fn):
94
+ parameters = signature(fn).parameters
95
+ return list(parameters.keys())
96
+
97
+ def check_has_param_attr(
98
+ param_name,
99
+ param_attr,
100
+ default_value = None
101
+ ):
102
+ def decorator(fn):
103
+ sig = signature(fn)
104
+
105
+ @wraps(fn)
106
+ def inner(*args, **kwargs):
107
+
108
+ bound_args = sig.bind(*args, **kwargs).arguments
109
+
110
+ if not (
111
+ param_name in bound_args and
112
+ hasattr(bound_args[param_name], param_attr)
113
+ ):
114
+ return default_value
115
+
116
+ return fn(*args, **kwargs)
117
+
118
+ return inner
119
+ return decorator
120
+
54
121
  # tensor helpers
55
122
 
56
123
  def log(t, eps = 1e-20):
@@ -62,6 +129,11 @@ def is_empty(t):
62
129
  def tree_map_tensor(x, fn):
63
130
  return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
64
131
 
132
+ def lens_to_mask(lens, max_len):
133
+ device = lens.device
134
+ seq = arange(max_len, device = device)
135
+ return einx.less('j, i -> i j', seq, lens)
136
+
65
137
  def pad_at_dim(
66
138
  t,
67
139
  pad: tuple[int, int],
@@ -75,12 +147,167 @@ def pad_at_dim(
75
147
  zeros = ((0, 0) * dims_from_right)
76
148
  return F.pad(t, (*zeros, *pad), value = value)
77
149
 
78
- def normalize(t, eps = 1e-5):
79
- return (t - t.mean()) / t.std().clamp_min(eps)
150
+ def safe_cat(t, next_t, dim = -1):
151
+ if not exists(t):
152
+ return next_t
153
+
154
+ return cat((t, next_t), dim = dim)
155
+
156
+ def normalize(t, mask = None, eps = 1e-5):
157
+ if exists(mask):
158
+ assert mask.any()
159
+
160
+ t_for_stats = t[mask] if exists(mask) else t
161
+ var, mean = torch.var_mean(t_for_stats)
162
+
163
+ return (t - mean) / var.sqrt().clamp_min(eps)
164
+
165
+ def tensor_to_dict(
166
+ t: Tensor,
167
+ config: tuple[tuple[str, int] | str],
168
+ dim = -1,
169
+ return_dottable = True
170
+ ):
171
+ config = tuple((c, 1) if isinstance(c, str) else c for c in config)
172
+
173
+ names, sizes = zip(*config)
174
+ assert sum(sizes) == t.shape[dim]
175
+
176
+ t = t.split(sizes, dim = dim)
177
+ tensor_dict = dict(zip(names, t))
178
+
179
+ if not return_dottable:
180
+ return tensor_dict
181
+
182
+ return SimpleNamespace(**tensor_dict)
183
+
184
+ # dataset related
185
+
186
+ class RemappedReplayDataset(Dataset):
187
+ def __init__(
188
+ self,
189
+ dataset: ReplayDataset,
190
+ episode_mapping: Tensor | list[list[int]],
191
+ shuffle_episodes = False,
192
+ num_trials_select = None
193
+ ):
194
+ assert len(dataset) > 0
195
+ self.dataset = dataset
196
+
197
+ if is_tensor(episode_mapping):
198
+ assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
199
+ episode_mapping = episode_mapping.tolist()
200
+
201
+ self.episode_mapping = episode_mapping
202
+ self.shuffle_episodes = shuffle_episodes
203
+
204
+ assert not (exists(num_trials_select) and num_trials_select <= 0)
205
+ self.sub_select_trials = exists(num_trials_select)
206
+ self.num_trials_select = num_trials_select
207
+
208
+ def __len__(self):
209
+ return len(self.episode_mapping)
210
+
211
+ def __getitem__(self, idx):
212
+
213
+ episode_indices = self.episode_mapping[idx]
214
+
215
+ episode_indices = tensor(episode_indices)
216
+ episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
217
+
218
+ assert not is_empty(episode_indices)
219
+
220
+ # shuffle the episode indices if either shuffle episodes is turned on, or `num_trial_select` passed in (for sub selecting episodes from a set)
221
+
222
+ if (
223
+ episode_indices.numel() > 1 and
224
+ (self.shuffle_episodes or self.sub_select_trials)
225
+ ):
226
+ num_episodes = len(episode_indices)
227
+ episode_indices = episode_indices[torch.randperm(num_episodes)]
228
+
229
+ # crop out the episodes
230
+
231
+ if self.sub_select_trials:
232
+ episode_indices = episode_indices[:self.num_trials_select]
233
+
234
+ # now select out the episode data and merge along time
235
+
236
+ episode_data = [self.dataset[i] for i in episode_indices.tolist()]
237
+
238
+ episode_lens = stack([data.pop('_lens') for data in episode_data])
239
+
240
+ keys = first(episode_data).keys()
241
+
242
+ values = [list(data.values()) for data in episode_data]
243
+
244
+ values = [cat(field_values) for field_values in zip(*values)] # concat across time
245
+
246
+ multi_episode_data = dict(zip(keys, values))
247
+
248
+ multi_episode_data['_lens'] = episode_lens.sum()
249
+
250
+ multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
251
+
252
+ return multi_episode_data
253
+
254
+ # reward functions - A.2
255
+
256
+ @check_has_param_attr('state', 'v_xy')
257
+ @check_has_param_attr('command', 'v_xy')
258
+ def reward_linear_velocity_command_tracking(
259
+ state,
260
+ command,
261
+ s1 = 1.
262
+ ):
263
+ error = (state.v_xy - command.v_xy).norm(dim = -1).pow(2)
264
+ return torch.exp(-error / s1)
265
+
266
+ @check_has_param_attr('state', 'w_z')
267
+ @check_has_param_attr('command', 'w_z')
268
+ def reward_angular_velocity_command_tracking(
269
+ state,
270
+ command,
271
+ s2 = 1.
272
+ ):
273
+ error = (state.w_z - command.w_z).norm(dim = -1).pow(2)
274
+ return torch.exp(-error / s2)
275
+
276
+ @check_has_param_attr('state', 'v_z')
277
+ def reward_base_linear_velocity_penalty(
278
+ state
279
+ ):
280
+ return -state.v_z.norm(dim = -1).pow(2)
281
+
282
+ @check_has_param_attr('state', 'w_xy')
283
+ def reward_base_angular_velocity_penalty(
284
+ state
285
+ ):
286
+ return -state.w_xy.norm(dim = -1).pow(2)
287
+
288
+ @check_has_param_attr('state', 'x_z')
289
+ def reward_base_height_penalty(
290
+ state,
291
+ x_z_nominal = 0.27
292
+ ):
293
+ return -(state.x_z - x_z_nominal).norm(dim = -1).pow(2)
294
+
295
+ @check_has_param_attr('state', 'joint_q')
296
+ def reward_joint_acceleration_penalty(
297
+ state
298
+ ):
299
+ return -state.joint_q.norm(dim = -1).pow(2)
300
+
301
+ @check_has_param_attr('state', 'tau')
302
+ def reward_torque_penalty(
303
+ state
304
+ ):
305
+ return -state.tau.norm(dim = -1).pow(2)
80
306
 
81
- def calc_entropy(logits):
82
- prob = logits.softmax(dim = -1)
83
- return -(prob * log(prob)).sum(dim = -1)
307
+ def reward_alive(
308
+ state
309
+ ):
310
+ return 1.
84
311
 
85
312
  # generalized advantage estimate
86
313
 
@@ -185,283 +412,87 @@ def create_sliding_mask(
185
412
  create_kwargs = dict(device = device) if exists(device) else dict()
186
413
  return create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True, **create_kwargs)
187
414
 
188
- # data
189
-
190
- def collate_var_time(data):
191
-
192
- datum = first(data)
193
- keys = datum.keys()
194
-
195
- all_tensors = zip(*[datum.values() for datum in data])
196
-
197
- collated_values = []
198
-
199
- for key, tensors in zip(keys, all_tensors):
200
-
201
- # the episode lens have zero dimension - think of a cleaner way to handle this later
202
-
203
- if key != '_lens':
204
-
205
- times = [t.shape[0] for t in tensors]
206
- max_time = max(times)
207
- tensors = [pad_at_dim(t, (0, max_time - t.shape[0]), dim = 0) for t in tensors]
208
-
209
- collated_values.append(stack(tensors))
210
-
211
- return dict(zip(keys, collated_values))
212
-
213
- class ReplayDataset(Dataset):
214
- def __init__(
215
- self,
216
- folder: str | Path,
217
- fields: tuple[str, ...] | None = None
218
- ):
219
- if isinstance(folder, str):
220
- folder = Path(folder)
221
-
222
- episode_lens = folder / 'episode_lens.npy'
223
- self.episode_lens = open_memmap(str(episode_lens), mode = 'r')
224
-
225
- # get indices of non-zero lengthed episodes
226
-
227
- nonzero_episodes = self.episode_lens > 0
228
- self.indices = np.arange(self.episode_lens.shape[-1])[nonzero_episodes]
229
-
230
- # get all data files
231
-
232
- filepaths = [*folder.glob('*.data.npy')]
233
- assert len(filepaths) > 0
234
-
235
- fieldname_to_filepath = {path.name.split('.')[0]: path for path in filepaths}
236
-
237
- fieldnames_from_files = set(fieldname_to_filepath.keys())
238
-
239
- fields = default(fields, fieldnames_from_files)
240
-
241
- self.memmaps = dict()
242
-
243
- for field in fields:
244
- assert field in fieldnames_from_files, f'invalid field {field} - must be one of {fieldnames_from_files}'
245
-
246
- path = fieldname_to_filepath[field]
247
-
248
- self.memmaps[field] = open_memmap(str(path), mode = 'r')
415
+ # normalization + conditioning (needed for the commands to the robot)
249
416
 
250
- def __len__(self):
251
- return len(self.indices)
252
-
253
- def __getitem__(self, idx):
254
- episode_index = self.indices[idx]
255
-
256
- episode_len = self.episode_lens[episode_index]
257
-
258
- data = {field: from_numpy(memmap[episode_index, :episode_len].copy()) for field, memmap in self.memmaps.items()}
259
-
260
- data['_lens'] = tensor(episode_len)
261
-
262
- return data
263
-
264
- class RemappedReplayDataset(Dataset):
417
+ class MaybeAdaRMSNormWrapper(Module):
265
418
  def __init__(
266
419
  self,
267
- dataset: ReplayDataset,
268
- episode_mapping: Tensor | list[list[int]],
269
- shuffle_episodes = False
420
+ fn: Module,
421
+ dim,
422
+ dim_cond = None
270
423
  ):
271
- assert len(dataset) > 0
272
- self.dataset = dataset
273
-
274
- if is_tensor(episode_mapping):
275
- assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
276
- episode_mapping = episode_mapping.tolist()
277
-
278
- self.episode_mapping = episode_mapping
279
- self.shuffle_episodes = shuffle_episodes
280
-
281
- def __len__(self):
282
- return len(self.episode_mapping)
283
-
284
- def __getitem__(self, idx):
285
-
286
- episode_indices = self.episode_mapping[idx]
287
-
288
- episode_indices = tensor(episode_indices)
289
- episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
290
-
291
- assert not is_empty(episode_indices)
292
-
293
- if self.shuffle_episodes and episode_indices.numel() > 1:
294
- num_episodes = len(episode_indices)
295
- episode_indices = episode_indices[torch.randperm(num_episodes)]
296
-
297
- episode_data = [self.dataset[i] for i in episode_indices.tolist()]
298
-
299
- episode_lens = stack([data.pop('_lens') for data in episode_data])
300
-
301
- keys = first(episode_data).keys()
302
-
303
- values = [list(data.values()) for data in episode_data]
304
-
305
- values = [cat(field_values) for field_values in zip(*values)] # concat across time
306
-
307
- multi_episode_data = dict(zip(keys, values))
424
+ super().__init__()
425
+ condition = exists(dim_cond)
308
426
 
309
- multi_episode_data['_lens'] = episode_lens.sum()
427
+ self.fn = fn
428
+ self.norm = nn.RMSNorm(dim, elementwise_affine = not condition)
310
429
 
311
- multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
430
+ self.accept_condition = condition
312
431
 
313
- return multi_episode_data
432
+ if condition:
433
+ self.to_gamma = LinearNoBias(dim_cond, dim)
434
+ self.to_ada_norm_zero = nn.Linear(dim_cond, dim)
314
435
 
315
- class ReplayBuffer:
436
+ nn.init.zeros_(self.to_gamma.weight)
437
+ nn.init.zeros_(self.to_ada_norm_zero.weight)
438
+ nn.init.constant_(self.to_ada_norm_zero.bias, -5.)
316
439
 
317
- @beartype
318
- def __init__(
440
+ def forward(
319
441
  self,
320
- folder: str | Path,
321
- max_episodes: int,
322
- max_timesteps: int,
323
- fields: dict[
324
- str,
325
- str | tuple[str, int | tuple[int, ...]]
326
- ]
442
+ x,
443
+ *args,
444
+ cond = None,
445
+ cond_mask = None,
446
+ **kwargs
327
447
  ):
328
448
 
329
- # folder for data
330
-
331
- if not isinstance(folder, Path):
332
- folder = Path(folder)
333
- folder.mkdir(exist_ok = True)
334
-
335
- self.folder = folder
336
- assert folder.is_dir()
337
-
338
- # keeping track of episode length
339
-
340
- episode_lens = folder / 'episode_lens.npy'
341
-
342
- self.episode_index = 0
343
- self.timestep_index = 0
344
-
345
- self.max_episodes = max_episodes
346
- self.max_timesteps= max_timesteps
449
+ need_cond = self.accept_condition
450
+ has_input_cond = need_cond and exists(cond)
347
451
 
348
- self.episode_lens = open_memmap(str(episode_lens), mode = 'w+', dtype = np.int32, shape = (max_episodes,))
452
+ if exists(cond):
453
+ assert self.accept_condition
349
454
 
350
- # create the memmap for individual data tracks
455
+ prenormed = self.norm(x)
351
456
 
352
- self.shapes = dict()
353
- self.dtypes = dict()
354
- self.memmaps = dict()
355
- self.fieldnames = set(fields.keys())
457
+ if has_input_cond:
458
+ if cond.ndim == 2:
459
+ cond = rearrange(cond, 'b d -> b 1 d')
356
460
 
357
- for field_name, field_info in fields.items():
461
+ cond_scale = self.to_gamma(cond)
358
462
 
359
- # some flexibility
463
+ conditioned = prenormed * cond_scale
360
464
 
361
- field_info = (field_info, ()) if isinstance(field_info, str) else field_info
465
+ # handle a condition mask
362
466
 
363
- dtype_str, shape = field_info
364
- assert dtype_str in {'int', 'float', 'bool'}
365
-
366
- dtype = dict(int = np.int32, float = np.float32, bool = np.bool_)[dtype_str]
367
-
368
- # memmap file
369
-
370
- filepath = folder / f'{field_name}.data.npy'
371
- memmap = open_memmap(str(filepath), mode = 'w+', dtype = dtype, shape = (max_episodes, max_timesteps, *shape))
372
-
373
- self.memmaps[field_name] = memmap
374
- self.shapes[field_name] = shape
375
- self.dtypes[field_name] = dtype
376
-
377
- self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
378
-
379
- def __len__(self):
380
- return (self.episode_lens > 0).sum().item()
381
-
382
- def reset_(self):
383
- self.episode_lens[:] = 0
384
- self.episode_index = 0
385
- self.timestep_index = 0
386
-
387
- def advance_episode(self):
388
- self.episode_index = (self.episode_index + 1) % self.max_episodes
389
- self.timestep_index = 0
390
-
391
- def flush(self):
392
- self.episode_lens[self.episode_index] = self.timestep_index
393
-
394
- for memmap in self.memmaps.values():
395
- memmap.flush()
396
-
397
- self.episode_lens.flush()
398
-
399
- @contextmanager
400
- def one_episode(self):
401
-
402
- yield
403
-
404
- self.flush()
405
- self.advance_episode()
406
-
407
- @beartype
408
- def store_datapoint(
409
- self,
410
- episode_index: int,
411
- timestep_index: int,
412
- name: str,
413
- datapoint: Tensor | ndarray
414
- ):
415
- assert 0 <= episode_index < self.max_episodes
416
- assert 0 <= timestep_index < self.max_timesteps
417
-
418
- if is_tensor(datapoint):
419
- datapoint = datapoint.detach().cpu().numpy()
420
-
421
- assert name in self.fieldnames, f'invalid field name {name} - must be one of {self.fieldnames}'
422
-
423
- assert datapoint.shape == self.shapes[name], f'invalid shape {datapoint.shape} - shape must be {self.shapes[name]}'
424
-
425
- self.memmaps[name][self.episode_index, self.timestep_index] = datapoint
426
-
427
- def store(
428
- self,
429
- **data
430
- ):
431
- assert is_bearable(data, dict[str, Tensor | ndarray])
432
-
433
- assert not self.timestep_index >= self.max_timesteps, 'you exceeded the `max_timesteps` set on the replay buffer'
467
+ if exists(cond_mask):
468
+ prenormed = einx.where('b n, b n d, b n d', cond_mask, conditioned, prenormed)
469
+ else:
470
+ prenormed = conditioned
434
471
 
435
- for name, datapoint in data.items():
472
+ # the main block, either attention or feedforward or whatever
436
473
 
437
- self.store_datapoint(self.episode_index, self.timestep_index, name, datapoint)
474
+ all_fn_out = self.fn(prenormed, *args, **kwargs)
438
475
 
439
- self.timestep_index += 1
476
+ if not has_input_cond:
477
+ return all_fn_out
440
478
 
441
- return self.memory_namedtuple(**data)
479
+ # function may return multiple args
442
480
 
443
- def dataset(
444
- self,
445
- episode_mapping: Tensor | list[list[int]] | None = None,
446
- ) -> Dataset:
447
- self.flush()
481
+ (out, *rest), tree_spec = tree_flatten(all_fn_out)
448
482
 
449
- dataset = ReplayDataset(self.folder)
483
+ scale_out = self.to_ada_norm_zero(cond).sigmoid()
450
484
 
451
- if not exists(episode_mapping):
452
- return dataset
485
+ if exists(cond_mask):
486
+ is_cond = rearrange(cond_mask, '... -> ... 1')
487
+ out = torch.where(is_cond, out * scale_out, out)
488
+ else:
489
+ out = out * scale_out
453
490
 
454
- return RemappedReplayDataset(dataset, episode_mapping)
491
+ # restore
455
492
 
456
- def dataloader(
457
- self,
458
- batch_size,
459
- episode_mapping: Tensor | list[list[int]] | None = None,
460
- **kwargs
461
- ) -> DataLoader:
462
- self.flush()
493
+ all_fn_out = tree_unflatten((out, *rest), tree_spec)
463
494
 
464
- return DataLoader(self.dataset(episode_mapping), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
495
+ return all_fn_out
465
496
 
466
497
  # transformer-xl with ppo
467
498
 
@@ -472,15 +503,13 @@ class Attention(Module):
472
503
  window_size,
473
504
  dim_head = 64,
474
505
  heads = 8,
475
- pre_rmsnorm = True,
476
506
  fixed_window_size = False,
477
- accept_value_residual = False
507
+ accept_value_residual = False,
508
+ max_mem_segments = 1
478
509
  ):
479
510
  super().__init__()
480
511
  self.scale = dim_head ** -0.5
481
512
 
482
- self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
483
-
484
513
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
485
514
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
486
515
 
@@ -512,20 +541,22 @@ class Attention(Module):
512
541
 
513
542
  self.fixed_window_size = fixed_window_size
514
543
  self.window_size = window_size
544
+ self.max_mem_segments = max_mem_segments
545
+
546
+ self.register_buffer('causal_mask', None, persistent = False)
515
547
 
516
548
  def forward(
517
549
  self,
518
550
  tokens,
519
551
  value_residual = None,
520
552
  kv_cache = None,
553
+ past_segments = None,
521
554
  return_kv_cache = False,
522
555
  ):
523
556
  seq_len = tokens.shape[-2]
524
557
 
525
558
  device = tokens.device
526
559
 
527
- tokens = self.norm(tokens)
528
-
529
560
  q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
530
561
 
531
562
  q, k, v = map(self.split_heads, (q, k, v))
@@ -544,6 +575,11 @@ class Attention(Module):
544
575
  k = cat((ck, k), dim = -2)
545
576
  v = cat((cv, v), dim = -2)
546
577
 
578
+ if exists(past_segments):
579
+ pk, pv = past_segments
580
+ k = cat((pk, k), dim = -2)
581
+ v = cat((pv, v), dim = -2)
582
+
547
583
  if return_kv_cache:
548
584
  next_kv_cache = stack((k, v))
549
585
 
@@ -557,9 +593,12 @@ class Attention(Module):
557
593
  i_seq = arange(i, device = device)
558
594
  j_seq = arange(j, device = device) - (j - i)
559
595
  dist = einx.subtract('i, j -> i j', i_seq, j_seq)
560
- causal_mask = (dist < 0) | (dist > self.window_size)
596
+ causal_mask = (dist < 0) | (dist > (self.max_mem_segments * self.window_size))
561
597
  else:
562
- causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
598
+ if not exists(self.causal_mask) or self.causal_mask.shape != (i, j):
599
+ self.causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
600
+
601
+ causal_mask = self.causal_mask
563
602
 
564
603
  sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
565
604
 
@@ -606,6 +645,7 @@ class FeedForward(Module):
606
645
  return self.proj_out(x)
607
646
 
608
647
  class TransformerXL(Module):
648
+ @beartype
609
649
  def __init__(
610
650
  self,
611
651
  dim,
@@ -614,27 +654,83 @@ class TransformerXL(Module):
614
654
  dim_head = 64,
615
655
  heads = 8,
616
656
  expansion_factor = 4.,
657
+ dim_cond = None,
617
658
  final_norm = True,
618
659
  fixed_window_size = False,
660
+ gru_layers = False,
661
+ long_term_mem_layers: tuple[int, ...] = (),
662
+ mem_kwargs: dict = dict(),
663
+ num_residual_streams = 1,
664
+ max_mem_segments = 1
619
665
  ):
620
666
  super().__init__()
667
+ self.dim = dim
668
+
669
+ # memory
670
+
671
+ long_term_mem_layers = set(long_term_mem_layers)
672
+
673
+ assert all([1 <= l <= depth for l in long_term_mem_layers])
674
+
675
+ self.long_term_mem_layers = long_term_mem_layers
676
+ self.num_mem_mlps = len(long_term_mem_layers)
677
+ self.has_mem = self.num_mem_mlps > 0
678
+ self.max_mem_segments = max_mem_segments
679
+
680
+ # hyper connections
681
+
682
+ init_hyper_conn, self.expand_streams, self.reduce_streams = mc_get_init_and_expand_reduce_stream_functions(num_residual_streams)
683
+
684
+ # condition
685
+
686
+ condition = exists(dim_cond)
687
+
688
+ self.to_cond_tokens = MLP(dim_cond, dim * 2, activate_last = True) if exists(dim_cond) else None
689
+
690
+ norm_fn = partial(MaybeAdaRMSNormWrapper, dim = dim, dim_cond = (dim * 2) if condition else None)
691
+
692
+ # layers
621
693
 
622
694
  layers = ModuleList([])
623
695
 
624
696
  for i in range(depth):
625
- is_first = i == 0
697
+ layer = i + 1
698
+ is_first = layer == 1
699
+ has_mem = layer in long_term_mem_layers
700
+
701
+ gru = norm_fn(nn.GRU(dim, dim, batch_first = True)) if gru_layers else None
702
+
703
+ mem = MemoryMLP(dim, **mem_kwargs) if has_mem else None
704
+
705
+ attn = norm_fn(Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first, max_mem_segments = max_mem_segments))
706
+
707
+ ff = norm_fn(FeedForward(dim = dim, expansion_factor = expansion_factor))
626
708
 
627
- attn = Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first)
709
+ # maybe hyper connections
628
710
 
629
- ff = FeedForward(dim = dim, expansion_factor = expansion_factor)
711
+ mem_store_hc = init_hyper_conn(dim = dim, add_branch_out_to_residual = False)
712
+
713
+ if num_residual_streams > 1:
714
+
715
+ attn = init_hyper_conn(dim = dim, branch = attn)
716
+
717
+ ff = init_hyper_conn(dim = dim, branch = ff)
718
+
719
+ if gru_layers:
720
+ gru = init_hyper_conn(dim = dim, branch = gru)
721
+
722
+ if has_mem:
723
+ mem = init_hyper_conn(dim = dim, branch = mem, forward_method_names = ('store',))
630
724
 
631
725
  layers.append(ModuleList([
632
- attn, ff
726
+ gru, mem, mem_store_hc, attn, ff
633
727
  ]))
634
728
 
635
729
  self.layers = layers
636
730
  self.norm = RMSNorm(dim) if final_norm else Identity()
637
731
 
732
+ self.gru_layers = gru_layers
733
+
638
734
  # fixed window size
639
735
 
640
736
  self.fixed_window_size = fixed_window_size
@@ -643,72 +739,439 @@ class TransformerXL(Module):
643
739
  def forward(
644
740
  self,
645
741
  x,
646
- cache = None,
647
- return_kv_cache = False
742
+ cache: TransformerMemory | None = None,
743
+ return_kv_cache = False,
744
+ condition: Tensor | None = None,
745
+ cond_mask: Tensor | None = None
648
746
  ):
747
+ curr_token_seq_len = x.shape[-2]
649
748
 
650
- cache = default(cache, (None,) * len(self.layers))
749
+ # cache and residuals
750
+
751
+ num_layers = len(self.layers)
752
+
753
+ # extract variables from cache
754
+
755
+ is_first_window = True
756
+ total_tokens = 0
757
+ kv_cache = gru_cache = mem_mlp_cache = mem_mlp_hidden_states = memory_segments = None
758
+
759
+ if exists(cache):
760
+ total_tokens, kv_cache, gru_cache, mem_mlp_cache, mem_mlp_hidden_states, memory_segments = cache
761
+ is_first_window = total_tokens < self.window_size
762
+ else:
763
+ memory_segments = deque(maxlen = self.max_mem_segments)
764
+
765
+ # handle memory segments
766
+
767
+ past_segments = None
768
+ if len(memory_segments) > 0:
769
+ past_segments = stack(list(memory_segments), dim = 0)
770
+ past_segments = rearrange(past_segments, 'l depth kv b h n d -> depth kv b h (l n) d')
771
+
772
+ kv_cache = default(kv_cache, (None,) * num_layers)
773
+ gru_cache = default(gru_cache, (None,) * num_layers)
774
+ mem_mlp_cache = default(mem_mlp_cache, (None,) * num_layers)
775
+ mem_mlp_hidden_states = default(mem_mlp_hidden_states, (None,) * num_layers)
776
+
777
+ # prepare next cache
778
+
779
+ next_kv_caches = []
780
+ next_gru_hiddens = [] if self.gru_layers else None
781
+ next_mem_mlp_cache = [] if self.has_mem else None
782
+ next_mem_mlp_hidden_states = [] if self.has_mem else None
783
+ next_total_tokens = total_tokens + curr_token_seq_len
784
+
785
+ is_window_boundary = divisible_by(next_total_tokens, self.window_size)
651
786
 
652
- next_kv_caches = []
653
787
  value_residual = None
654
788
 
655
- for (attn, ff), kv_cache in zip(self.layers, cache):
789
+ # handle condition
790
+
791
+ cond_tokens = None
792
+
793
+ if exists(condition):
794
+ assert exists(self.to_cond_tokens)
795
+ cond_tokens = self.to_cond_tokens(condition)
796
+
797
+ cond_kwargs = dict(cond = cond_tokens, cond_mask = cond_mask)
798
+
799
+ # hc expand
800
+
801
+ x = self.expand_streams(x)
802
+
803
+ # layers
804
+
805
+ for layer_index, ((maybe_gru, maybe_mem, maybe_mem_store_hc, attn, ff), layer_gru_cache, layer_mem_mlp, layer_kv_cache, layer_hidden_states) in enumerate(zip(self.layers, gru_cache, mem_mlp_cache, kv_cache, mem_mlp_hidden_states)):
806
+
807
+ # handle maybe rnn
656
808
 
657
- attn_out, (next_kv_cache, values) = attn(x, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
809
+ if exists(maybe_gru):
810
+ x, gru_hiddens = maybe_gru(x, layer_gru_cache, **cond_kwargs)
658
811
 
659
- x = attn_out + x
660
- x = ff(x) + x
812
+ next_gru_hiddens.append(gru_hiddens)
813
+
814
+ # maybe handle retrieving
815
+
816
+ is_mem_layer = exists(maybe_mem)
817
+
818
+ if (
819
+ not is_first_window and
820
+ is_mem_layer
821
+ ):
822
+ x = maybe_mem(x, layer_mem_mlp)
823
+
824
+ # attention
825
+
826
+ layer_past_segments = None
827
+ if exists(past_segments):
828
+ layer_past_segments = past_segments[layer_index]
829
+
830
+ x, (next_kv_cache, values) = attn(x, **cond_kwargs, value_residual = value_residual, kv_cache = layer_kv_cache, past_segments = layer_past_segments, return_kv_cache = True)
831
+
832
+ # handle storing of memory
833
+
834
+ if self.has_mem:
835
+ next_mem_mlp = layer_mem_mlp
836
+ next_layer_hidden_states = layer_hidden_states
837
+
838
+ if is_mem_layer:
839
+ # accumulate hidden states
840
+ next_layer_hidden_states = safe_cat(layer_hidden_states, x, dim = -2)
841
+
842
+ if is_window_boundary:
843
+ mem_store_input, _ = maybe_mem_store_hc(next_layer_hidden_states)
844
+
845
+ next_mem_mlp = maybe_mem.store(mem_store_input, layer_mem_mlp)
846
+ next_layer_hidden_states = None
847
+
848
+ next_mem_mlp_cache.append(next_mem_mlp)
849
+ next_mem_mlp_hidden_states.append(next_layer_hidden_states)
850
+
851
+ # feedforward
852
+
853
+ x = ff(x, **cond_kwargs)
661
854
 
662
855
  next_kv_caches.append(next_kv_cache)
663
856
  value_residual = default(value_residual, values)
664
857
 
665
- embed = self.norm(x)
858
+ # hc reduce
666
859
 
667
- if not return_kv_cache:
668
- return embed
860
+ x = self.reduce_streams(x)
861
+
862
+ # norm
863
+
864
+ embed = self.norm(x)
669
865
 
670
866
  next_kv_cache = stack(next_kv_caches)
671
867
 
672
- next_kv_cache = next_kv_cache[..., -self.window_size:, :]
868
+ if exists(next_gru_hiddens):
869
+ next_gru_hiddens = stack(next_gru_hiddens)
870
+
871
+ next_cache = TransformerMemory(next_total_tokens, next_kv_cache, next_gru_hiddens, next_mem_mlp_cache, next_mem_mlp_hidden_states, memory_segments)
872
+
873
+ return embed, next_cache
874
+
875
+ # simple 2 layer memory mlp
876
+ # following ttt/titans
877
+
878
+ from torch.func import functional_call, grad, vmap
879
+
880
+ class MemoryMLP(Module):
881
+ def __init__(
882
+ self,
883
+ dim,
884
+ expansion_factor = 4.
885
+ ):
886
+ super().__init__()
887
+
888
+ dim_hidden = int(dim * expansion_factor)
889
+
890
+ self.norm = nn.RMSNorm(dim)
891
+
892
+ # queries, keys, values
893
+
894
+ self.to_queries = Linear(dim, dim, bias = False)
895
+
896
+ self.to_key_values = nn.Sequential(
897
+ Linear(dim, dim * 2, bias = False),
898
+ nn.SiLU()
899
+ )
900
+
901
+ # memory mlp
902
+
903
+ self.mlp = MLP(dim, dim_hidden, dim, activation = nn.SiLU())
904
+
905
+ # initial params
906
+
907
+ self.init_mlp_params = dict(self.mlp.named_parameters())
908
+
909
+ # grad for storing
910
+
911
+ def retrieve_fn(params, queries: Tensor):
912
+ return functional_call(self.mlp, params, queries)
913
+
914
+ def loss_fn(params, inputs: tuple[Tensor, Tensor, Tensor]):
915
+ keys, values, learning_rate = inputs
916
+ pred = functional_call(self.mlp, params, keys)
917
+ loss = F.mse_loss(pred, values, reduction = 'none')
918
+ loss = loss * learning_rate
919
+ return loss.mean()
920
+
921
+ self.grad_fn = vmap(grad(loss_fn), in_dims = (0, (0, 0, 0)))
922
+
923
+ self.retrieve_fn = vmap(retrieve_fn, in_dims = (0, 0))
924
+
925
+ # forgetting
926
+
927
+ self.to_forget_gate = nn.Sequential(
928
+ Reduce('b n d -> b d', 'mean'),
929
+ nn.Linear(dim, 1, bias = False),
930
+ Rearrange('b 1 -> b'),
931
+ nn.Sigmoid()
932
+ )
933
+
934
+ # loss weight / learning rate
935
+
936
+ self.to_loss_weight = nn.Linear(dim, 1, bias = False)
937
+
938
+ def get_init_mlp_params(
939
+ self,
940
+ batch_size
941
+ ):
942
+ return {name: repeat(params, '... -> b ...', b = batch_size) for name, params in self.init_mlp_params.items()}
943
+
944
+ def store(
945
+ self,
946
+ tokens, # (b n d)
947
+ memories: dict[str, Tensor] | None = None
948
+ ):
949
+
950
+ batch_size = tokens.shape[0]
951
+
952
+ if not exists(memories):
953
+ memories = self.get_init_mlp_params(batch_size)
954
+
955
+ tokens = self.norm(tokens)
956
+
957
+ keys, values = self.to_key_values(tokens).chunk(2, dim = -1)
958
+
959
+ loss_weight = self.to_loss_weight(tokens)
960
+
961
+ grad = self.grad_fn(memories, (keys, values, loss_weight))
962
+
963
+ # prepare forget
964
+
965
+ forget = self.to_forget_gate(tokens)
966
+
967
+ # update memories
968
+
969
+ next_memories = dict()
970
+
971
+ for param_name, past_memory in memories.items():
972
+ change = grad[param_name]
973
+
974
+ past_memory = einx.multiply('b, b ...', forget, past_memory)
975
+
976
+ next_memories[param_name] = past_memory - change
977
+
978
+ return next_memories
979
+
980
+ def forward(
981
+ self,
982
+ tokens, # (b n d)
983
+ memories: dict[str, Tensor] | None = None
984
+ ):
985
+ batch_size = tokens.shape[0]
986
+
987
+ if not exists(memories):
988
+ memories = self.get_init_mlp_params(batch_size)
989
+
990
+ tokens = self.norm(tokens)
991
+
992
+ queries = self.to_queries(tokens)
993
+
994
+ retrieved = self.retrieve_fn(memories, queries)
995
+
996
+ return retrieved
997
+
998
+ # state embedder
999
+
1000
+ class StateEmbedder(Module):
1001
+ @beartype
1002
+ def __init__(
1003
+ self,
1004
+ dim,
1005
+ dim_state: tuple[int, ...] | list[int] | int,
1006
+ num_internal_states: int | None = None,
1007
+ internal_states_selectors: list[list[int]] | None = None
1008
+ ):
1009
+ super().__init__()
1010
+ dim_hidden = dim * 2
1011
+
1012
+ self.image_to_token = nn.Sequential(
1013
+ Rearrange('b t c h w -> b c t h w'),
1014
+ nn.Conv3d(3, dim_hidden, (1, 7, 7), padding = (0, 3, 3)),
1015
+ nn.ReLU(),
1016
+ nn.Conv3d(dim_hidden, dim_hidden, (1, 3, 3), stride = (1, 2, 2), padding = (0, 1, 1)),
1017
+ nn.ReLU(),
1018
+ nn.Conv3d(dim_hidden, dim_hidden, (1, 3, 3), stride = (1, 2, 2), padding = (0, 1, 1)),
1019
+ Reduce('b c t h w -> b t c', 'mean'),
1020
+ nn.Linear(dim_hidden, dim)
1021
+ )
1022
+
1023
+ dim_states = (dim_state,) if not isinstance(dim_state, (tuple, list)) else dim_state
1024
+
1025
+ self.dim_states = dim_states
1026
+ self.state_to_token = ModuleList([MLP(dim_state, dim, bias = False) for dim_state in dim_states])
1027
+
1028
+ # internal state embeds for each robot
1029
+
1030
+ self.internal_state_embedder = None
1031
+
1032
+ if exists(num_internal_states) and exists(internal_states_selectors):
1033
+ self.internal_state_embedder = Embed(
1034
+ dim,
1035
+ num_continuous = num_internal_states,
1036
+ selectors = internal_states_selectors
1037
+ )
1038
+
1039
+ @property
1040
+ def device(self):
1041
+ return next(self.parameters()).device
1042
+
1043
+ def forward(
1044
+ self,
1045
+ state,
1046
+ state_type,
1047
+ state_id = 0,
1048
+ internal_state = None,
1049
+ internal_state_selector_id: int | None = None
1050
+ ):
1051
+
1052
+ if state_type == 'image':
1053
+ token_embeds = self.image_to_token(state)
1054
+ elif state_type == 'raw':
1055
+ state_to_token = self.state_to_token[state_id]
1056
+ token_embeds = state_to_token(state)
1057
+ else:
1058
+ raise ValueError('invalid state type')
673
1059
 
674
- return embed, next_kv_cache
1060
+ if (
1061
+ exists(internal_state_selector_id) and
1062
+ exists(internal_state) and
1063
+ exists(self.internal_state_embedder)
1064
+ ):
1065
+ internal_state = internal_state.to(self.device)
1066
+
1067
+ internal_state_embed = self.internal_state_embedder(internal_state, selector_index = internal_state_selector_id)
1068
+
1069
+ token_embeds = token_embeds + internal_state_embed
1070
+
1071
+ return token_embeds
675
1072
 
676
1073
  # class
677
1074
 
1075
+ OneRewardShaper = Callable[..., float | Tensor]
1076
+
1077
+ MaybeOneRewardShaper = OneRewardShaper | None
1078
+
1079
+ @beartype
1080
+ def default_parse_env_reset_out(reset_out: tuple):
1081
+ assert len(reset_out) == 2
1082
+ return dict(zip(('state', 'info'), reset_out))
1083
+
1084
+ @beartype
1085
+ def default_parse_env_step_out(step_out: tuple):
1086
+ assert len(step_out) in {4, 5}
1087
+
1088
+ if len(step_out) == 5:
1089
+ data_dict = dict(zip(('state', 'reward', 'terminated', 'truncated', 'info'), step_out))
1090
+ elif len(step_out) == 4:
1091
+ data_dict = dict(zip(('state', 'reward', 'terminated', 'info'), step_out))
1092
+ data_dict['truncated'] = False
1093
+
1094
+ return data_dict
1095
+
678
1096
  class Locoformer(Module):
679
1097
  def __init__(
680
1098
  self,
681
- embedder: Module,
682
- unembedder: Module,
1099
+ embedder: dict | Module,
1100
+ unembedder: dict | Readout,
683
1101
  transformer: dict | TransformerXL,
1102
+ *,
684
1103
  discount_factor = 0.999,
685
1104
  gae_lam = 0.95,
686
1105
  ppo_eps_clip = 0.2,
687
1106
  ppo_entropy_weight = 0.01,
688
1107
  ppo_value_clip = 0.4,
689
- dim_value_input = None, # needs to be set for value network to be available
1108
+ ppo_soft_constrain_action_max = None,
1109
+ ppo_soft_constrain_action_loss_weight = 0.1,
1110
+ dim_value_input = None, # needs to be set for value network to be available
690
1111
  value_network: Module = nn.Identity(),
1112
+ policy_network: Module = nn.Identity(),
1113
+ state_pred_network: Module | None = None,
1114
+ embed_past_action = False,
1115
+ state_pred_loss_weight = 0.05,
691
1116
  reward_range: tuple[float, float] | None = None,
692
- reward_shaping_fns: list[Callable[[Tensor], float | Tensor]] | None = None,
1117
+ reward_shaping_fns: (
1118
+ MaybeOneRewardShaper |
1119
+ list[MaybeOneRewardShaper] |
1120
+ list[list[MaybeOneRewardShaper]]
1121
+ ) = None,
693
1122
  num_reward_bins = 32,
694
1123
  hl_gauss_loss_kwargs = dict(),
695
1124
  value_loss_weight = 0.5,
696
1125
  calc_gae_kwargs: dict = dict(),
697
- recurrent_kv_cache = True,
698
- use_spo = False # simple policy optimization https://arxiv.org/abs/2401.16025 - Levine's group (PI) verified it is more stable than PPO
1126
+ parse_env_reset_out: Callable | None = None,
1127
+ parse_env_step_out: Callable | None = None,
1128
+ recurrent_cache = True,
1129
+ use_spo = False, # simple policy optimization https://arxiv.org/abs/2401.16025 - Levine's group (PI) verified it is more stable than PPO
1130
+ asymmetric_spo = False, # https://openreview.net/pdf?id=BA6n0nmagi
1131
+ max_mem_segments = 1
699
1132
  ):
700
1133
  super().__init__()
701
1134
 
702
1135
  if isinstance(transformer, dict):
703
- transformer = TransformerXL(**transformer)
1136
+ transformer = TransformerXL(max_mem_segments = max_mem_segments, **transformer)
704
1137
 
705
1138
  self.transformer = transformer
706
1139
 
1140
+ # handle state embedder
1141
+
1142
+ if isinstance(embedder, dict):
1143
+ embedder = StateEmbedder(**embedder)
1144
+
707
1145
  self.embedder = embedder
1146
+
1147
+ # unembed state to actions or ssl predictions
1148
+
1149
+ action_embedder = None
1150
+ if isinstance(unembedder, dict):
1151
+ action_embedder, unembedder = EmbedAndReadout(
1152
+ explicit_single_action_dim_given = True,
1153
+ **unembedder,
1154
+ )
1155
+
708
1156
  self.unembedder = unembedder
709
1157
 
1158
+ # embedding past actions
1159
+
1160
+ self.past_action_embedder = None
1161
+ self.embed_past_action = embed_past_action
1162
+
1163
+ if embed_past_action and exists(action_embedder):
1164
+ self.past_action_embedder = action_embedder
1165
+
1166
+ # attention window related
1167
+
710
1168
  self.fixed_window_size = transformer.fixed_window_size
711
1169
  self.window_size = transformer.window_size
1170
+ self.max_mem_segments = max_mem_segments
1171
+
1172
+ # policy network
1173
+
1174
+ self.policy_network = policy_network
712
1175
 
713
1176
  # determine value network, using HL Gauss Layer
714
1177
 
@@ -731,6 +1194,22 @@ class Locoformer(Module):
731
1194
  **hl_gauss_loss_kwargs
732
1195
  )
733
1196
 
1197
+ # state prediction related
1198
+
1199
+ self.can_pred_state = exists(state_pred_network)
1200
+ self.state_pred_network = state_pred_network
1201
+
1202
+ if exists(state_pred_network):
1203
+ dim_states = self.embedder.dim_states
1204
+ total_dim_states = sum(dim_states)
1205
+
1206
+ selectors = [t.tolist() for t in arange(total_dim_states).split(dim_states)]
1207
+
1208
+ self.state_pred_head = Readout(transformer.dim, num_continuous = total_dim_states, selectors = selectors)
1209
+
1210
+ self.has_state_pred_loss = state_pred_loss_weight > 0.
1211
+ self.state_pred_loss_weight = state_pred_loss_weight
1212
+
734
1213
  # ppo related
735
1214
 
736
1215
  self.discount_factor = discount_factor
@@ -738,6 +1217,9 @@ class Locoformer(Module):
738
1217
  self.ppo_eps_clip = ppo_eps_clip
739
1218
  self.ppo_entropy_weight = ppo_entropy_weight
740
1219
  self.ppo_value_clip = ppo_value_clip
1220
+ self.ppo_soft_constrain_action_max = ppo_soft_constrain_action_max
1221
+ self.ppo_soft_constrain_action_loss_weight = ppo_soft_constrain_action_loss_weight
1222
+
741
1223
  self.value_loss_weight = value_loss_weight
742
1224
 
743
1225
  self.calc_gae_kwargs = calc_gae_kwargs
@@ -746,14 +1228,26 @@ class Locoformer(Module):
746
1228
 
747
1229
  self.use_spo = use_spo
748
1230
 
749
- # maybe recurrent kv cache (todo: find and cite this paper from ages ago)
1231
+ self.asymmetric_spo = asymmetric_spo
1232
+
1233
+ # maybe recurrent kv cache, from Ding et al. https://arxiv.org/abs/2012.15688
1234
+
1235
+ self.recurrent_cache = recurrent_cache
750
1236
 
751
- self.recurrent_kv_cache = recurrent_kv_cache
1237
+ # environment returns to dictionary
1238
+
1239
+ self.parse_env_reset_out = default(parse_env_reset_out, default_parse_env_reset_out)
1240
+ self.parse_env_step_out = default(parse_env_step_out, default_parse_env_step_out)
752
1241
 
753
1242
  # reward shaping function
754
1243
 
755
1244
  self.has_reward_shaping = exists(reward_shaping_fns)
1245
+
1246
+ if is_bearable(reward_shaping_fns, OneRewardShaper):
1247
+ reward_shaping_fns = [reward_shaping_fns]
1248
+
756
1249
  self.reward_shaping_fns = reward_shaping_fns
1250
+ self.reward_shaping_fns_multiple_envs = is_bearable(reward_shaping_fns, list[list[OneRewardShaper]])
757
1251
 
758
1252
  # loss related
759
1253
 
@@ -764,7 +1258,10 @@ class Locoformer(Module):
764
1258
  return next(self.parameters()).device
765
1259
 
766
1260
  def actor_parameters(self):
767
- return self.unembedder.parameters()
1261
+ return [
1262
+ *self.policy_network.parameters(),
1263
+ *self.unembedder.parameters()
1264
+ ]
768
1265
 
769
1266
  def critic_parameters(self):
770
1267
  if not exists(self.to_value_pred):
@@ -772,41 +1269,140 @@ class Locoformer(Module):
772
1269
 
773
1270
  return self.to_value_pred.parameters()
774
1271
 
1272
+ @beartype
1273
+ def learn(
1274
+ self,
1275
+ optims,
1276
+ accelerator,
1277
+ replay,
1278
+ state_embed_kwargs: dict,
1279
+ action_select_kwargs: dict,
1280
+ state_id_kwarg: dict = dict(),
1281
+ batch_size = 16,
1282
+ epochs = 2,
1283
+ use_vision = False,
1284
+ compute_state_pred_loss = False,
1285
+ state_pred_loss_weight = None,
1286
+ maybe_construct_trial_from_buffer: Callable[[ReplayBuffer], Tensor] | None = None
1287
+ ):
1288
+ state_field = 'state_image' if use_vision else 'state'
1289
+
1290
+ episode_mapping = None
1291
+
1292
+ if exists(maybe_construct_trial_from_buffer):
1293
+ episode_mapping = maybe_construct_trial_from_buffer(replay)
1294
+
1295
+ dataset = replay.dataset()
1296
+
1297
+ if exists(episode_mapping):
1298
+ dataset = RemappedReplayDataset(dataset, episode_mapping)
1299
+
1300
+ dl = replay.dataloader(
1301
+ batch_size = batch_size,
1302
+ dataset = dataset,
1303
+ shuffle = True
1304
+ )
1305
+
1306
+ self, dl, *optims = accelerator.prepare(self, dl, *optims)
1307
+
1308
+ for _ in range(epochs):
1309
+ for data in dl:
1310
+
1311
+ data = SimpleNamespace(**data)
1312
+
1313
+ actor_loss, critic_loss = self.ppo(
1314
+ state = getattr(data, state_field),
1315
+ internal_state = getattr(data, 'internal_state', None),
1316
+ action = data.action,
1317
+ action_log_prob = data.action_log_prob,
1318
+ reward = data.reward,
1319
+ value = data.value,
1320
+ done = data.done,
1321
+ condition = getattr(data, 'condition', None),
1322
+ cond_mask = getattr(data, 'cond_mask', None),
1323
+ episode_lens = data._lens,
1324
+ optims = optims,
1325
+ state_embed_kwargs = state_embed_kwargs,
1326
+ action_select_kwargs = action_select_kwargs,
1327
+ state_id_kwarg = state_id_kwarg,
1328
+ compute_state_pred_loss = compute_state_pred_loss,
1329
+ state_pred_loss_weight = state_pred_loss_weight,
1330
+ accelerator = accelerator
1331
+ )
1332
+
1333
+ accelerator.print(f'actor: {actor_loss.item():.3f} | critic: {critic_loss.item():.3f}')
1334
+
1335
+ def evolve(
1336
+ self,
1337
+ environment,
1338
+ **kwargs
1339
+ ):
1340
+ evo_strat = EvoStrategy(self, environment = environment, **kwargs)
1341
+ evo_strat()
1342
+
775
1343
  def ppo(
776
1344
  self,
777
1345
  state,
1346
+ internal_state,
778
1347
  action,
779
- old_action_log_prob,
1348
+ action_log_prob,
780
1349
  reward,
781
- old_value,
782
- mask,
1350
+ value,
1351
+ done,
783
1352
  episode_lens,
784
- actor_optim: Optimizer | None = None,
785
- critic_optim: Optimizer | None = None
1353
+ condition: Tensor | None = None,
1354
+ cond_mask: Tensor | None = None,
1355
+ optims: list[Optimizer] | None = None,
1356
+ state_embed_kwargs: dict = dict(),
1357
+ action_select_kwargs: dict = dict(),
1358
+ state_id_kwarg: dict = dict(),
1359
+ compute_state_pred_loss = True,
1360
+ state_pred_loss_weight = None,
1361
+ accelerator = None,
1362
+ max_grad_norm = 0.5
786
1363
  ):
787
- window_size = self.window_size
788
- total_learnable_tokens = mask.sum().item()
1364
+ state_pred_loss_weight = default(state_pred_loss_weight, self.state_pred_loss_weight)
789
1365
 
1366
+ window_size = self.window_size
1367
+ mask = ~done
790
1368
  seq_len = state.shape[1]
791
- gae_mask = einx.less('j, i -> i j', arange(seq_len, device = self.device), episode_lens)
1369
+ padding_mask = einx.less('j, i -> i j', arange(seq_len, device = self.device), episode_lens)
1370
+ gae_mask = padding_mask & mask
792
1371
 
793
- advantage, returns = calc_gae(reward, old_value, masks = gae_mask, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
1372
+ total_learnable_tokens = gae_mask.sum().item()
794
1373
 
795
- advantage = normalize(advantage)
1374
+ advantage, returns = calc_gae(reward, value, masks = gae_mask, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
796
1375
 
797
- windowed_tensors = [
798
- t.split(window_size, dim = 1) for t in
799
- (
800
- state,
801
- action,
802
- old_action_log_prob,
803
- reward,
804
- old_value,
805
- mask,
806
- advantage,
807
- returns
808
- )
809
- ]
1376
+ advantage = normalize(advantage, mask = gae_mask)
1377
+
1378
+ advantage = rearrange(advantage, '... -> ... 1')
1379
+
1380
+ past_action = pad_at_dim(action, (1, -1), dim = -2)
1381
+
1382
+ data_dict = dict(
1383
+ state = state,
1384
+ internal_state = internal_state,
1385
+ action = action,
1386
+ past_action = past_action,
1387
+ old_action_log_prob = action_log_prob,
1388
+ reward = reward,
1389
+ mask = mask,
1390
+ advantage = advantage,
1391
+ returns = returns,
1392
+ windowed_gae_mask = gae_mask,
1393
+ condition = condition,
1394
+ cond_mask = cond_mask
1395
+ )
1396
+
1397
+ num_windows = math.ceil(seq_len / window_size)
1398
+
1399
+ windowed_data = dict()
1400
+
1401
+ for name, tensor in data_dict.items():
1402
+ if exists(tensor):
1403
+ windowed_data[name] = tensor.split(window_size, dim = 1)
1404
+ else:
1405
+ windowed_data[name] = (None,) * num_windows
810
1406
 
811
1407
  mean_actor_loss = self.zero.clone()
812
1408
  mean_critic_loss = self.zero.clone()
@@ -815,52 +1411,90 @@ class Locoformer(Module):
815
1411
 
816
1412
  cache = None
817
1413
 
818
- for (
819
- state,
820
- action,
821
- old_action_log_prob,
822
- reward,
823
- old_value,
824
- mask,
825
- advantage,
826
- returns
827
- ) in zip(*windowed_tensors):
1414
+ for window_tensors in zip(*windowed_data.values()):
828
1415
 
829
- (action_logits, value_logits), cache = self.forward(state, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True)
830
- entropy = calc_entropy(action_logits)
1416
+ data = SimpleNamespace(**dict(zip(windowed_data.keys(), window_tensors)))
831
1417
 
832
- action = rearrange(action, 'b t -> b t 1')
833
- log_prob = action_logits.gather(-1, action)
834
- log_prob = rearrange(log_prob, 'b t 1 -> b t')
1418
+ ((action_logits, maybe_state_pred), value_logits), cache = self.forward(data.state, past_action = data.past_action if self.embed_past_action else None, state_embed_kwargs = {**state_embed_kwargs, 'internal_state': data.internal_state}, action_select_kwargs = action_select_kwargs, state_id_kwarg = state_id_kwarg, condition = data.condition, cond_mask = data.cond_mask, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True, return_state_pred = True)
1419
+
1420
+ log_prob = self.unembedder.log_prob(action_logits, data.action, **action_select_kwargs)
835
1421
 
836
1422
  # update actor, classic clipped surrogate loss
837
1423
 
838
1424
  eps_clip = self.ppo_eps_clip
839
- ratio = (log_prob - old_action_log_prob).exp()
1425
+ ratio = (log_prob - data.old_action_log_prob).exp()
1426
+
1427
+ calc_spo = lambda: -(ratio * data.advantage - (data.advantage.abs() * (ratio - 1.).square()) / (2 * eps_clip))
1428
+
1429
+ calc_ppo = lambda: -torch.min(ratio * data.advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * data.advantage)
840
1430
 
841
- if self.use_spo:
842
- actor_loss = -(ratio * advantage - (advantage.abs() * (ratio - 1.).square()) / (2 * eps_clip))
1431
+ if self.asymmetric_spo:
1432
+ actor_loss = torch.where(data.advantage >= 0, calc_ppo(), calc_spo())
1433
+ elif self.use_spo:
1434
+ actor_loss = calc_spo()
843
1435
  else:
844
- actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
1436
+ actor_loss = calc_ppo()
845
1437
 
846
- actor_loss = actor_loss - self.ppo_entropy_weight * entropy
1438
+ # maybe entropy
847
1439
 
848
- windowed_actor_loss = actor_loss[mask].sum() / total_learnable_tokens
849
- windowed_actor_loss.backward(retain_graph = True)
1440
+ if self.ppo_entropy_weight > 0.:
1441
+ entropy = self.unembedder.entropy(action_logits, **action_select_kwargs)
850
1442
 
851
- # update critic
1443
+ if exists(entropy):
1444
+ actor_loss = actor_loss - self.ppo_entropy_weight * entropy
1445
+
1446
+ windowed_actor_loss = actor_loss[data.windowed_gae_mask].sum() / total_learnable_tokens
1447
+
1448
+ # maybe add state prediction
1449
+
1450
+ if (
1451
+ exists(maybe_state_pred) and
1452
+ self.has_state_pred_loss and
1453
+ compute_state_pred_loss and
1454
+ data.windowed_gae_mask[:, :-1].any()
1455
+ ):
1456
+ state_pred = maybe_state_pred[:, :-1]
1457
+ state_labels = data.state[:, 1:]
1458
+ loss_mask = data.windowed_gae_mask[:, :-1]
1459
+
1460
+ state_id = state_id_kwarg.get('state_id', 0)
1461
+
1462
+ state_pred_loss = self.state_pred_head.calculate_loss(state_pred, state_labels, selector_index = state_id, return_unreduced_loss = True)
1463
+
1464
+ state_pred_loss = state_pred_loss.mean(dim = -1) # average over state features
1465
+
1466
+ windowed_state_pred_loss = state_pred_loss[loss_mask].sum() / total_learnable_tokens
1467
+
1468
+ windowed_actor_loss = (
1469
+ windowed_actor_loss +
1470
+ windowed_state_pred_loss * state_pred_loss_weight
1471
+ )
1472
+
1473
+ # maybe soft constrain continuous actions
1474
+
1475
+ if (
1476
+ self.ppo_soft_constrain_action_max and
1477
+ self.unembedder.has_continuous
1478
+ ):
1479
+ loss_mask = data.windowed_gae_mask
852
1480
 
853
- value_loss = self.hl_gauss_loss(value_logits, returns, reduction = 'none')
1481
+ soft_constrain_loss = (action_logits[..., 0].abs() - self.ppo_soft_constrain_action_max).relu().pow(2)
1482
+ windowed_soft_constrain_loss = soft_constrain_loss[loss_mask].sum() / total_learnable_tokens
854
1483
 
855
- value_clip = self.ppo_value_clip
856
- value = self.hl_gauss_loss(value_logits)
1484
+ windowed_actor_loss = (
1485
+ windowed_actor_loss +
1486
+ windowed_soft_constrain_loss * self.ppo_soft_constrain_action_loss_weight
1487
+ )
857
1488
 
858
- clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
859
- clipped_value_loss = self.hl_gauss_loss(clipped_value, returns, reduction = 'none')
1489
+ # windowed loss
860
1490
 
861
- critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
1491
+ windowed_actor_loss.backward(retain_graph = True)
1492
+
1493
+ # update critic
862
1494
 
863
- windowed_critic_loss = critic_loss[mask].sum() / total_learnable_tokens
1495
+ value_loss = self.hl_gauss_loss(value_logits, data.returns, reduction = 'none') * self.value_loss_weight
1496
+
1497
+ windowed_critic_loss = value_loss[data.windowed_gae_mask].sum() / total_learnable_tokens
864
1498
  windowed_critic_loss.backward(retain_graph = True)
865
1499
 
866
1500
  # accumulate
@@ -870,31 +1504,69 @@ class Locoformer(Module):
870
1504
 
871
1505
  # optimizer update
872
1506
 
873
- if exists(actor_optim):
874
- actor_optim.step()
875
- actor_optim.zero_grad()
1507
+ if exists(optims):
1508
+
1509
+ if exists(accelerator):
1510
+ accelerator.clip_grad_norm_(self.parameters(), max_grad_norm)
1511
+ else:
1512
+ nn.utils.clip_grad_norm_(self.parameters(), max_grad_norm)
876
1513
 
877
- if exists(critic_optim):
878
- critic_optim.step()
879
- critic_optim.zero_grad()
1514
+ for optim in optims:
1515
+ optim.step()
1516
+ optim.zero_grad()
880
1517
 
881
1518
  # return losses for logging
882
1519
 
883
1520
  return mean_actor_loss.detach(), mean_critic_loss.detach()
884
1521
 
885
- def state_to_rewards(
1522
+ def state_and_command_to_rewards(
886
1523
  self,
887
- state
1524
+ state,
1525
+ commands = None,
1526
+ env_index: int | None = None
888
1527
  ) -> Tensor:
889
1528
 
890
1529
  assert self.has_reward_shaping
1530
+ assert xnor(exists(env_index), self.reward_shaping_fns_multiple_envs), f'`env_index` must be passed in if multiple reward shaping functions are defined, and vice versa (not passed in if only single list of reward shaping functions)'
1531
+
1532
+ rewards = []
1533
+
1534
+ reward_shaping_fns = self.reward_shaping_fns[env_index] if exists(env_index) else self.reward_shaping_fns
891
1535
 
892
- rewards = [fn(state) for fn in self.reward_shaping_fns]
1536
+ for fn in reward_shaping_fns:
1537
+ param_names = get_param_names(fn)
1538
+ param_names = set(param_names) & {'state', 'command'}
1539
+
1540
+ if param_names == {'state'}: # only state
1541
+ reward = fn(state = state)
1542
+ elif param_names == {'state', 'command'}: # state and command
1543
+ reward = fn(state = state, command = commands)
1544
+ else:
1545
+ raise ValueError('invalid number of arguments for reward shaping function')
1546
+
1547
+ rewards.append(reward)
1548
+
1549
+ # cast to Tensor if returns a float, just make it flexible for researcher
893
1550
 
894
1551
  rewards = [tensor(reward) if not is_tensor(reward) else reward for reward in rewards]
895
- return stack(rewards)
896
1552
 
897
- def wrap_env_functions(self, env):
1553
+ assert all([r.numel() == 1 for r in rewards])
1554
+
1555
+ if len(rewards) == 0:
1556
+ return None
1557
+
1558
+ packed_rewards, _ = pack(rewards, '*')
1559
+ return packed_rewards
1560
+
1561
+ @beartype
1562
+ def wrap_env_functions(
1563
+ self,
1564
+ env,
1565
+ env_output_transforms: dict[str, Callable] = dict(),
1566
+ state_transform: Callable = identity,
1567
+ reward_norm = 1.,
1568
+ command_generator: Callable = always(None)
1569
+ ):
898
1570
 
899
1571
  def transform_output(el):
900
1572
  if isinstance(el, ndarray):
@@ -907,23 +1579,57 @@ class Locoformer(Module):
907
1579
  def wrapped_reset(*args, **kwargs):
908
1580
  env_reset_out = env.reset(*args, **kwargs)
909
1581
 
910
- return tree_map(transform_output, env_reset_out)
1582
+ env_reset_out_torch = tree_map(transform_output, env_reset_out)
1583
+
1584
+ env_reset_out_dict = self.parse_env_reset_out(env_reset_out_torch)
1585
+
1586
+ env_reset_out_dict['state'] = state_transform(env_reset_out_dict['state'])
911
1587
 
912
- def wrapped_step(action, *args, **kwargs):
1588
+ derived_states = dict()
1589
+
1590
+ for derived_name, transform in env_output_transforms.items():
1591
+ derived_states[derived_name] = transform(env_reset_out_dict, env)
1592
+
1593
+ env_reset_out_dict['derived_state'] = derived_states
1594
+
1595
+ return env_reset_out_dict
1596
+
1597
+ def wrapped_step(action, *args, command = None, env_index = None, **kwargs):
913
1598
 
914
1599
  if is_tensor(action):
915
- action = action.item()
1600
+ if action.numel() == 1:
1601
+ action = action.item()
1602
+ else:
1603
+ action = action.tolist()
916
1604
 
917
1605
  env_step_out = env.step(action, *args, **kwargs)
918
1606
 
919
1607
  env_step_out_torch = tree_map(transform_output, env_step_out)
920
1608
 
921
- if not self.has_reward_shaping:
922
- return env_step_out_torch
1609
+ env_step_out_dict = self.parse_env_step_out(env_step_out_torch)
1610
+
1611
+ env_step_out_dict['state'] = state_transform(env_step_out_dict['state'])
1612
+
1613
+ env_step_out_dict['reward'] = env_step_out_dict['reward'] / reward_norm
1614
+
1615
+ if self.has_reward_shaping:
1616
+ shaped_rewards = self.state_and_command_to_rewards(env_step_out_dict['state'], command, env_index = env_index)
1617
+
1618
+ if exists(shaped_rewards):
1619
+ env_step_out_dict['shaped_rewards'] = shaped_rewards
1620
+
1621
+ # add shaped rewards to main reward
1622
+
1623
+ env_step_out_dict['reward'] = env_step_out_dict['reward'] + shaped_rewards.sum()
923
1624
 
924
- shaped_rewards = self.state_to_rewards(env_step_out_torch)
1625
+ derived_states = dict()
925
1626
 
926
- return env_step_out_torch, shaped_rewards
1627
+ for derived_name, transform in env_output_transforms.items():
1628
+ derived_states[derived_name] = transform(env_step_out_dict, env)
1629
+
1630
+ env_step_out_dict['derived_state'] = derived_states
1631
+
1632
+ return env_step_out_dict
927
1633
 
928
1634
  return wrapped_reset, wrapped_step
929
1635
 
@@ -936,24 +1642,54 @@ class Locoformer(Module):
936
1642
  state_time_dim = 1,
937
1643
  **kwargs
938
1644
  ):
939
- window_size = self.window_size
940
1645
 
941
1646
  cache = None
942
1647
 
943
- def stateful_forward(state: Tensor, **override_kwargs):
1648
+ def stateful_forward(
1649
+ state: Tensor,
1650
+ condition: Tensor | None = None,
1651
+ cond_mask: Tensor | None = None,
1652
+ **override_kwargs
1653
+ ):
944
1654
  nonlocal cache
945
1655
 
1656
+ state = state.to(self.device)
1657
+
1658
+ if exists(condition):
1659
+ condition = condition.to(self.device)
1660
+
1661
+ if exists(cond_mask):
1662
+ cond_mask = cond_mask.to(self.device)
1663
+
946
1664
  # handle no batch or time, for easier time rolling out against envs
947
1665
 
948
1666
  if not has_batch_dim:
949
1667
  state = rearrange(state, '... -> 1 ...')
950
1668
 
1669
+ if exists(condition):
1670
+ condition = rearrange(condition, '... -> 1 ...')
1671
+
1672
+ if exists(cond_mask):
1673
+ cond_mask = rearrange(cond_mask, '... -> 1 ...')
1674
+
951
1675
  if not has_time_dim:
952
1676
  state = state.unsqueeze(state_time_dim)
953
1677
 
1678
+ if exists(condition):
1679
+ condition = rearrange(condition, '... d -> ... 1 d')
1680
+
1681
+ if exists(cond_mask):
1682
+ cond_mask = cond_mask.unsqueeze(state_time_dim)
1683
+
954
1684
  # forwards
955
1685
 
956
- out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
1686
+ out, cache = self.forward(
1687
+ state,
1688
+ condition = condition,
1689
+ cond_mask = cond_mask,
1690
+ cache = cache,
1691
+ **{**kwargs, **override_kwargs}
1692
+ )
957
1693
 
958
1694
  # maybe remove batch or time
959
1695
 
@@ -984,49 +1720,275 @@ class Locoformer(Module):
984
1720
 
985
1721
  return stateful_forward, initial_logits
986
1722
 
1723
+ @beartype
1724
+ def gather_experience_from_env_(
1725
+ self,
1726
+ wrapped_env_functions: tuple[Callable, Callable],
1727
+ replay: ReplayBuffer,
1728
+ embed_past_action = False,
1729
+ max_timesteps = None,
1730
+ use_vision = False,
1731
+ action_select_kwargs: dict = dict(),
1732
+ state_embed_kwargs: dict = dict(),
1733
+ state_id_kwarg: dict = dict(),
1734
+ env_index: int | None = None,
1735
+ state_entropy_bonus_weight = 0.,
1736
+ action_rescale_range: tuple[float, float] | None = None,
1737
+ command_fn: Callable = always(None)
1738
+ ):
1739
+
1740
+ env_reset, env_step = wrapped_env_functions
1741
+
1742
+ reset_out_dict = env_reset()
1743
+ derived, state = pick(reset_out_dict, ('derived_state', 'state'))
1744
+
1745
+ state_image = derived.get('state_image', None)
1746
+ internal_state = derived.get('internal_state', None)
1747
+
1748
+ timestep = 0
1749
+
1750
+ max_timesteps = default(max_timesteps, replay.max_timesteps)
1751
+
1752
+ stateful_forward = self.get_stateful_forward(
1753
+ has_batch_dim = False,
1754
+ has_time_dim = False,
1755
+ inference_mode = True
1756
+ )
1757
+
1758
+ cum_rewards = 0.
1759
+
1760
+ with replay.one_episode() as final_meta_data_store_dict:
1761
+
1762
+ past_action = None
1763
+
1764
+ while True:
1765
+ state_for_model = state_image if use_vision else state
1766
+
1767
+ maybe_command = command_fn(state_for_model)
1768
+
1769
+ # predict next action
1770
+
1771
+ (action_logits, state_pred), value = stateful_forward(
1772
+ state_for_model,
1773
+ condition = maybe_command,
1774
+ cond_mask = tensor(exists(maybe_command)),
1775
+ state_embed_kwargs = {**state_embed_kwargs, 'internal_state': internal_state},
1776
+ action_select_kwargs = action_select_kwargs,
1777
+ state_id_kwarg = state_id_kwarg,
1778
+ past_action = past_action if embed_past_action else None,
1779
+ return_values = True,
1780
+ return_state_pred = True
1781
+ )
1782
+
1783
+ action = self.unembedder.sample(action_logits, **action_select_kwargs)
1784
+
1785
+ # maybe clip
1786
+
1787
+ if exists(action_rescale_range):
1788
+ min_val, max_val = action_rescale_range
1789
+ action = (action + 1.) * 0.5 * (max_val - min_val) + min_val
1790
+
1791
+ # pass to environment
1792
+
1793
+ step_dict = env_step(action, command = maybe_command, env_index = env_index)
1794
+
1795
+ derived, next_state, reward, terminated, truncated = pick(step_dict, ('derived_state', 'state', 'reward', 'terminated', 'truncated'))
1796
+
1797
+ next_state_image = derived.get('state_image', None)
1798
+ next_internal_state = derived.get('internal_state', None)
1799
+
1800
+ # maybe state entropy bonus
1801
+
1802
+ if state_entropy_bonus_weight > 0. and exists(state_pred):
1803
+ state_id = state_id_kwarg.get('state_id', 0)
1804
+ entropy = self.state_pred_head.entropy(state_pred, selector_index = state_id)
1805
+
1806
+ state_entropy_bonus = (entropy * state_entropy_bonus_weight).sum()
1807
+
1808
+ reward = reward + state_entropy_bonus.item() # the entropy is directly related to log variance
1809
+
1810
+ # cum rewards
1811
+
1812
+ cum_rewards += reward
1813
+
1814
+ # increment counters
1815
+ # we will store the step with done=False, as only the bootstrap/boundary node is done=True
1816
+
1817
+ exceeds_max_timesteps = max_timesteps >= 0 and timestep == (max_timesteps - 1)
1818
+ should_stop = truncated or terminated or tensor(exceeds_max_timesteps)
1819
+
1820
+ # get log prob of action
1821
+
1822
+ action_log_prob = self.unembedder.log_prob(action_logits, action, **action_select_kwargs)
1823
+
1824
+ memory = replay.store(
1825
+ state = state,
1826
+ state_image = state_image,
1827
+ action = action,
1828
+ action_log_prob = action_log_prob,
1829
+ internal_state = internal_state,
1830
+ reward = reward,
1831
+ value = value,
1832
+ done = tensor(False),
1833
+ condition = maybe_command,
1834
+ cond_mask = tensor(exists(maybe_command))
1835
+ )
1836
+
1837
+ timestep += 1
1838
+
1839
+ # break if done or exceed max timestep
1840
+ if should_stop:
1841
+
1842
+ # handle bootstrap value, which is a non-learnable timestep added with the next value for GAE
1843
+ # only if terminated signal not detected
1844
+ if not terminated:
1845
+ next_state_for_model = next_state_image if use_vision else next_state
1846
+
1847
+ _, next_value = stateful_forward(next_state_for_model, condition = maybe_command, cond_mask = tensor(exists(maybe_command)), return_values = True, state_embed_kwargs = {**state_embed_kwargs, 'internal_state': internal_state}, state_id_kwarg = state_id_kwarg, action_select_kwargs = action_select_kwargs)
1848
+
1849
+ terminal_node = dict(
1850
+ state = next_state,
1851
+ state_image = next_state_image,
1852
+ internal_state = next_internal_state,
1853
+ value = next_value,
1854
+ reward = next_value,
1855
+ done = True,
1856
+ condition = maybe_command,
1857
+ cond_mask = exists(maybe_command)
1858
+ )
1859
+
1860
+ else:
1861
+ # terminal node - store a step with 0 reward and value, and done=True, to stop GAE scan
1862
+ terminal_node = dict(
1863
+ state = next_state,
1864
+ state_image = next_state_image,
1865
+ internal_state = next_internal_state,
1866
+ value = torch.zeros_like(value),
1867
+ reward = torch.zeros_like(reward),
1868
+ done = True,
1869
+ condition = maybe_command,
1870
+ cond_mask = exists(maybe_command)
1871
+ )
1872
+
1873
+ terminal_node = {key: value for key, value in terminal_node.items() if key in memory._fields}
1874
+
1875
+ terminal_memory = memory._replace(**terminal_node)
1876
+
1877
+ replay.store(**terminal_memory._asdict())
1878
+
1879
+ # store the final cumulative reward into meta data
1880
+
1881
+ final_meta_data_store_dict.update(cum_rewards = cum_rewards)
1882
+
1883
+ break
1884
+
1885
+ state = next_state
1886
+ state_image = next_state_image
1887
+ internal_state = next_internal_state
1888
+
1889
+ past_action = action
1890
+
1891
+ return cum_rewards
1892
+
987
1893
  def forward(
988
1894
  self,
989
1895
  state: Tensor,
990
- cache: Cache | None = None,
1896
+ cache: TransformerMemory | None = None,
1897
+ condition: Tensor | None = None,
1898
+ cond_mask: Tensor | None = None,
1899
+ past_action: Tensor | None = None,
1900
+ state_embed_kwargs: dict = dict(),
1901
+ action_select_kwargs: dict = dict(),
1902
+ state_id_kwarg: dict = dict(),
991
1903
  detach_cache = False,
992
1904
  return_values = False,
1905
+ return_state_pred = False,
993
1906
  return_raw_value_logits = False
994
1907
  ):
995
1908
 
996
1909
  state = state.to(self.device)
997
1910
 
998
- tokens = self.embedder(state)
1911
+ # move condition
999
1912
 
1000
- # time
1913
+ if exists(condition):
1914
+ condition = condition.to(self.device)
1001
1915
 
1002
- time = tokens.shape[-2]
1916
+ # determine which function to invoke for state to token for transformer
1003
1917
 
1004
- # destruct the cache for the current timestep and the cache
1918
+ state_to_token = self.embedder
1005
1919
 
1006
- prev_kv_cache = None
1007
- timestep_start = 0
1920
+ # embed
1008
1921
 
1009
- if exists(cache):
1010
- timestep_start, prev_kv_cache = cache
1922
+ tokens = state_to_token(state, **state_embed_kwargs, **state_id_kwarg)
1923
+
1924
+ # maybe add past action
1925
+
1926
+ # determine if first window and start of sequence
1927
+
1928
+ total_tokens = cache.total_tokens if exists(cache) else 0
1929
+
1930
+ is_start_of_sequence = total_tokens == 0
1931
+
1932
+ # maybe add past action
1933
+
1934
+ if exists(past_action):
1935
+ assert self.embed_past_action
1936
+ past_action_embed = self.past_action_embedder(past_action, **action_select_kwargs)
1937
+
1938
+ if is_start_of_sequence:
1939
+ past_action_embed = pad_at_dim(past_action_embed[..., 1:, :], (1, 0), dim = -2)
1940
+
1941
+ tokens = tokens + past_action_embed
1942
+
1943
+ # time
1944
+
1945
+ time = tokens.shape[-2]
1011
1946
 
1012
1947
  # an assert - make sure during training or inference, forward never gets anything that crosses the window segment boundary, to open up some possibilities with extending memory
1013
1948
 
1014
- assert ((timestep_start % self.window_size) + time) <= self.window_size
1949
+ assert ((total_tokens % self.window_size) + time) <= self.window_size
1015
1950
 
1016
1951
  # attention
1017
1952
 
1018
- embed, kv_cache = self.transformer(tokens, cache = prev_kv_cache, return_kv_cache = True)
1953
+ if not exists(cache):
1954
+ memory_segments = deque(maxlen = self.max_mem_segments)
1955
+ else:
1956
+ total_tokens, kv_cache, gru_cache, mem_mlp_cache, mem_mlp_hidden_states, memory_segments = cache
1957
+
1958
+ attn_cache = cache
1959
+
1960
+ embed, cache = self.transformer(
1961
+ tokens,
1962
+ condition = condition,
1963
+ cond_mask = cond_mask,
1964
+ cache = attn_cache,
1965
+ return_kv_cache = True
1966
+ )
1019
1967
 
1020
1968
  # unembed to actions - in language models this would be the next state
1021
1969
 
1022
- action_logits = self.unembedder(embed)
1970
+ policy_embed = self.policy_network(embed)
1971
+
1972
+ action_logits = self.unembedder(policy_embed, **action_select_kwargs)
1023
1973
 
1024
1974
  out = action_logits
1025
1975
 
1976
+ # maybe return state prediction
1977
+
1978
+ if return_state_pred:
1979
+ state_pred = None
1980
+
1981
+ if self.can_pred_state:
1982
+ state_id = state_id_kwarg.get('state_id', 0)
1983
+ state_pred_embed = self.state_pred_network(embed)
1984
+ state_pred = self.state_pred_head(state_pred_embed, selector_index = state_id)
1985
+
1986
+ out = (out, state_pred)
1987
+
1026
1988
  # maybe detach cache
1027
1989
 
1028
1990
  if detach_cache:
1029
- kv_cache = kv_cache.detach()
1991
+ cache = tree_map_tensor(cache, lambda t: t.detach())
1030
1992
 
1031
1993
  # handle returning of values
1032
1994
 
@@ -1040,20 +2002,22 @@ class Locoformer(Module):
1040
2002
 
1041
2003
  out = (out, values)
1042
2004
 
1043
- # output and cache
2005
+ # handle curtailing kv cache at the right intervals
1044
2006
 
1045
- next_timestep = time + timestep_start
2007
+ total_tokens, kv_cache, gru_cache, mem_mlp_cache, mem_mlp_hidden_states, memory_segments = cache
1046
2008
 
1047
- # handle curtailing kv cache at the right intervals
2009
+ if self.fixed_window_size or divisible_by(total_tokens, self.window_size * 2):
2010
+ kv_cache = kv_cache[..., -self.window_size:, :]
1048
2011
 
1049
- window_size = self.window_size
2012
+ if self.recurrent_cache and divisible_by(total_tokens, self.window_size):
2013
+ kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
1050
2014
 
1051
- if self.fixed_window_size or divisible_by(next_timestep, window_size * 2):
1052
- kv_cache = kv_cache[..., -window_size:, :]
2015
+ if exists(gru_cache):
2016
+ gru_cache = torch.roll(gru_cache, shifts = -1, dims = 0)
1053
2017
 
1054
- # maybe recurrent cache - shift the kv cache from one layer above to the one below, for extending on receptive field of past
2018
+ if divisible_by(total_tokens, self.window_size):
2019
+ memory_segments.append(kv_cache.detach())
1055
2020
 
1056
- if self.recurrent_kv_cache and divisible_by(next_timestep, window_size):
1057
- kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
2021
+ cache = TransformerMemory(total_tokens, kv_cache, gru_cache, mem_mlp_cache, mem_mlp_hidden_states, memory_segments)
1058
2022
 
1059
- return out, (next_timestep, kv_cache)
2023
+ return out, cache