locoformer 0.0.43__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- locoformer/locoformer.py +1129 -452
- {locoformer-0.0.43.dist-info → locoformer-0.1.1.dist-info}/METADATA +30 -3
- locoformer-0.1.1.dist-info/RECORD +6 -0
- locoformer-0.0.43.dist-info/RECORD +0 -6
- {locoformer-0.0.43.dist-info → locoformer-0.1.1.dist-info}/WHEEL +0 -0
- {locoformer-0.0.43.dist-info → locoformer-0.1.1.dist-info}/licenses/LICENSE +0 -0
locoformer/locoformer.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
import math
|
|
2
3
|
from typing import Callable
|
|
3
4
|
from types import SimpleNamespace
|
|
4
5
|
from functools import partial, wraps
|
|
5
6
|
|
|
6
7
|
from pathlib import Path
|
|
7
8
|
from contextlib import contextmanager
|
|
8
|
-
from collections import namedtuple
|
|
9
|
+
from collections import namedtuple, deque
|
|
10
|
+
|
|
11
|
+
from glom import glom
|
|
9
12
|
|
|
10
13
|
from inspect import signature
|
|
11
14
|
|
|
@@ -17,16 +20,17 @@ from beartype import beartype
|
|
|
17
20
|
from beartype.door import is_bearable
|
|
18
21
|
|
|
19
22
|
import torch
|
|
20
|
-
from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
|
|
23
|
+
from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy, nested
|
|
21
24
|
import torch.nn.functional as F
|
|
22
25
|
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
|
|
23
26
|
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
24
27
|
from torch.utils.data import Dataset, DataLoader
|
|
28
|
+
from torch.distributions import Normal
|
|
25
29
|
from torch.optim import Optimizer
|
|
26
30
|
|
|
27
31
|
import einx
|
|
28
|
-
from einops import rearrange, einsum
|
|
29
|
-
from einops.layers.torch import Rearrange
|
|
32
|
+
from einops import rearrange, repeat, einsum, reduce, pack
|
|
33
|
+
from einops.layers.torch import Rearrange, Reduce
|
|
30
34
|
|
|
31
35
|
from rotary_embedding_torch import RotaryEmbedding
|
|
32
36
|
|
|
@@ -38,11 +42,24 @@ from x_mlps_pytorch import MLP
|
|
|
38
42
|
|
|
39
43
|
from x_evolution import EvoStrategy
|
|
40
44
|
|
|
45
|
+
from discrete_continuous_embed_readout import EmbedAndReadout, Embed, Readout
|
|
46
|
+
|
|
47
|
+
from hyper_connections import mc_get_init_and_expand_reduce_stream_functions
|
|
48
|
+
|
|
49
|
+
from memmap_replay_buffer import ReplayBuffer, ReplayDataset
|
|
50
|
+
|
|
41
51
|
# constants
|
|
42
52
|
|
|
43
53
|
LinearNoBias = partial(Linear, bias = False)
|
|
44
54
|
|
|
45
|
-
|
|
55
|
+
TransformerMemory = namedtuple('TransformerMemory', (
|
|
56
|
+
'total_tokens',
|
|
57
|
+
'kv_cache',
|
|
58
|
+
'gru_cache',
|
|
59
|
+
'mem_mlp_cache',
|
|
60
|
+
'mem_mlp_hidden_states',
|
|
61
|
+
'memory_segments'
|
|
62
|
+
))
|
|
46
63
|
|
|
47
64
|
# helper functions
|
|
48
65
|
|
|
@@ -52,6 +69,18 @@ def exists(v):
|
|
|
52
69
|
def default(v, d):
|
|
53
70
|
return v if exists(v) else d
|
|
54
71
|
|
|
72
|
+
def always(val):
|
|
73
|
+
def inner(*args, **kwargs):
|
|
74
|
+
return val
|
|
75
|
+
|
|
76
|
+
return inner
|
|
77
|
+
|
|
78
|
+
def identity(t, *args, **kwargs):
|
|
79
|
+
return t
|
|
80
|
+
|
|
81
|
+
def pick(data, keys):
|
|
82
|
+
return tuple(data[k] for k in keys)
|
|
83
|
+
|
|
55
84
|
def first(arr):
|
|
56
85
|
return arr[0]
|
|
57
86
|
|
|
@@ -100,6 +129,11 @@ def is_empty(t):
|
|
|
100
129
|
def tree_map_tensor(x, fn):
|
|
101
130
|
return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
|
|
102
131
|
|
|
132
|
+
def lens_to_mask(lens, max_len):
|
|
133
|
+
device = lens.device
|
|
134
|
+
seq = arange(max_len, device = device)
|
|
135
|
+
return einx.less('j, i -> i j', seq, lens)
|
|
136
|
+
|
|
103
137
|
def pad_at_dim(
|
|
104
138
|
t,
|
|
105
139
|
pad: tuple[int, int],
|
|
@@ -113,8 +147,20 @@ def pad_at_dim(
|
|
|
113
147
|
zeros = ((0, 0) * dims_from_right)
|
|
114
148
|
return F.pad(t, (*zeros, *pad), value = value)
|
|
115
149
|
|
|
116
|
-
def
|
|
117
|
-
|
|
150
|
+
def safe_cat(t, next_t, dim = -1):
|
|
151
|
+
if not exists(t):
|
|
152
|
+
return next_t
|
|
153
|
+
|
|
154
|
+
return cat((t, next_t), dim = dim)
|
|
155
|
+
|
|
156
|
+
def normalize(t, mask = None, eps = 1e-5):
|
|
157
|
+
if exists(mask):
|
|
158
|
+
assert mask.any()
|
|
159
|
+
|
|
160
|
+
t_for_stats = t[mask] if exists(mask) else t
|
|
161
|
+
var, mean = torch.var_mean(t_for_stats)
|
|
162
|
+
|
|
163
|
+
return (t - mean) / var.sqrt().clamp_min(eps)
|
|
118
164
|
|
|
119
165
|
def tensor_to_dict(
|
|
120
166
|
t: Tensor,
|
|
@@ -135,9 +181,75 @@ def tensor_to_dict(
|
|
|
135
181
|
|
|
136
182
|
return SimpleNamespace(**tensor_dict)
|
|
137
183
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
184
|
+
# dataset related
|
|
185
|
+
|
|
186
|
+
class RemappedReplayDataset(Dataset):
|
|
187
|
+
def __init__(
|
|
188
|
+
self,
|
|
189
|
+
dataset: ReplayDataset,
|
|
190
|
+
episode_mapping: Tensor | list[list[int]],
|
|
191
|
+
shuffle_episodes = False,
|
|
192
|
+
num_trials_select = None
|
|
193
|
+
):
|
|
194
|
+
assert len(dataset) > 0
|
|
195
|
+
self.dataset = dataset
|
|
196
|
+
|
|
197
|
+
if is_tensor(episode_mapping):
|
|
198
|
+
assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
|
|
199
|
+
episode_mapping = episode_mapping.tolist()
|
|
200
|
+
|
|
201
|
+
self.episode_mapping = episode_mapping
|
|
202
|
+
self.shuffle_episodes = shuffle_episodes
|
|
203
|
+
|
|
204
|
+
assert not (exists(num_trials_select) and num_trials_select <= 0)
|
|
205
|
+
self.sub_select_trials = exists(num_trials_select)
|
|
206
|
+
self.num_trials_select = num_trials_select
|
|
207
|
+
|
|
208
|
+
def __len__(self):
|
|
209
|
+
return len(self.episode_mapping)
|
|
210
|
+
|
|
211
|
+
def __getitem__(self, idx):
|
|
212
|
+
|
|
213
|
+
episode_indices = self.episode_mapping[idx]
|
|
214
|
+
|
|
215
|
+
episode_indices = tensor(episode_indices)
|
|
216
|
+
episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
|
|
217
|
+
|
|
218
|
+
assert not is_empty(episode_indices)
|
|
219
|
+
|
|
220
|
+
# shuffle the episode indices if either shuffle episodes is turned on, or `num_trial_select` passed in (for sub selecting episodes from a set)
|
|
221
|
+
|
|
222
|
+
if (
|
|
223
|
+
episode_indices.numel() > 1 and
|
|
224
|
+
(self.shuffle_episodes or self.sub_select_trials)
|
|
225
|
+
):
|
|
226
|
+
num_episodes = len(episode_indices)
|
|
227
|
+
episode_indices = episode_indices[torch.randperm(num_episodes)]
|
|
228
|
+
|
|
229
|
+
# crop out the episodes
|
|
230
|
+
|
|
231
|
+
if self.sub_select_trials:
|
|
232
|
+
episode_indices = episode_indices[:self.num_trials_select]
|
|
233
|
+
|
|
234
|
+
# now select out the episode data and merge along time
|
|
235
|
+
|
|
236
|
+
episode_data = [self.dataset[i] for i in episode_indices.tolist()]
|
|
237
|
+
|
|
238
|
+
episode_lens = stack([data.pop('_lens') for data in episode_data])
|
|
239
|
+
|
|
240
|
+
keys = first(episode_data).keys()
|
|
241
|
+
|
|
242
|
+
values = [list(data.values()) for data in episode_data]
|
|
243
|
+
|
|
244
|
+
values = [cat(field_values) for field_values in zip(*values)] # concat across time
|
|
245
|
+
|
|
246
|
+
multi_episode_data = dict(zip(keys, values))
|
|
247
|
+
|
|
248
|
+
multi_episode_data['_lens'] = episode_lens.sum()
|
|
249
|
+
|
|
250
|
+
multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
|
|
251
|
+
|
|
252
|
+
return multi_episode_data
|
|
141
253
|
|
|
142
254
|
# reward functions - A.2
|
|
143
255
|
|
|
@@ -300,305 +412,6 @@ def create_sliding_mask(
|
|
|
300
412
|
create_kwargs = dict(device = device) if exists(device) else dict()
|
|
301
413
|
return create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True, **create_kwargs)
|
|
302
414
|
|
|
303
|
-
# data
|
|
304
|
-
|
|
305
|
-
def collate_var_time(data):
|
|
306
|
-
|
|
307
|
-
datum = first(data)
|
|
308
|
-
keys = datum.keys()
|
|
309
|
-
|
|
310
|
-
all_tensors = zip(*[datum.values() for datum in data])
|
|
311
|
-
|
|
312
|
-
collated_values = []
|
|
313
|
-
|
|
314
|
-
for key, tensors in zip(keys, all_tensors):
|
|
315
|
-
|
|
316
|
-
# the episode lens have zero dimension - think of a cleaner way to handle this later
|
|
317
|
-
|
|
318
|
-
if key != '_lens':
|
|
319
|
-
|
|
320
|
-
times = [t.shape[0] for t in tensors]
|
|
321
|
-
max_time = max(times)
|
|
322
|
-
tensors = [pad_at_dim(t, (0, max_time - t.shape[0]), dim = 0) for t in tensors]
|
|
323
|
-
|
|
324
|
-
collated_values.append(stack(tensors))
|
|
325
|
-
|
|
326
|
-
return dict(zip(keys, collated_values))
|
|
327
|
-
|
|
328
|
-
class ReplayDataset(Dataset):
|
|
329
|
-
def __init__(
|
|
330
|
-
self,
|
|
331
|
-
folder: str | Path,
|
|
332
|
-
fields: tuple[str, ...] | None = None
|
|
333
|
-
):
|
|
334
|
-
if isinstance(folder, str):
|
|
335
|
-
folder = Path(folder)
|
|
336
|
-
|
|
337
|
-
episode_lens = folder / 'episode_lens.npy'
|
|
338
|
-
self.episode_lens = open_memmap(str(episode_lens), mode = 'r')
|
|
339
|
-
|
|
340
|
-
# get indices of non-zero lengthed episodes
|
|
341
|
-
|
|
342
|
-
nonzero_episodes = self.episode_lens > 0
|
|
343
|
-
self.indices = np.arange(self.episode_lens.shape[-1])[nonzero_episodes]
|
|
344
|
-
|
|
345
|
-
# get all data files
|
|
346
|
-
|
|
347
|
-
filepaths = [*folder.glob('*.data.npy')]
|
|
348
|
-
assert len(filepaths) > 0
|
|
349
|
-
|
|
350
|
-
fieldname_to_filepath = {path.name.split('.')[0]: path for path in filepaths}
|
|
351
|
-
|
|
352
|
-
fieldnames_from_files = set(fieldname_to_filepath.keys())
|
|
353
|
-
|
|
354
|
-
fields = default(fields, fieldnames_from_files)
|
|
355
|
-
|
|
356
|
-
self.memmaps = dict()
|
|
357
|
-
|
|
358
|
-
for field in fields:
|
|
359
|
-
assert field in fieldnames_from_files, f'invalid field {field} - must be one of {fieldnames_from_files}'
|
|
360
|
-
|
|
361
|
-
path = fieldname_to_filepath[field]
|
|
362
|
-
|
|
363
|
-
self.memmaps[field] = open_memmap(str(path), mode = 'r')
|
|
364
|
-
|
|
365
|
-
def __len__(self):
|
|
366
|
-
return len(self.indices)
|
|
367
|
-
|
|
368
|
-
def __getitem__(self, idx):
|
|
369
|
-
episode_index = self.indices[idx]
|
|
370
|
-
|
|
371
|
-
episode_len = self.episode_lens[episode_index]
|
|
372
|
-
|
|
373
|
-
data = {field: from_numpy(memmap[episode_index, :episode_len].copy()) for field, memmap in self.memmaps.items()}
|
|
374
|
-
|
|
375
|
-
data['_lens'] = tensor(episode_len)
|
|
376
|
-
|
|
377
|
-
return data
|
|
378
|
-
|
|
379
|
-
class RemappedReplayDataset(Dataset):
|
|
380
|
-
def __init__(
|
|
381
|
-
self,
|
|
382
|
-
dataset: ReplayDataset,
|
|
383
|
-
episode_mapping: Tensor | list[list[int]],
|
|
384
|
-
shuffle_episodes = False,
|
|
385
|
-
num_trials_select = None
|
|
386
|
-
):
|
|
387
|
-
assert len(dataset) > 0
|
|
388
|
-
self.dataset = dataset
|
|
389
|
-
|
|
390
|
-
if is_tensor(episode_mapping):
|
|
391
|
-
assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
|
|
392
|
-
episode_mapping = episode_mapping.tolist()
|
|
393
|
-
|
|
394
|
-
self.episode_mapping = episode_mapping
|
|
395
|
-
self.shuffle_episodes = shuffle_episodes
|
|
396
|
-
|
|
397
|
-
assert not (exists(num_trials_select) and num_trials_select >= 1)
|
|
398
|
-
self.sub_select_trials = exists(num_trials_select)
|
|
399
|
-
self.num_trials_select = num_trials_select
|
|
400
|
-
|
|
401
|
-
def __len__(self):
|
|
402
|
-
return len(self.episode_mapping)
|
|
403
|
-
|
|
404
|
-
def __getitem__(self, idx):
|
|
405
|
-
|
|
406
|
-
episode_indices = self.episode_mapping[idx]
|
|
407
|
-
|
|
408
|
-
episode_indices = tensor(episode_indices)
|
|
409
|
-
episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
|
|
410
|
-
|
|
411
|
-
assert not is_empty(episode_indices)
|
|
412
|
-
|
|
413
|
-
# shuffle the episode indices if either shuffle episodes is turned on, or `num_trial_select` passed in (for sub selecting episodes from a set)
|
|
414
|
-
|
|
415
|
-
if (
|
|
416
|
-
episode_indices.numel() > 1 and
|
|
417
|
-
(self.shuffle_episodes or self.sub_select_trials)
|
|
418
|
-
):
|
|
419
|
-
num_episodes = len(episode_indices)
|
|
420
|
-
episode_indices = episode_indices[torch.randperm(num_episodes)]
|
|
421
|
-
|
|
422
|
-
# crop out the episodes
|
|
423
|
-
|
|
424
|
-
if self.sub_select_trials:
|
|
425
|
-
episode_indices = episode_indices[:self.num_trials_select]
|
|
426
|
-
|
|
427
|
-
# now select out the episode data and merge along time
|
|
428
|
-
|
|
429
|
-
episode_data = [self.dataset[i] for i in episode_indices.tolist()]
|
|
430
|
-
|
|
431
|
-
episode_lens = stack([data.pop('_lens') for data in episode_data])
|
|
432
|
-
|
|
433
|
-
keys = first(episode_data).keys()
|
|
434
|
-
|
|
435
|
-
values = [list(data.values()) for data in episode_data]
|
|
436
|
-
|
|
437
|
-
values = [cat(field_values) for field_values in zip(*values)] # concat across time
|
|
438
|
-
|
|
439
|
-
multi_episode_data = dict(zip(keys, values))
|
|
440
|
-
|
|
441
|
-
multi_episode_data['_lens'] = episode_lens.sum()
|
|
442
|
-
|
|
443
|
-
multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
|
|
444
|
-
|
|
445
|
-
return multi_episode_data
|
|
446
|
-
|
|
447
|
-
class ReplayBuffer:
|
|
448
|
-
|
|
449
|
-
@beartype
|
|
450
|
-
def __init__(
|
|
451
|
-
self,
|
|
452
|
-
folder: str | Path,
|
|
453
|
-
max_episodes: int,
|
|
454
|
-
max_timesteps: int,
|
|
455
|
-
fields: dict[
|
|
456
|
-
str,
|
|
457
|
-
str | tuple[str, int | tuple[int, ...]]
|
|
458
|
-
]
|
|
459
|
-
):
|
|
460
|
-
|
|
461
|
-
# folder for data
|
|
462
|
-
|
|
463
|
-
if not isinstance(folder, Path):
|
|
464
|
-
folder = Path(folder)
|
|
465
|
-
folder.mkdir(exist_ok = True)
|
|
466
|
-
|
|
467
|
-
self.folder = folder
|
|
468
|
-
assert folder.is_dir()
|
|
469
|
-
|
|
470
|
-
# keeping track of episode length
|
|
471
|
-
|
|
472
|
-
episode_lens = folder / 'episode_lens.npy'
|
|
473
|
-
|
|
474
|
-
self.episode_index = 0
|
|
475
|
-
self.timestep_index = 0
|
|
476
|
-
|
|
477
|
-
self.max_episodes = max_episodes
|
|
478
|
-
self.max_timesteps= max_timesteps
|
|
479
|
-
|
|
480
|
-
self.episode_lens = open_memmap(str(episode_lens), mode = 'w+', dtype = np.int32, shape = (max_episodes,))
|
|
481
|
-
|
|
482
|
-
# create the memmap for individual data tracks
|
|
483
|
-
|
|
484
|
-
self.shapes = dict()
|
|
485
|
-
self.dtypes = dict()
|
|
486
|
-
self.memmaps = dict()
|
|
487
|
-
self.fieldnames = set(fields.keys())
|
|
488
|
-
|
|
489
|
-
for field_name, field_info in fields.items():
|
|
490
|
-
|
|
491
|
-
# some flexibility
|
|
492
|
-
|
|
493
|
-
field_info = (field_info, ()) if isinstance(field_info, str) else field_info
|
|
494
|
-
|
|
495
|
-
dtype_str, shape = field_info
|
|
496
|
-
assert dtype_str in {'int', 'float', 'bool'}
|
|
497
|
-
|
|
498
|
-
dtype = dict(int = np.int32, float = np.float32, bool = np.bool_)[dtype_str]
|
|
499
|
-
|
|
500
|
-
# memmap file
|
|
501
|
-
|
|
502
|
-
filepath = folder / f'{field_name}.data.npy'
|
|
503
|
-
|
|
504
|
-
if isinstance(shape, int):
|
|
505
|
-
shape = (shape,)
|
|
506
|
-
|
|
507
|
-
memmap = open_memmap(str(filepath), mode = 'w+', dtype = dtype, shape = (max_episodes, max_timesteps, *shape))
|
|
508
|
-
|
|
509
|
-
self.memmaps[field_name] = memmap
|
|
510
|
-
self.shapes[field_name] = shape
|
|
511
|
-
self.dtypes[field_name] = dtype
|
|
512
|
-
|
|
513
|
-
self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
|
|
514
|
-
|
|
515
|
-
def __len__(self):
|
|
516
|
-
return (self.episode_lens > 0).sum().item()
|
|
517
|
-
|
|
518
|
-
def reset_(self):
|
|
519
|
-
self.episode_lens[:] = 0
|
|
520
|
-
self.episode_index = 0
|
|
521
|
-
self.timestep_index = 0
|
|
522
|
-
|
|
523
|
-
def advance_episode(self):
|
|
524
|
-
self.episode_index = (self.episode_index + 1) % self.max_episodes
|
|
525
|
-
self.timestep_index = 0
|
|
526
|
-
|
|
527
|
-
def flush(self):
|
|
528
|
-
self.episode_lens[self.episode_index] = self.timestep_index
|
|
529
|
-
|
|
530
|
-
for memmap in self.memmaps.values():
|
|
531
|
-
memmap.flush()
|
|
532
|
-
|
|
533
|
-
self.episode_lens.flush()
|
|
534
|
-
|
|
535
|
-
@contextmanager
|
|
536
|
-
def one_episode(self):
|
|
537
|
-
|
|
538
|
-
yield
|
|
539
|
-
|
|
540
|
-
self.flush()
|
|
541
|
-
self.advance_episode()
|
|
542
|
-
|
|
543
|
-
@beartype
|
|
544
|
-
def store_datapoint(
|
|
545
|
-
self,
|
|
546
|
-
episode_index: int,
|
|
547
|
-
timestep_index: int,
|
|
548
|
-
name: str,
|
|
549
|
-
datapoint: Tensor | ndarray
|
|
550
|
-
):
|
|
551
|
-
assert 0 <= episode_index < self.max_episodes
|
|
552
|
-
assert 0 <= timestep_index < self.max_timesteps
|
|
553
|
-
|
|
554
|
-
if is_tensor(datapoint):
|
|
555
|
-
datapoint = datapoint.detach().cpu().numpy()
|
|
556
|
-
|
|
557
|
-
assert name in self.fieldnames, f'invalid field name {name} - must be one of {self.fieldnames}'
|
|
558
|
-
|
|
559
|
-
assert datapoint.shape == self.shapes[name], f'invalid shape {datapoint.shape} - shape must be {self.shapes[name]}'
|
|
560
|
-
|
|
561
|
-
self.memmaps[name][self.episode_index, self.timestep_index] = datapoint
|
|
562
|
-
|
|
563
|
-
def store(
|
|
564
|
-
self,
|
|
565
|
-
**data
|
|
566
|
-
):
|
|
567
|
-
assert is_bearable(data, dict[str, Tensor | ndarray])
|
|
568
|
-
|
|
569
|
-
assert not self.timestep_index >= self.max_timesteps, 'you exceeded the `max_timesteps` set on the replay buffer'
|
|
570
|
-
|
|
571
|
-
for name, datapoint in data.items():
|
|
572
|
-
|
|
573
|
-
self.store_datapoint(self.episode_index, self.timestep_index, name, datapoint)
|
|
574
|
-
|
|
575
|
-
self.timestep_index += 1
|
|
576
|
-
|
|
577
|
-
return self.memory_namedtuple(**data)
|
|
578
|
-
|
|
579
|
-
def dataset(
|
|
580
|
-
self,
|
|
581
|
-
episode_mapping: Tensor | list[list[int]] | None = None,
|
|
582
|
-
) -> Dataset:
|
|
583
|
-
self.flush()
|
|
584
|
-
|
|
585
|
-
dataset = ReplayDataset(self.folder)
|
|
586
|
-
|
|
587
|
-
if not exists(episode_mapping):
|
|
588
|
-
return dataset
|
|
589
|
-
|
|
590
|
-
return RemappedReplayDataset(dataset, episode_mapping)
|
|
591
|
-
|
|
592
|
-
def dataloader(
|
|
593
|
-
self,
|
|
594
|
-
batch_size,
|
|
595
|
-
episode_mapping: Tensor | list[list[int]] | None = None,
|
|
596
|
-
**kwargs
|
|
597
|
-
) -> DataLoader:
|
|
598
|
-
self.flush()
|
|
599
|
-
|
|
600
|
-
return DataLoader(self.dataset(episode_mapping), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
|
|
601
|
-
|
|
602
415
|
# normalization + conditioning (needed for the commands to the robot)
|
|
603
416
|
|
|
604
417
|
class MaybeAdaRMSNormWrapper(Module):
|
|
@@ -627,34 +440,52 @@ class MaybeAdaRMSNormWrapper(Module):
|
|
|
627
440
|
def forward(
|
|
628
441
|
self,
|
|
629
442
|
x,
|
|
443
|
+
*args,
|
|
630
444
|
cond = None,
|
|
445
|
+
cond_mask = None,
|
|
631
446
|
**kwargs
|
|
632
447
|
):
|
|
633
448
|
|
|
634
449
|
need_cond = self.accept_condition
|
|
450
|
+
has_input_cond = need_cond and exists(cond)
|
|
635
451
|
|
|
636
|
-
|
|
452
|
+
if exists(cond):
|
|
453
|
+
assert self.accept_condition
|
|
637
454
|
|
|
638
455
|
prenormed = self.norm(x)
|
|
639
456
|
|
|
640
|
-
if
|
|
457
|
+
if has_input_cond:
|
|
641
458
|
if cond.ndim == 2:
|
|
642
459
|
cond = rearrange(cond, 'b d -> b 1 d')
|
|
643
460
|
|
|
644
|
-
|
|
645
|
-
|
|
461
|
+
cond_scale = self.to_gamma(cond)
|
|
462
|
+
|
|
463
|
+
conditioned = prenormed * cond_scale
|
|
464
|
+
|
|
465
|
+
# handle a condition mask
|
|
646
466
|
|
|
647
|
-
|
|
467
|
+
if exists(cond_mask):
|
|
468
|
+
prenormed = einx.where('b n, b n d, b n d', cond_mask, conditioned, prenormed)
|
|
469
|
+
else:
|
|
470
|
+
prenormed = conditioned
|
|
471
|
+
|
|
472
|
+
# the main block, either attention or feedforward or whatever
|
|
648
473
|
|
|
649
|
-
|
|
474
|
+
all_fn_out = self.fn(prenormed, *args, **kwargs)
|
|
475
|
+
|
|
476
|
+
if not has_input_cond:
|
|
650
477
|
return all_fn_out
|
|
651
478
|
|
|
652
479
|
# function may return multiple args
|
|
653
480
|
|
|
654
481
|
(out, *rest), tree_spec = tree_flatten(all_fn_out)
|
|
655
482
|
|
|
656
|
-
|
|
657
|
-
|
|
483
|
+
scale_out = self.to_ada_norm_zero(cond).sigmoid()
|
|
484
|
+
|
|
485
|
+
if exists(cond_mask):
|
|
486
|
+
is_cond = rearrange(cond_mask, '... -> ... 1')
|
|
487
|
+
out = torch.where(is_cond, out * scale_out, out)
|
|
488
|
+
else:
|
|
658
489
|
out = out * scale_out
|
|
659
490
|
|
|
660
491
|
# restore
|
|
@@ -673,7 +504,8 @@ class Attention(Module):
|
|
|
673
504
|
dim_head = 64,
|
|
674
505
|
heads = 8,
|
|
675
506
|
fixed_window_size = False,
|
|
676
|
-
accept_value_residual = False
|
|
507
|
+
accept_value_residual = False,
|
|
508
|
+
max_mem_segments = 1
|
|
677
509
|
):
|
|
678
510
|
super().__init__()
|
|
679
511
|
self.scale = dim_head ** -0.5
|
|
@@ -709,12 +541,16 @@ class Attention(Module):
|
|
|
709
541
|
|
|
710
542
|
self.fixed_window_size = fixed_window_size
|
|
711
543
|
self.window_size = window_size
|
|
544
|
+
self.max_mem_segments = max_mem_segments
|
|
545
|
+
|
|
546
|
+
self.register_buffer('causal_mask', None, persistent = False)
|
|
712
547
|
|
|
713
548
|
def forward(
|
|
714
549
|
self,
|
|
715
550
|
tokens,
|
|
716
551
|
value_residual = None,
|
|
717
552
|
kv_cache = None,
|
|
553
|
+
past_segments = None,
|
|
718
554
|
return_kv_cache = False,
|
|
719
555
|
):
|
|
720
556
|
seq_len = tokens.shape[-2]
|
|
@@ -739,6 +575,11 @@ class Attention(Module):
|
|
|
739
575
|
k = cat((ck, k), dim = -2)
|
|
740
576
|
v = cat((cv, v), dim = -2)
|
|
741
577
|
|
|
578
|
+
if exists(past_segments):
|
|
579
|
+
pk, pv = past_segments
|
|
580
|
+
k = cat((pk, k), dim = -2)
|
|
581
|
+
v = cat((pv, v), dim = -2)
|
|
582
|
+
|
|
742
583
|
if return_kv_cache:
|
|
743
584
|
next_kv_cache = stack((k, v))
|
|
744
585
|
|
|
@@ -752,9 +593,12 @@ class Attention(Module):
|
|
|
752
593
|
i_seq = arange(i, device = device)
|
|
753
594
|
j_seq = arange(j, device = device) - (j - i)
|
|
754
595
|
dist = einx.subtract('i, j -> i j', i_seq, j_seq)
|
|
755
|
-
causal_mask = (dist < 0) | (dist > self.window_size)
|
|
596
|
+
causal_mask = (dist < 0) | (dist > (self.max_mem_segments * self.window_size))
|
|
756
597
|
else:
|
|
757
|
-
|
|
598
|
+
if not exists(self.causal_mask) or self.causal_mask.shape != (i, j):
|
|
599
|
+
self.causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
|
|
600
|
+
|
|
601
|
+
causal_mask = self.causal_mask
|
|
758
602
|
|
|
759
603
|
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
|
760
604
|
|
|
@@ -801,6 +645,7 @@ class FeedForward(Module):
|
|
|
801
645
|
return self.proj_out(x)
|
|
802
646
|
|
|
803
647
|
class TransformerXL(Module):
|
|
648
|
+
@beartype
|
|
804
649
|
def __init__(
|
|
805
650
|
self,
|
|
806
651
|
dim,
|
|
@@ -812,8 +657,31 @@ class TransformerXL(Module):
|
|
|
812
657
|
dim_cond = None,
|
|
813
658
|
final_norm = True,
|
|
814
659
|
fixed_window_size = False,
|
|
660
|
+
gru_layers = False,
|
|
661
|
+
long_term_mem_layers: tuple[int, ...] = (),
|
|
662
|
+
mem_kwargs: dict = dict(),
|
|
663
|
+
num_residual_streams = 1,
|
|
664
|
+
max_mem_segments = 1
|
|
815
665
|
):
|
|
816
666
|
super().__init__()
|
|
667
|
+
self.dim = dim
|
|
668
|
+
|
|
669
|
+
# memory
|
|
670
|
+
|
|
671
|
+
long_term_mem_layers = set(long_term_mem_layers)
|
|
672
|
+
|
|
673
|
+
assert all([1 <= l <= depth for l in long_term_mem_layers])
|
|
674
|
+
|
|
675
|
+
self.long_term_mem_layers = long_term_mem_layers
|
|
676
|
+
self.num_mem_mlps = len(long_term_mem_layers)
|
|
677
|
+
self.has_mem = self.num_mem_mlps > 0
|
|
678
|
+
self.max_mem_segments = max_mem_segments
|
|
679
|
+
|
|
680
|
+
# hyper connections
|
|
681
|
+
|
|
682
|
+
init_hyper_conn, self.expand_streams, self.reduce_streams = mc_get_init_and_expand_reduce_stream_functions(num_residual_streams)
|
|
683
|
+
|
|
684
|
+
# condition
|
|
817
685
|
|
|
818
686
|
condition = exists(dim_cond)
|
|
819
687
|
|
|
@@ -821,22 +689,48 @@ class TransformerXL(Module):
|
|
|
821
689
|
|
|
822
690
|
norm_fn = partial(MaybeAdaRMSNormWrapper, dim = dim, dim_cond = (dim * 2) if condition else None)
|
|
823
691
|
|
|
692
|
+
# layers
|
|
693
|
+
|
|
824
694
|
layers = ModuleList([])
|
|
825
695
|
|
|
826
696
|
for i in range(depth):
|
|
827
|
-
|
|
697
|
+
layer = i + 1
|
|
698
|
+
is_first = layer == 1
|
|
699
|
+
has_mem = layer in long_term_mem_layers
|
|
700
|
+
|
|
701
|
+
gru = norm_fn(nn.GRU(dim, dim, batch_first = True)) if gru_layers else None
|
|
828
702
|
|
|
829
|
-
|
|
703
|
+
mem = MemoryMLP(dim, **mem_kwargs) if has_mem else None
|
|
704
|
+
|
|
705
|
+
attn = norm_fn(Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first, max_mem_segments = max_mem_segments))
|
|
830
706
|
|
|
831
707
|
ff = norm_fn(FeedForward(dim = dim, expansion_factor = expansion_factor))
|
|
832
708
|
|
|
709
|
+
# maybe hyper connections
|
|
710
|
+
|
|
711
|
+
mem_store_hc = init_hyper_conn(dim = dim, add_branch_out_to_residual = False)
|
|
712
|
+
|
|
713
|
+
if num_residual_streams > 1:
|
|
714
|
+
|
|
715
|
+
attn = init_hyper_conn(dim = dim, branch = attn)
|
|
716
|
+
|
|
717
|
+
ff = init_hyper_conn(dim = dim, branch = ff)
|
|
718
|
+
|
|
719
|
+
if gru_layers:
|
|
720
|
+
gru = init_hyper_conn(dim = dim, branch = gru)
|
|
721
|
+
|
|
722
|
+
if has_mem:
|
|
723
|
+
mem = init_hyper_conn(dim = dim, branch = mem, forward_method_names = ('store',))
|
|
724
|
+
|
|
833
725
|
layers.append(ModuleList([
|
|
834
|
-
attn, ff
|
|
726
|
+
gru, mem, mem_store_hc, attn, ff
|
|
835
727
|
]))
|
|
836
728
|
|
|
837
729
|
self.layers = layers
|
|
838
730
|
self.norm = RMSNorm(dim) if final_norm else Identity()
|
|
839
731
|
|
|
732
|
+
self.gru_layers = gru_layers
|
|
733
|
+
|
|
840
734
|
# fixed window size
|
|
841
735
|
|
|
842
736
|
self.fixed_window_size = fixed_window_size
|
|
@@ -845,92 +739,439 @@ class TransformerXL(Module):
|
|
|
845
739
|
def forward(
|
|
846
740
|
self,
|
|
847
741
|
x,
|
|
848
|
-
cache = None,
|
|
742
|
+
cache: TransformerMemory | None = None,
|
|
849
743
|
return_kv_cache = False,
|
|
850
|
-
condition: Tensor | None = None
|
|
744
|
+
condition: Tensor | None = None,
|
|
745
|
+
cond_mask: Tensor | None = None
|
|
851
746
|
):
|
|
747
|
+
curr_token_seq_len = x.shape[-2]
|
|
852
748
|
|
|
853
749
|
# cache and residuals
|
|
854
750
|
|
|
855
|
-
|
|
751
|
+
num_layers = len(self.layers)
|
|
752
|
+
|
|
753
|
+
# extract variables from cache
|
|
754
|
+
|
|
755
|
+
is_first_window = True
|
|
756
|
+
total_tokens = 0
|
|
757
|
+
kv_cache = gru_cache = mem_mlp_cache = mem_mlp_hidden_states = memory_segments = None
|
|
758
|
+
|
|
759
|
+
if exists(cache):
|
|
760
|
+
total_tokens, kv_cache, gru_cache, mem_mlp_cache, mem_mlp_hidden_states, memory_segments = cache
|
|
761
|
+
is_first_window = total_tokens < self.window_size
|
|
762
|
+
else:
|
|
763
|
+
memory_segments = deque(maxlen = self.max_mem_segments)
|
|
764
|
+
|
|
765
|
+
# handle memory segments
|
|
766
|
+
|
|
767
|
+
past_segments = None
|
|
768
|
+
if len(memory_segments) > 0:
|
|
769
|
+
past_segments = stack(list(memory_segments), dim = 0)
|
|
770
|
+
past_segments = rearrange(past_segments, 'l depth kv b h n d -> depth kv b h (l n) d')
|
|
771
|
+
|
|
772
|
+
kv_cache = default(kv_cache, (None,) * num_layers)
|
|
773
|
+
gru_cache = default(gru_cache, (None,) * num_layers)
|
|
774
|
+
mem_mlp_cache = default(mem_mlp_cache, (None,) * num_layers)
|
|
775
|
+
mem_mlp_hidden_states = default(mem_mlp_hidden_states, (None,) * num_layers)
|
|
776
|
+
|
|
777
|
+
# prepare next cache
|
|
856
778
|
|
|
857
779
|
next_kv_caches = []
|
|
780
|
+
next_gru_hiddens = [] if self.gru_layers else None
|
|
781
|
+
next_mem_mlp_cache = [] if self.has_mem else None
|
|
782
|
+
next_mem_mlp_hidden_states = [] if self.has_mem else None
|
|
783
|
+
next_total_tokens = total_tokens + curr_token_seq_len
|
|
784
|
+
|
|
785
|
+
is_window_boundary = divisible_by(next_total_tokens, self.window_size)
|
|
786
|
+
|
|
858
787
|
value_residual = None
|
|
859
788
|
|
|
860
789
|
# handle condition
|
|
861
790
|
|
|
862
791
|
cond_tokens = None
|
|
792
|
+
|
|
863
793
|
if exists(condition):
|
|
864
794
|
assert exists(self.to_cond_tokens)
|
|
865
795
|
cond_tokens = self.to_cond_tokens(condition)
|
|
866
796
|
|
|
797
|
+
cond_kwargs = dict(cond = cond_tokens, cond_mask = cond_mask)
|
|
798
|
+
|
|
799
|
+
# hc expand
|
|
800
|
+
|
|
801
|
+
x = self.expand_streams(x)
|
|
802
|
+
|
|
867
803
|
# layers
|
|
868
804
|
|
|
869
|
-
for (attn, ff),
|
|
805
|
+
for layer_index, ((maybe_gru, maybe_mem, maybe_mem_store_hc, attn, ff), layer_gru_cache, layer_mem_mlp, layer_kv_cache, layer_hidden_states) in enumerate(zip(self.layers, gru_cache, mem_mlp_cache, kv_cache, mem_mlp_hidden_states)):
|
|
870
806
|
|
|
871
|
-
|
|
807
|
+
# handle maybe rnn
|
|
872
808
|
|
|
873
|
-
|
|
874
|
-
|
|
809
|
+
if exists(maybe_gru):
|
|
810
|
+
x, gru_hiddens = maybe_gru(x, layer_gru_cache, **cond_kwargs)
|
|
811
|
+
|
|
812
|
+
next_gru_hiddens.append(gru_hiddens)
|
|
813
|
+
|
|
814
|
+
# maybe handle retrieving
|
|
815
|
+
|
|
816
|
+
is_mem_layer = exists(maybe_mem)
|
|
817
|
+
|
|
818
|
+
if (
|
|
819
|
+
not is_first_window and
|
|
820
|
+
is_mem_layer
|
|
821
|
+
):
|
|
822
|
+
x = maybe_mem(x, layer_mem_mlp)
|
|
823
|
+
|
|
824
|
+
# attention
|
|
825
|
+
|
|
826
|
+
layer_past_segments = None
|
|
827
|
+
if exists(past_segments):
|
|
828
|
+
layer_past_segments = past_segments[layer_index]
|
|
829
|
+
|
|
830
|
+
x, (next_kv_cache, values) = attn(x, **cond_kwargs, value_residual = value_residual, kv_cache = layer_kv_cache, past_segments = layer_past_segments, return_kv_cache = True)
|
|
831
|
+
|
|
832
|
+
# handle storing of memory
|
|
833
|
+
|
|
834
|
+
if self.has_mem:
|
|
835
|
+
next_mem_mlp = layer_mem_mlp
|
|
836
|
+
next_layer_hidden_states = layer_hidden_states
|
|
837
|
+
|
|
838
|
+
if is_mem_layer:
|
|
839
|
+
# accumulate hidden states
|
|
840
|
+
next_layer_hidden_states = safe_cat(layer_hidden_states, x, dim = -2)
|
|
841
|
+
|
|
842
|
+
if is_window_boundary:
|
|
843
|
+
mem_store_input, _ = maybe_mem_store_hc(next_layer_hidden_states)
|
|
844
|
+
|
|
845
|
+
next_mem_mlp = maybe_mem.store(mem_store_input, layer_mem_mlp)
|
|
846
|
+
next_layer_hidden_states = None
|
|
847
|
+
|
|
848
|
+
next_mem_mlp_cache.append(next_mem_mlp)
|
|
849
|
+
next_mem_mlp_hidden_states.append(next_layer_hidden_states)
|
|
850
|
+
|
|
851
|
+
# feedforward
|
|
852
|
+
|
|
853
|
+
x = ff(x, **cond_kwargs)
|
|
875
854
|
|
|
876
855
|
next_kv_caches.append(next_kv_cache)
|
|
877
856
|
value_residual = default(value_residual, values)
|
|
878
857
|
|
|
879
|
-
|
|
858
|
+
# hc reduce
|
|
880
859
|
|
|
881
|
-
|
|
882
|
-
|
|
860
|
+
x = self.reduce_streams(x)
|
|
861
|
+
|
|
862
|
+
# norm
|
|
863
|
+
|
|
864
|
+
embed = self.norm(x)
|
|
883
865
|
|
|
884
866
|
next_kv_cache = stack(next_kv_caches)
|
|
885
867
|
|
|
886
|
-
|
|
868
|
+
if exists(next_gru_hiddens):
|
|
869
|
+
next_gru_hiddens = stack(next_gru_hiddens)
|
|
870
|
+
|
|
871
|
+
next_cache = TransformerMemory(next_total_tokens, next_kv_cache, next_gru_hiddens, next_mem_mlp_cache, next_mem_mlp_hidden_states, memory_segments)
|
|
872
|
+
|
|
873
|
+
return embed, next_cache
|
|
874
|
+
|
|
875
|
+
# simple 2 layer memory mlp
|
|
876
|
+
# following ttt/titans
|
|
877
|
+
|
|
878
|
+
from torch.func import functional_call, grad, vmap
|
|
879
|
+
|
|
880
|
+
class MemoryMLP(Module):
|
|
881
|
+
def __init__(
|
|
882
|
+
self,
|
|
883
|
+
dim,
|
|
884
|
+
expansion_factor = 4.
|
|
885
|
+
):
|
|
886
|
+
super().__init__()
|
|
887
|
+
|
|
888
|
+
dim_hidden = int(dim * expansion_factor)
|
|
889
|
+
|
|
890
|
+
self.norm = nn.RMSNorm(dim)
|
|
891
|
+
|
|
892
|
+
# queries, keys, values
|
|
893
|
+
|
|
894
|
+
self.to_queries = Linear(dim, dim, bias = False)
|
|
895
|
+
|
|
896
|
+
self.to_key_values = nn.Sequential(
|
|
897
|
+
Linear(dim, dim * 2, bias = False),
|
|
898
|
+
nn.SiLU()
|
|
899
|
+
)
|
|
900
|
+
|
|
901
|
+
# memory mlp
|
|
902
|
+
|
|
903
|
+
self.mlp = MLP(dim, dim_hidden, dim, activation = nn.SiLU())
|
|
904
|
+
|
|
905
|
+
# initial params
|
|
906
|
+
|
|
907
|
+
self.init_mlp_params = dict(self.mlp.named_parameters())
|
|
908
|
+
|
|
909
|
+
# grad for storing
|
|
910
|
+
|
|
911
|
+
def retrieve_fn(params, queries: Tensor):
|
|
912
|
+
return functional_call(self.mlp, params, queries)
|
|
913
|
+
|
|
914
|
+
def loss_fn(params, inputs: tuple[Tensor, Tensor, Tensor]):
|
|
915
|
+
keys, values, learning_rate = inputs
|
|
916
|
+
pred = functional_call(self.mlp, params, keys)
|
|
917
|
+
loss = F.mse_loss(pred, values, reduction = 'none')
|
|
918
|
+
loss = loss * learning_rate
|
|
919
|
+
return loss.mean()
|
|
920
|
+
|
|
921
|
+
self.grad_fn = vmap(grad(loss_fn), in_dims = (0, (0, 0, 0)))
|
|
887
922
|
|
|
888
|
-
|
|
923
|
+
self.retrieve_fn = vmap(retrieve_fn, in_dims = (0, 0))
|
|
924
|
+
|
|
925
|
+
# forgetting
|
|
926
|
+
|
|
927
|
+
self.to_forget_gate = nn.Sequential(
|
|
928
|
+
Reduce('b n d -> b d', 'mean'),
|
|
929
|
+
nn.Linear(dim, 1, bias = False),
|
|
930
|
+
Rearrange('b 1 -> b'),
|
|
931
|
+
nn.Sigmoid()
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
# loss weight / learning rate
|
|
935
|
+
|
|
936
|
+
self.to_loss_weight = nn.Linear(dim, 1, bias = False)
|
|
937
|
+
|
|
938
|
+
def get_init_mlp_params(
|
|
939
|
+
self,
|
|
940
|
+
batch_size
|
|
941
|
+
):
|
|
942
|
+
return {name: repeat(params, '... -> b ...', b = batch_size) for name, params in self.init_mlp_params.items()}
|
|
943
|
+
|
|
944
|
+
def store(
|
|
945
|
+
self,
|
|
946
|
+
tokens, # (b n d)
|
|
947
|
+
memories: dict[str, Tensor] | None = None
|
|
948
|
+
):
|
|
949
|
+
|
|
950
|
+
batch_size = tokens.shape[0]
|
|
951
|
+
|
|
952
|
+
if not exists(memories):
|
|
953
|
+
memories = self.get_init_mlp_params(batch_size)
|
|
954
|
+
|
|
955
|
+
tokens = self.norm(tokens)
|
|
956
|
+
|
|
957
|
+
keys, values = self.to_key_values(tokens).chunk(2, dim = -1)
|
|
958
|
+
|
|
959
|
+
loss_weight = self.to_loss_weight(tokens)
|
|
960
|
+
|
|
961
|
+
grad = self.grad_fn(memories, (keys, values, loss_weight))
|
|
962
|
+
|
|
963
|
+
# prepare forget
|
|
964
|
+
|
|
965
|
+
forget = self.to_forget_gate(tokens)
|
|
966
|
+
|
|
967
|
+
# update memories
|
|
968
|
+
|
|
969
|
+
next_memories = dict()
|
|
970
|
+
|
|
971
|
+
for param_name, past_memory in memories.items():
|
|
972
|
+
change = grad[param_name]
|
|
973
|
+
|
|
974
|
+
past_memory = einx.multiply('b, b ...', forget, past_memory)
|
|
975
|
+
|
|
976
|
+
next_memories[param_name] = past_memory - change
|
|
977
|
+
|
|
978
|
+
return next_memories
|
|
979
|
+
|
|
980
|
+
def forward(
|
|
981
|
+
self,
|
|
982
|
+
tokens, # (b n d)
|
|
983
|
+
memories: dict[str, Tensor] | None = None
|
|
984
|
+
):
|
|
985
|
+
batch_size = tokens.shape[0]
|
|
986
|
+
|
|
987
|
+
if not exists(memories):
|
|
988
|
+
memories = self.get_init_mlp_params(batch_size)
|
|
989
|
+
|
|
990
|
+
tokens = self.norm(tokens)
|
|
991
|
+
|
|
992
|
+
queries = self.to_queries(tokens)
|
|
993
|
+
|
|
994
|
+
retrieved = self.retrieve_fn(memories, queries)
|
|
995
|
+
|
|
996
|
+
return retrieved
|
|
997
|
+
|
|
998
|
+
# state embedder
|
|
999
|
+
|
|
1000
|
+
class StateEmbedder(Module):
|
|
1001
|
+
@beartype
|
|
1002
|
+
def __init__(
|
|
1003
|
+
self,
|
|
1004
|
+
dim,
|
|
1005
|
+
dim_state: tuple[int, ...] | list[int] | int,
|
|
1006
|
+
num_internal_states: int | None = None,
|
|
1007
|
+
internal_states_selectors: list[list[int]] | None = None
|
|
1008
|
+
):
|
|
1009
|
+
super().__init__()
|
|
1010
|
+
dim_hidden = dim * 2
|
|
1011
|
+
|
|
1012
|
+
self.image_to_token = nn.Sequential(
|
|
1013
|
+
Rearrange('b t c h w -> b c t h w'),
|
|
1014
|
+
nn.Conv3d(3, dim_hidden, (1, 7, 7), padding = (0, 3, 3)),
|
|
1015
|
+
nn.ReLU(),
|
|
1016
|
+
nn.Conv3d(dim_hidden, dim_hidden, (1, 3, 3), stride = (1, 2, 2), padding = (0, 1, 1)),
|
|
1017
|
+
nn.ReLU(),
|
|
1018
|
+
nn.Conv3d(dim_hidden, dim_hidden, (1, 3, 3), stride = (1, 2, 2), padding = (0, 1, 1)),
|
|
1019
|
+
Reduce('b c t h w -> b t c', 'mean'),
|
|
1020
|
+
nn.Linear(dim_hidden, dim)
|
|
1021
|
+
)
|
|
1022
|
+
|
|
1023
|
+
dim_states = (dim_state,) if not isinstance(dim_state, (tuple, list)) else dim_state
|
|
1024
|
+
|
|
1025
|
+
self.dim_states = dim_states
|
|
1026
|
+
self.state_to_token = ModuleList([MLP(dim_state, dim, bias = False) for dim_state in dim_states])
|
|
1027
|
+
|
|
1028
|
+
# internal state embeds for each robot
|
|
1029
|
+
|
|
1030
|
+
self.internal_state_embedder = None
|
|
1031
|
+
|
|
1032
|
+
if exists(num_internal_states) and exists(internal_states_selectors):
|
|
1033
|
+
self.internal_state_embedder = Embed(
|
|
1034
|
+
dim,
|
|
1035
|
+
num_continuous = num_internal_states,
|
|
1036
|
+
selectors = internal_states_selectors
|
|
1037
|
+
)
|
|
1038
|
+
|
|
1039
|
+
@property
|
|
1040
|
+
def device(self):
|
|
1041
|
+
return next(self.parameters()).device
|
|
1042
|
+
|
|
1043
|
+
def forward(
|
|
1044
|
+
self,
|
|
1045
|
+
state,
|
|
1046
|
+
state_type,
|
|
1047
|
+
state_id = 0,
|
|
1048
|
+
internal_state = None,
|
|
1049
|
+
internal_state_selector_id: int | None = None
|
|
1050
|
+
):
|
|
1051
|
+
|
|
1052
|
+
if state_type == 'image':
|
|
1053
|
+
token_embeds = self.image_to_token(state)
|
|
1054
|
+
elif state_type == 'raw':
|
|
1055
|
+
state_to_token = self.state_to_token[state_id]
|
|
1056
|
+
token_embeds = state_to_token(state)
|
|
1057
|
+
else:
|
|
1058
|
+
raise ValueError('invalid state type')
|
|
1059
|
+
|
|
1060
|
+
if (
|
|
1061
|
+
exists(internal_state_selector_id) and
|
|
1062
|
+
exists(internal_state) and
|
|
1063
|
+
exists(self.internal_state_embedder)
|
|
1064
|
+
):
|
|
1065
|
+
internal_state = internal_state.to(self.device)
|
|
1066
|
+
|
|
1067
|
+
internal_state_embed = self.internal_state_embedder(internal_state, selector_index = internal_state_selector_id)
|
|
1068
|
+
|
|
1069
|
+
token_embeds = token_embeds + internal_state_embed
|
|
1070
|
+
|
|
1071
|
+
return token_embeds
|
|
889
1072
|
|
|
890
1073
|
# class
|
|
891
1074
|
|
|
1075
|
+
OneRewardShaper = Callable[..., float | Tensor]
|
|
1076
|
+
|
|
1077
|
+
MaybeOneRewardShaper = OneRewardShaper | None
|
|
1078
|
+
|
|
1079
|
+
@beartype
|
|
1080
|
+
def default_parse_env_reset_out(reset_out: tuple):
|
|
1081
|
+
assert len(reset_out) == 2
|
|
1082
|
+
return dict(zip(('state', 'info'), reset_out))
|
|
1083
|
+
|
|
1084
|
+
@beartype
|
|
1085
|
+
def default_parse_env_step_out(step_out: tuple):
|
|
1086
|
+
assert len(step_out) in {4, 5}
|
|
1087
|
+
|
|
1088
|
+
if len(step_out) == 5:
|
|
1089
|
+
data_dict = dict(zip(('state', 'reward', 'terminated', 'truncated', 'info'), step_out))
|
|
1090
|
+
elif len(step_out) == 4:
|
|
1091
|
+
data_dict = dict(zip(('state', 'reward', 'terminated', 'info'), step_out))
|
|
1092
|
+
data_dict['truncated'] = False
|
|
1093
|
+
|
|
1094
|
+
return data_dict
|
|
1095
|
+
|
|
892
1096
|
class Locoformer(Module):
|
|
893
1097
|
def __init__(
|
|
894
1098
|
self,
|
|
895
|
-
embedder:
|
|
896
|
-
unembedder:
|
|
1099
|
+
embedder: dict | Module,
|
|
1100
|
+
unembedder: dict | Readout,
|
|
897
1101
|
transformer: dict | TransformerXL,
|
|
1102
|
+
*,
|
|
898
1103
|
discount_factor = 0.999,
|
|
899
1104
|
gae_lam = 0.95,
|
|
900
1105
|
ppo_eps_clip = 0.2,
|
|
901
1106
|
ppo_entropy_weight = 0.01,
|
|
902
1107
|
ppo_value_clip = 0.4,
|
|
1108
|
+
ppo_soft_constrain_action_max = None,
|
|
1109
|
+
ppo_soft_constrain_action_loss_weight = 0.1,
|
|
903
1110
|
dim_value_input = None, # needs to be set for value network to be available
|
|
904
1111
|
value_network: Module = nn.Identity(),
|
|
1112
|
+
policy_network: Module = nn.Identity(),
|
|
1113
|
+
state_pred_network: Module | None = None,
|
|
1114
|
+
embed_past_action = False,
|
|
1115
|
+
state_pred_loss_weight = 0.05,
|
|
905
1116
|
reward_range: tuple[float, float] | None = None,
|
|
906
|
-
reward_shaping_fns:
|
|
1117
|
+
reward_shaping_fns: (
|
|
1118
|
+
MaybeOneRewardShaper |
|
|
1119
|
+
list[MaybeOneRewardShaper] |
|
|
1120
|
+
list[list[MaybeOneRewardShaper]]
|
|
1121
|
+
) = None,
|
|
907
1122
|
num_reward_bins = 32,
|
|
908
1123
|
hl_gauss_loss_kwargs = dict(),
|
|
909
1124
|
value_loss_weight = 0.5,
|
|
910
1125
|
calc_gae_kwargs: dict = dict(),
|
|
911
|
-
|
|
912
|
-
|
|
1126
|
+
parse_env_reset_out: Callable | None = None,
|
|
1127
|
+
parse_env_step_out: Callable | None = None,
|
|
1128
|
+
recurrent_cache = True,
|
|
1129
|
+
use_spo = False, # simple policy optimization https://arxiv.org/abs/2401.16025 - Levine's group (PI) verified it is more stable than PPO
|
|
1130
|
+
asymmetric_spo = False, # https://openreview.net/pdf?id=BA6n0nmagi
|
|
1131
|
+
max_mem_segments = 1
|
|
913
1132
|
):
|
|
914
1133
|
super().__init__()
|
|
915
1134
|
|
|
916
1135
|
if isinstance(transformer, dict):
|
|
917
|
-
transformer = TransformerXL(**transformer)
|
|
1136
|
+
transformer = TransformerXL(max_mem_segments = max_mem_segments, **transformer)
|
|
918
1137
|
|
|
919
1138
|
self.transformer = transformer
|
|
920
1139
|
|
|
921
1140
|
# handle state embedder
|
|
922
1141
|
|
|
923
|
-
if isinstance(embedder,
|
|
924
|
-
embedder =
|
|
1142
|
+
if isinstance(embedder, dict):
|
|
1143
|
+
embedder = StateEmbedder(**embedder)
|
|
925
1144
|
|
|
926
1145
|
self.embedder = embedder
|
|
927
1146
|
|
|
928
1147
|
# unembed state to actions or ssl predictions
|
|
929
1148
|
|
|
1149
|
+
action_embedder = None
|
|
1150
|
+
if isinstance(unembedder, dict):
|
|
1151
|
+
action_embedder, unembedder = EmbedAndReadout(
|
|
1152
|
+
explicit_single_action_dim_given = True,
|
|
1153
|
+
**unembedder,
|
|
1154
|
+
)
|
|
1155
|
+
|
|
930
1156
|
self.unembedder = unembedder
|
|
931
1157
|
|
|
1158
|
+
# embedding past actions
|
|
1159
|
+
|
|
1160
|
+
self.past_action_embedder = None
|
|
1161
|
+
self.embed_past_action = embed_past_action
|
|
1162
|
+
|
|
1163
|
+
if embed_past_action and exists(action_embedder):
|
|
1164
|
+
self.past_action_embedder = action_embedder
|
|
1165
|
+
|
|
1166
|
+
# attention window related
|
|
1167
|
+
|
|
932
1168
|
self.fixed_window_size = transformer.fixed_window_size
|
|
933
1169
|
self.window_size = transformer.window_size
|
|
1170
|
+
self.max_mem_segments = max_mem_segments
|
|
1171
|
+
|
|
1172
|
+
# policy network
|
|
1173
|
+
|
|
1174
|
+
self.policy_network = policy_network
|
|
934
1175
|
|
|
935
1176
|
# determine value network, using HL Gauss Layer
|
|
936
1177
|
|
|
@@ -953,6 +1194,22 @@ class Locoformer(Module):
|
|
|
953
1194
|
**hl_gauss_loss_kwargs
|
|
954
1195
|
)
|
|
955
1196
|
|
|
1197
|
+
# state prediction related
|
|
1198
|
+
|
|
1199
|
+
self.can_pred_state = exists(state_pred_network)
|
|
1200
|
+
self.state_pred_network = state_pred_network
|
|
1201
|
+
|
|
1202
|
+
if exists(state_pred_network):
|
|
1203
|
+
dim_states = self.embedder.dim_states
|
|
1204
|
+
total_dim_states = sum(dim_states)
|
|
1205
|
+
|
|
1206
|
+
selectors = [t.tolist() for t in arange(total_dim_states).split(dim_states)]
|
|
1207
|
+
|
|
1208
|
+
self.state_pred_head = Readout(transformer.dim, num_continuous = total_dim_states, selectors = selectors)
|
|
1209
|
+
|
|
1210
|
+
self.has_state_pred_loss = state_pred_loss_weight > 0.
|
|
1211
|
+
self.state_pred_loss_weight = state_pred_loss_weight
|
|
1212
|
+
|
|
956
1213
|
# ppo related
|
|
957
1214
|
|
|
958
1215
|
self.discount_factor = discount_factor
|
|
@@ -960,6 +1217,9 @@ class Locoformer(Module):
|
|
|
960
1217
|
self.ppo_eps_clip = ppo_eps_clip
|
|
961
1218
|
self.ppo_entropy_weight = ppo_entropy_weight
|
|
962
1219
|
self.ppo_value_clip = ppo_value_clip
|
|
1220
|
+
self.ppo_soft_constrain_action_max = ppo_soft_constrain_action_max
|
|
1221
|
+
self.ppo_soft_constrain_action_loss_weight = ppo_soft_constrain_action_loss_weight
|
|
1222
|
+
|
|
963
1223
|
self.value_loss_weight = value_loss_weight
|
|
964
1224
|
|
|
965
1225
|
self.calc_gae_kwargs = calc_gae_kwargs
|
|
@@ -968,14 +1228,26 @@ class Locoformer(Module):
|
|
|
968
1228
|
|
|
969
1229
|
self.use_spo = use_spo
|
|
970
1230
|
|
|
1231
|
+
self.asymmetric_spo = asymmetric_spo
|
|
1232
|
+
|
|
971
1233
|
# maybe recurrent kv cache, from Ding et al. https://arxiv.org/abs/2012.15688
|
|
972
1234
|
|
|
973
|
-
self.
|
|
1235
|
+
self.recurrent_cache = recurrent_cache
|
|
1236
|
+
|
|
1237
|
+
# environment returns to dictionary
|
|
1238
|
+
|
|
1239
|
+
self.parse_env_reset_out = default(parse_env_reset_out, default_parse_env_reset_out)
|
|
1240
|
+
self.parse_env_step_out = default(parse_env_step_out, default_parse_env_step_out)
|
|
974
1241
|
|
|
975
1242
|
# reward shaping function
|
|
976
1243
|
|
|
977
1244
|
self.has_reward_shaping = exists(reward_shaping_fns)
|
|
1245
|
+
|
|
1246
|
+
if is_bearable(reward_shaping_fns, OneRewardShaper):
|
|
1247
|
+
reward_shaping_fns = [reward_shaping_fns]
|
|
1248
|
+
|
|
978
1249
|
self.reward_shaping_fns = reward_shaping_fns
|
|
1250
|
+
self.reward_shaping_fns_multiple_envs = is_bearable(reward_shaping_fns, list[list[OneRewardShaper]])
|
|
979
1251
|
|
|
980
1252
|
# loss related
|
|
981
1253
|
|
|
@@ -986,7 +1258,10 @@ class Locoformer(Module):
|
|
|
986
1258
|
return next(self.parameters()).device
|
|
987
1259
|
|
|
988
1260
|
def actor_parameters(self):
|
|
989
|
-
return
|
|
1261
|
+
return [
|
|
1262
|
+
*self.policy_network.parameters(),
|
|
1263
|
+
*self.unembedder.parameters()
|
|
1264
|
+
]
|
|
990
1265
|
|
|
991
1266
|
def critic_parameters(self):
|
|
992
1267
|
if not exists(self.to_value_pred):
|
|
@@ -994,6 +1269,69 @@ class Locoformer(Module):
|
|
|
994
1269
|
|
|
995
1270
|
return self.to_value_pred.parameters()
|
|
996
1271
|
|
|
1272
|
+
@beartype
|
|
1273
|
+
def learn(
|
|
1274
|
+
self,
|
|
1275
|
+
optims,
|
|
1276
|
+
accelerator,
|
|
1277
|
+
replay,
|
|
1278
|
+
state_embed_kwargs: dict,
|
|
1279
|
+
action_select_kwargs: dict,
|
|
1280
|
+
state_id_kwarg: dict = dict(),
|
|
1281
|
+
batch_size = 16,
|
|
1282
|
+
epochs = 2,
|
|
1283
|
+
use_vision = False,
|
|
1284
|
+
compute_state_pred_loss = False,
|
|
1285
|
+
state_pred_loss_weight = None,
|
|
1286
|
+
maybe_construct_trial_from_buffer: Callable[[ReplayBuffer], Tensor] | None = None
|
|
1287
|
+
):
|
|
1288
|
+
state_field = 'state_image' if use_vision else 'state'
|
|
1289
|
+
|
|
1290
|
+
episode_mapping = None
|
|
1291
|
+
|
|
1292
|
+
if exists(maybe_construct_trial_from_buffer):
|
|
1293
|
+
episode_mapping = maybe_construct_trial_from_buffer(replay)
|
|
1294
|
+
|
|
1295
|
+
dataset = replay.dataset()
|
|
1296
|
+
|
|
1297
|
+
if exists(episode_mapping):
|
|
1298
|
+
dataset = RemappedReplayDataset(dataset, episode_mapping)
|
|
1299
|
+
|
|
1300
|
+
dl = replay.dataloader(
|
|
1301
|
+
batch_size = batch_size,
|
|
1302
|
+
dataset = dataset,
|
|
1303
|
+
shuffle = True
|
|
1304
|
+
)
|
|
1305
|
+
|
|
1306
|
+
self, dl, *optims = accelerator.prepare(self, dl, *optims)
|
|
1307
|
+
|
|
1308
|
+
for _ in range(epochs):
|
|
1309
|
+
for data in dl:
|
|
1310
|
+
|
|
1311
|
+
data = SimpleNamespace(**data)
|
|
1312
|
+
|
|
1313
|
+
actor_loss, critic_loss = self.ppo(
|
|
1314
|
+
state = getattr(data, state_field),
|
|
1315
|
+
internal_state = getattr(data, 'internal_state', None),
|
|
1316
|
+
action = data.action,
|
|
1317
|
+
action_log_prob = data.action_log_prob,
|
|
1318
|
+
reward = data.reward,
|
|
1319
|
+
value = data.value,
|
|
1320
|
+
done = data.done,
|
|
1321
|
+
condition = getattr(data, 'condition', None),
|
|
1322
|
+
cond_mask = getattr(data, 'cond_mask', None),
|
|
1323
|
+
episode_lens = data._lens,
|
|
1324
|
+
optims = optims,
|
|
1325
|
+
state_embed_kwargs = state_embed_kwargs,
|
|
1326
|
+
action_select_kwargs = action_select_kwargs,
|
|
1327
|
+
state_id_kwarg = state_id_kwarg,
|
|
1328
|
+
compute_state_pred_loss = compute_state_pred_loss,
|
|
1329
|
+
state_pred_loss_weight = state_pred_loss_weight,
|
|
1330
|
+
accelerator = accelerator
|
|
1331
|
+
)
|
|
1332
|
+
|
|
1333
|
+
accelerator.print(f'actor: {actor_loss.item():.3f} | critic: {critic_loss.item():.3f}')
|
|
1334
|
+
|
|
997
1335
|
def evolve(
|
|
998
1336
|
self,
|
|
999
1337
|
environment,
|
|
@@ -1005,47 +1343,66 @@ class Locoformer(Module):
|
|
|
1005
1343
|
def ppo(
|
|
1006
1344
|
self,
|
|
1007
1345
|
state,
|
|
1346
|
+
internal_state,
|
|
1008
1347
|
action,
|
|
1009
|
-
|
|
1348
|
+
action_log_prob,
|
|
1010
1349
|
reward,
|
|
1011
|
-
|
|
1012
|
-
|
|
1350
|
+
value,
|
|
1351
|
+
done,
|
|
1013
1352
|
episode_lens,
|
|
1014
1353
|
condition: Tensor | None = None,
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1354
|
+
cond_mask: Tensor | None = None,
|
|
1355
|
+
optims: list[Optimizer] | None = None,
|
|
1356
|
+
state_embed_kwargs: dict = dict(),
|
|
1357
|
+
action_select_kwargs: dict = dict(),
|
|
1358
|
+
state_id_kwarg: dict = dict(),
|
|
1359
|
+
compute_state_pred_loss = True,
|
|
1360
|
+
state_pred_loss_weight = None,
|
|
1361
|
+
accelerator = None,
|
|
1362
|
+
max_grad_norm = 0.5
|
|
1018
1363
|
):
|
|
1019
|
-
|
|
1020
|
-
total_learnable_tokens = mask.sum().item()
|
|
1364
|
+
state_pred_loss_weight = default(state_pred_loss_weight, self.state_pred_loss_weight)
|
|
1021
1365
|
|
|
1366
|
+
window_size = self.window_size
|
|
1367
|
+
mask = ~done
|
|
1022
1368
|
seq_len = state.shape[1]
|
|
1023
|
-
|
|
1369
|
+
padding_mask = einx.less('j, i -> i j', arange(seq_len, device = self.device), episode_lens)
|
|
1370
|
+
gae_mask = padding_mask & mask
|
|
1371
|
+
|
|
1372
|
+
total_learnable_tokens = gae_mask.sum().item()
|
|
1373
|
+
|
|
1374
|
+
advantage, returns = calc_gae(reward, value, masks = gae_mask, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
|
|
1024
1375
|
|
|
1025
|
-
advantage
|
|
1376
|
+
advantage = normalize(advantage, mask = gae_mask)
|
|
1026
1377
|
|
|
1027
|
-
advantage =
|
|
1378
|
+
advantage = rearrange(advantage, '... -> ... 1')
|
|
1028
1379
|
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1380
|
+
past_action = pad_at_dim(action, (1, -1), dim = -2)
|
|
1381
|
+
|
|
1382
|
+
data_dict = dict(
|
|
1383
|
+
state = state,
|
|
1384
|
+
internal_state = internal_state,
|
|
1385
|
+
action = action,
|
|
1386
|
+
past_action = past_action,
|
|
1387
|
+
old_action_log_prob = action_log_prob,
|
|
1388
|
+
reward = reward,
|
|
1389
|
+
mask = mask,
|
|
1390
|
+
advantage = advantage,
|
|
1391
|
+
returns = returns,
|
|
1392
|
+
windowed_gae_mask = gae_mask,
|
|
1393
|
+
condition = condition,
|
|
1394
|
+
cond_mask = cond_mask
|
|
1038
1395
|
)
|
|
1039
1396
|
|
|
1040
|
-
|
|
1397
|
+
num_windows = math.ceil(seq_len / window_size)
|
|
1041
1398
|
|
|
1042
|
-
|
|
1043
|
-
data_tensors = (*data_tensors, condition)
|
|
1399
|
+
windowed_data = dict()
|
|
1044
1400
|
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1401
|
+
for name, tensor in data_dict.items():
|
|
1402
|
+
if exists(tensor):
|
|
1403
|
+
windowed_data[name] = tensor.split(window_size, dim = 1)
|
|
1404
|
+
else:
|
|
1405
|
+
windowed_data[name] = (None,) * num_windows
|
|
1049
1406
|
|
|
1050
1407
|
mean_actor_loss = self.zero.clone()
|
|
1051
1408
|
mean_critic_loss = self.zero.clone()
|
|
@@ -1054,56 +1411,90 @@ class Locoformer(Module):
|
|
|
1054
1411
|
|
|
1055
1412
|
cache = None
|
|
1056
1413
|
|
|
1057
|
-
for (
|
|
1058
|
-
state,
|
|
1059
|
-
action,
|
|
1060
|
-
old_action_log_prob,
|
|
1061
|
-
reward,
|
|
1062
|
-
old_value,
|
|
1063
|
-
mask,
|
|
1064
|
-
advantage,
|
|
1065
|
-
returns,
|
|
1066
|
-
*rest
|
|
1067
|
-
) in zip(*windowed_tensors):
|
|
1414
|
+
for window_tensors in zip(*windowed_data.values()):
|
|
1068
1415
|
|
|
1069
|
-
|
|
1070
|
-
condition, = rest
|
|
1416
|
+
data = SimpleNamespace(**dict(zip(windowed_data.keys(), window_tensors)))
|
|
1071
1417
|
|
|
1072
|
-
(action_logits, value_logits), cache = self.forward(state, condition = condition,
|
|
1073
|
-
entropy = calc_entropy(action_logits)
|
|
1418
|
+
((action_logits, maybe_state_pred), value_logits), cache = self.forward(data.state, past_action = data.past_action if self.embed_past_action else None, state_embed_kwargs = {**state_embed_kwargs, 'internal_state': data.internal_state}, action_select_kwargs = action_select_kwargs, state_id_kwarg = state_id_kwarg, condition = data.condition, cond_mask = data.cond_mask, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True, return_state_pred = True)
|
|
1074
1419
|
|
|
1075
|
-
|
|
1076
|
-
log_prob = action_logits.gather(-1, action)
|
|
1077
|
-
log_prob = rearrange(log_prob, 'b t 1 -> b t')
|
|
1420
|
+
log_prob = self.unembedder.log_prob(action_logits, data.action, **action_select_kwargs)
|
|
1078
1421
|
|
|
1079
1422
|
# update actor, classic clipped surrogate loss
|
|
1080
1423
|
|
|
1081
1424
|
eps_clip = self.ppo_eps_clip
|
|
1082
|
-
ratio = (log_prob - old_action_log_prob).exp()
|
|
1425
|
+
ratio = (log_prob - data.old_action_log_prob).exp()
|
|
1426
|
+
|
|
1427
|
+
calc_spo = lambda: -(ratio * data.advantage - (data.advantage.abs() * (ratio - 1.).square()) / (2 * eps_clip))
|
|
1083
1428
|
|
|
1084
|
-
|
|
1085
|
-
|
|
1429
|
+
calc_ppo = lambda: -torch.min(ratio * data.advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * data.advantage)
|
|
1430
|
+
|
|
1431
|
+
if self.asymmetric_spo:
|
|
1432
|
+
actor_loss = torch.where(data.advantage >= 0, calc_ppo(), calc_spo())
|
|
1433
|
+
elif self.use_spo:
|
|
1434
|
+
actor_loss = calc_spo()
|
|
1086
1435
|
else:
|
|
1087
|
-
actor_loss =
|
|
1436
|
+
actor_loss = calc_ppo()
|
|
1088
1437
|
|
|
1089
|
-
|
|
1438
|
+
# maybe entropy
|
|
1090
1439
|
|
|
1091
|
-
|
|
1092
|
-
|
|
1440
|
+
if self.ppo_entropy_weight > 0.:
|
|
1441
|
+
entropy = self.unembedder.entropy(action_logits, **action_select_kwargs)
|
|
1093
1442
|
|
|
1094
|
-
|
|
1443
|
+
if exists(entropy):
|
|
1444
|
+
actor_loss = actor_loss - self.ppo_entropy_weight * entropy
|
|
1445
|
+
|
|
1446
|
+
windowed_actor_loss = actor_loss[data.windowed_gae_mask].sum() / total_learnable_tokens
|
|
1447
|
+
|
|
1448
|
+
# maybe add state prediction
|
|
1095
1449
|
|
|
1096
|
-
|
|
1450
|
+
if (
|
|
1451
|
+
exists(maybe_state_pred) and
|
|
1452
|
+
self.has_state_pred_loss and
|
|
1453
|
+
compute_state_pred_loss and
|
|
1454
|
+
data.windowed_gae_mask[:, :-1].any()
|
|
1455
|
+
):
|
|
1456
|
+
state_pred = maybe_state_pred[:, :-1]
|
|
1457
|
+
state_labels = data.state[:, 1:]
|
|
1458
|
+
loss_mask = data.windowed_gae_mask[:, :-1]
|
|
1459
|
+
|
|
1460
|
+
state_id = state_id_kwarg.get('state_id', 0)
|
|
1461
|
+
|
|
1462
|
+
state_pred_loss = self.state_pred_head.calculate_loss(state_pred, state_labels, selector_index = state_id, return_unreduced_loss = True)
|
|
1463
|
+
|
|
1464
|
+
state_pred_loss = state_pred_loss.mean(dim = -1) # average over state features
|
|
1465
|
+
|
|
1466
|
+
windowed_state_pred_loss = state_pred_loss[loss_mask].sum() / total_learnable_tokens
|
|
1467
|
+
|
|
1468
|
+
windowed_actor_loss = (
|
|
1469
|
+
windowed_actor_loss +
|
|
1470
|
+
windowed_state_pred_loss * state_pred_loss_weight
|
|
1471
|
+
)
|
|
1472
|
+
|
|
1473
|
+
# maybe soft constrain continuous actions
|
|
1474
|
+
|
|
1475
|
+
if (
|
|
1476
|
+
self.ppo_soft_constrain_action_max and
|
|
1477
|
+
self.unembedder.has_continuous
|
|
1478
|
+
):
|
|
1479
|
+
loss_mask = data.windowed_gae_mask
|
|
1480
|
+
|
|
1481
|
+
soft_constrain_loss = (action_logits[..., 0].abs() - self.ppo_soft_constrain_action_max).relu().pow(2)
|
|
1482
|
+
windowed_soft_constrain_loss = soft_constrain_loss[loss_mask].sum() / total_learnable_tokens
|
|
1097
1483
|
|
|
1098
|
-
|
|
1099
|
-
|
|
1484
|
+
windowed_actor_loss = (
|
|
1485
|
+
windowed_actor_loss +
|
|
1486
|
+
windowed_soft_constrain_loss * self.ppo_soft_constrain_action_loss_weight
|
|
1487
|
+
)
|
|
1100
1488
|
|
|
1101
|
-
|
|
1102
|
-
clipped_value_loss = self.hl_gauss_loss(clipped_value, returns, reduction = 'none')
|
|
1489
|
+
# windowed loss
|
|
1103
1490
|
|
|
1104
|
-
|
|
1491
|
+
windowed_actor_loss.backward(retain_graph = True)
|
|
1492
|
+
|
|
1493
|
+
# update critic
|
|
1105
1494
|
|
|
1106
|
-
|
|
1495
|
+
value_loss = self.hl_gauss_loss(value_logits, data.returns, reduction = 'none') * self.value_loss_weight
|
|
1496
|
+
|
|
1497
|
+
windowed_critic_loss = value_loss[data.windowed_gae_mask].sum() / total_learnable_tokens
|
|
1107
1498
|
windowed_critic_loss.backward(retain_graph = True)
|
|
1108
1499
|
|
|
1109
1500
|
# accumulate
|
|
@@ -1113,13 +1504,16 @@ class Locoformer(Module):
|
|
|
1113
1504
|
|
|
1114
1505
|
# optimizer update
|
|
1115
1506
|
|
|
1116
|
-
if exists(
|
|
1117
|
-
|
|
1118
|
-
|
|
1507
|
+
if exists(optims):
|
|
1508
|
+
|
|
1509
|
+
if exists(accelerator):
|
|
1510
|
+
accelerator.clip_grad_norm_(self.parameters(), max_grad_norm)
|
|
1511
|
+
else:
|
|
1512
|
+
nn.utils.clip_grad_norm_(self.parameters(), max_grad_norm)
|
|
1119
1513
|
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1514
|
+
for optim in optims:
|
|
1515
|
+
optim.step()
|
|
1516
|
+
optim.zero_grad()
|
|
1123
1517
|
|
|
1124
1518
|
# return losses for logging
|
|
1125
1519
|
|
|
@@ -1128,14 +1522,18 @@ class Locoformer(Module):
|
|
|
1128
1522
|
def state_and_command_to_rewards(
|
|
1129
1523
|
self,
|
|
1130
1524
|
state,
|
|
1131
|
-
commands = None
|
|
1525
|
+
commands = None,
|
|
1526
|
+
env_index: int | None = None
|
|
1132
1527
|
) -> Tensor:
|
|
1133
1528
|
|
|
1134
1529
|
assert self.has_reward_shaping
|
|
1530
|
+
assert xnor(exists(env_index), self.reward_shaping_fns_multiple_envs), f'`env_index` must be passed in if multiple reward shaping functions are defined, and vice versa (not passed in if only single list of reward shaping functions)'
|
|
1135
1531
|
|
|
1136
1532
|
rewards = []
|
|
1137
1533
|
|
|
1138
|
-
|
|
1534
|
+
reward_shaping_fns = self.reward_shaping_fns[env_index] if exists(env_index) else self.reward_shaping_fns
|
|
1535
|
+
|
|
1536
|
+
for fn in reward_shaping_fns:
|
|
1139
1537
|
param_names = get_param_names(fn)
|
|
1140
1538
|
param_names = set(param_names) & {'state', 'command'}
|
|
1141
1539
|
|
|
@@ -1152,9 +1550,23 @@ class Locoformer(Module):
|
|
|
1152
1550
|
|
|
1153
1551
|
rewards = [tensor(reward) if not is_tensor(reward) else reward for reward in rewards]
|
|
1154
1552
|
|
|
1155
|
-
|
|
1553
|
+
assert all([r.numel() == 1 for r in rewards])
|
|
1554
|
+
|
|
1555
|
+
if len(rewards) == 0:
|
|
1556
|
+
return None
|
|
1156
1557
|
|
|
1157
|
-
|
|
1558
|
+
packed_rewards, _ = pack(rewards, '*')
|
|
1559
|
+
return packed_rewards
|
|
1560
|
+
|
|
1561
|
+
@beartype
|
|
1562
|
+
def wrap_env_functions(
|
|
1563
|
+
self,
|
|
1564
|
+
env,
|
|
1565
|
+
env_output_transforms: dict[str, Callable] = dict(),
|
|
1566
|
+
state_transform: Callable = identity,
|
|
1567
|
+
reward_norm = 1.,
|
|
1568
|
+
command_generator: Callable = always(None)
|
|
1569
|
+
):
|
|
1158
1570
|
|
|
1159
1571
|
def transform_output(el):
|
|
1160
1572
|
if isinstance(el, ndarray):
|
|
@@ -1167,23 +1579,57 @@ class Locoformer(Module):
|
|
|
1167
1579
|
def wrapped_reset(*args, **kwargs):
|
|
1168
1580
|
env_reset_out = env.reset(*args, **kwargs)
|
|
1169
1581
|
|
|
1170
|
-
|
|
1582
|
+
env_reset_out_torch = tree_map(transform_output, env_reset_out)
|
|
1583
|
+
|
|
1584
|
+
env_reset_out_dict = self.parse_env_reset_out(env_reset_out_torch)
|
|
1585
|
+
|
|
1586
|
+
env_reset_out_dict['state'] = state_transform(env_reset_out_dict['state'])
|
|
1587
|
+
|
|
1588
|
+
derived_states = dict()
|
|
1171
1589
|
|
|
1172
|
-
|
|
1590
|
+
for derived_name, transform in env_output_transforms.items():
|
|
1591
|
+
derived_states[derived_name] = transform(env_reset_out_dict, env)
|
|
1592
|
+
|
|
1593
|
+
env_reset_out_dict['derived_state'] = derived_states
|
|
1594
|
+
|
|
1595
|
+
return env_reset_out_dict
|
|
1596
|
+
|
|
1597
|
+
def wrapped_step(action, *args, command = None, env_index = None, **kwargs):
|
|
1173
1598
|
|
|
1174
1599
|
if is_tensor(action):
|
|
1175
|
-
|
|
1600
|
+
if action.numel() == 1:
|
|
1601
|
+
action = action.item()
|
|
1602
|
+
else:
|
|
1603
|
+
action = action.tolist()
|
|
1176
1604
|
|
|
1177
1605
|
env_step_out = env.step(action, *args, **kwargs)
|
|
1178
1606
|
|
|
1179
1607
|
env_step_out_torch = tree_map(transform_output, env_step_out)
|
|
1180
1608
|
|
|
1181
|
-
|
|
1182
|
-
|
|
1609
|
+
env_step_out_dict = self.parse_env_step_out(env_step_out_torch)
|
|
1610
|
+
|
|
1611
|
+
env_step_out_dict['state'] = state_transform(env_step_out_dict['state'])
|
|
1612
|
+
|
|
1613
|
+
env_step_out_dict['reward'] = env_step_out_dict['reward'] / reward_norm
|
|
1183
1614
|
|
|
1184
|
-
|
|
1615
|
+
if self.has_reward_shaping:
|
|
1616
|
+
shaped_rewards = self.state_and_command_to_rewards(env_step_out_dict['state'], command, env_index = env_index)
|
|
1185
1617
|
|
|
1186
|
-
|
|
1618
|
+
if exists(shaped_rewards):
|
|
1619
|
+
env_step_out_dict['shaped_rewards'] = shaped_rewards
|
|
1620
|
+
|
|
1621
|
+
# add shaped rewards to main reward
|
|
1622
|
+
|
|
1623
|
+
env_step_out_dict['reward'] = env_step_out_dict['reward'] + shaped_rewards.sum()
|
|
1624
|
+
|
|
1625
|
+
derived_states = dict()
|
|
1626
|
+
|
|
1627
|
+
for derived_name, transform in env_output_transforms.items():
|
|
1628
|
+
derived_states[derived_name] = transform(env_step_out_dict, env)
|
|
1629
|
+
|
|
1630
|
+
env_step_out_dict['derived_state'] = derived_states
|
|
1631
|
+
|
|
1632
|
+
return env_step_out_dict
|
|
1187
1633
|
|
|
1188
1634
|
return wrapped_reset, wrapped_step
|
|
1189
1635
|
|
|
@@ -1196,14 +1642,13 @@ class Locoformer(Module):
|
|
|
1196
1642
|
state_time_dim = 1,
|
|
1197
1643
|
**kwargs
|
|
1198
1644
|
):
|
|
1199
|
-
window_size = self.window_size
|
|
1200
1645
|
|
|
1201
1646
|
cache = None
|
|
1202
1647
|
|
|
1203
1648
|
def stateful_forward(
|
|
1204
1649
|
state: Tensor,
|
|
1205
1650
|
condition: Tensor | None = None,
|
|
1206
|
-
|
|
1651
|
+
cond_mask: Tensor | None = None,
|
|
1207
1652
|
**override_kwargs
|
|
1208
1653
|
):
|
|
1209
1654
|
nonlocal cache
|
|
@@ -1213,6 +1658,9 @@ class Locoformer(Module):
|
|
|
1213
1658
|
if exists(condition):
|
|
1214
1659
|
condition = condition.to(self.device)
|
|
1215
1660
|
|
|
1661
|
+
if exists(cond_mask):
|
|
1662
|
+
cond_mask = cond_mask.to(self.device)
|
|
1663
|
+
|
|
1216
1664
|
# handle no batch or time, for easier time rolling out against envs
|
|
1217
1665
|
|
|
1218
1666
|
if not has_batch_dim:
|
|
@@ -1221,15 +1669,27 @@ class Locoformer(Module):
|
|
|
1221
1669
|
if exists(condition):
|
|
1222
1670
|
condition = rearrange(condition, '... -> 1 ...')
|
|
1223
1671
|
|
|
1672
|
+
if exists(cond_mask):
|
|
1673
|
+
cond_mask = rearrange(cond_mask, '... -> 1 ...')
|
|
1674
|
+
|
|
1224
1675
|
if not has_time_dim:
|
|
1225
1676
|
state = state.unsqueeze(state_time_dim)
|
|
1226
1677
|
|
|
1227
1678
|
if exists(condition):
|
|
1228
1679
|
condition = rearrange(condition, '... d -> ... 1 d')
|
|
1229
1680
|
|
|
1681
|
+
if exists(cond_mask):
|
|
1682
|
+
cond_mask = cond_mask.unsqueeze(state_time_dim)
|
|
1683
|
+
|
|
1230
1684
|
# forwards
|
|
1231
1685
|
|
|
1232
|
-
out, cache = self.forward(
|
|
1686
|
+
out, cache = self.forward(
|
|
1687
|
+
state,
|
|
1688
|
+
condition = condition,
|
|
1689
|
+
cond_mask = cond_mask,
|
|
1690
|
+
cache = cache,
|
|
1691
|
+
**{**kwargs, **override_kwargs}
|
|
1692
|
+
)
|
|
1233
1693
|
|
|
1234
1694
|
# maybe remove batch or time
|
|
1235
1695
|
|
|
@@ -1260,60 +1720,275 @@ class Locoformer(Module):
|
|
|
1260
1720
|
|
|
1261
1721
|
return stateful_forward, initial_logits
|
|
1262
1722
|
|
|
1723
|
+
@beartype
|
|
1724
|
+
def gather_experience_from_env_(
|
|
1725
|
+
self,
|
|
1726
|
+
wrapped_env_functions: tuple[Callable, Callable],
|
|
1727
|
+
replay: ReplayBuffer,
|
|
1728
|
+
embed_past_action = False,
|
|
1729
|
+
max_timesteps = None,
|
|
1730
|
+
use_vision = False,
|
|
1731
|
+
action_select_kwargs: dict = dict(),
|
|
1732
|
+
state_embed_kwargs: dict = dict(),
|
|
1733
|
+
state_id_kwarg: dict = dict(),
|
|
1734
|
+
env_index: int | None = None,
|
|
1735
|
+
state_entropy_bonus_weight = 0.,
|
|
1736
|
+
action_rescale_range: tuple[float, float] | None = None,
|
|
1737
|
+
command_fn: Callable = always(None)
|
|
1738
|
+
):
|
|
1739
|
+
|
|
1740
|
+
env_reset, env_step = wrapped_env_functions
|
|
1741
|
+
|
|
1742
|
+
reset_out_dict = env_reset()
|
|
1743
|
+
derived, state = pick(reset_out_dict, ('derived_state', 'state'))
|
|
1744
|
+
|
|
1745
|
+
state_image = derived.get('state_image', None)
|
|
1746
|
+
internal_state = derived.get('internal_state', None)
|
|
1747
|
+
|
|
1748
|
+
timestep = 0
|
|
1749
|
+
|
|
1750
|
+
max_timesteps = default(max_timesteps, replay.max_timesteps)
|
|
1751
|
+
|
|
1752
|
+
stateful_forward = self.get_stateful_forward(
|
|
1753
|
+
has_batch_dim = False,
|
|
1754
|
+
has_time_dim = False,
|
|
1755
|
+
inference_mode = True
|
|
1756
|
+
)
|
|
1757
|
+
|
|
1758
|
+
cum_rewards = 0.
|
|
1759
|
+
|
|
1760
|
+
with replay.one_episode() as final_meta_data_store_dict:
|
|
1761
|
+
|
|
1762
|
+
past_action = None
|
|
1763
|
+
|
|
1764
|
+
while True:
|
|
1765
|
+
state_for_model = state_image if use_vision else state
|
|
1766
|
+
|
|
1767
|
+
maybe_command = command_fn(state_for_model)
|
|
1768
|
+
|
|
1769
|
+
# predict next action
|
|
1770
|
+
|
|
1771
|
+
(action_logits, state_pred), value = stateful_forward(
|
|
1772
|
+
state_for_model,
|
|
1773
|
+
condition = maybe_command,
|
|
1774
|
+
cond_mask = tensor(exists(maybe_command)),
|
|
1775
|
+
state_embed_kwargs = {**state_embed_kwargs, 'internal_state': internal_state},
|
|
1776
|
+
action_select_kwargs = action_select_kwargs,
|
|
1777
|
+
state_id_kwarg = state_id_kwarg,
|
|
1778
|
+
past_action = past_action if embed_past_action else None,
|
|
1779
|
+
return_values = True,
|
|
1780
|
+
return_state_pred = True
|
|
1781
|
+
)
|
|
1782
|
+
|
|
1783
|
+
action = self.unembedder.sample(action_logits, **action_select_kwargs)
|
|
1784
|
+
|
|
1785
|
+
# maybe clip
|
|
1786
|
+
|
|
1787
|
+
if exists(action_rescale_range):
|
|
1788
|
+
min_val, max_val = action_rescale_range
|
|
1789
|
+
action = (action + 1.) * 0.5 * (max_val - min_val) + min_val
|
|
1790
|
+
|
|
1791
|
+
# pass to environment
|
|
1792
|
+
|
|
1793
|
+
step_dict = env_step(action, command = maybe_command, env_index = env_index)
|
|
1794
|
+
|
|
1795
|
+
derived, next_state, reward, terminated, truncated = pick(step_dict, ('derived_state', 'state', 'reward', 'terminated', 'truncated'))
|
|
1796
|
+
|
|
1797
|
+
next_state_image = derived.get('state_image', None)
|
|
1798
|
+
next_internal_state = derived.get('internal_state', None)
|
|
1799
|
+
|
|
1800
|
+
# maybe state entropy bonus
|
|
1801
|
+
|
|
1802
|
+
if state_entropy_bonus_weight > 0. and exists(state_pred):
|
|
1803
|
+
state_id = state_id_kwarg.get('state_id', 0)
|
|
1804
|
+
entropy = self.state_pred_head.entropy(state_pred, selector_index = state_id)
|
|
1805
|
+
|
|
1806
|
+
state_entropy_bonus = (entropy * state_entropy_bonus_weight).sum()
|
|
1807
|
+
|
|
1808
|
+
reward = reward + state_entropy_bonus.item() # the entropy is directly related to log variance
|
|
1809
|
+
|
|
1810
|
+
# cum rewards
|
|
1811
|
+
|
|
1812
|
+
cum_rewards += reward
|
|
1813
|
+
|
|
1814
|
+
# increment counters
|
|
1815
|
+
# we will store the step with done=False, as only the bootstrap/boundary node is done=True
|
|
1816
|
+
|
|
1817
|
+
exceeds_max_timesteps = max_timesteps >= 0 and timestep == (max_timesteps - 1)
|
|
1818
|
+
should_stop = truncated or terminated or tensor(exceeds_max_timesteps)
|
|
1819
|
+
|
|
1820
|
+
# get log prob of action
|
|
1821
|
+
|
|
1822
|
+
action_log_prob = self.unembedder.log_prob(action_logits, action, **action_select_kwargs)
|
|
1823
|
+
|
|
1824
|
+
memory = replay.store(
|
|
1825
|
+
state = state,
|
|
1826
|
+
state_image = state_image,
|
|
1827
|
+
action = action,
|
|
1828
|
+
action_log_prob = action_log_prob,
|
|
1829
|
+
internal_state = internal_state,
|
|
1830
|
+
reward = reward,
|
|
1831
|
+
value = value,
|
|
1832
|
+
done = tensor(False),
|
|
1833
|
+
condition = maybe_command,
|
|
1834
|
+
cond_mask = tensor(exists(maybe_command))
|
|
1835
|
+
)
|
|
1836
|
+
|
|
1837
|
+
timestep += 1
|
|
1838
|
+
|
|
1839
|
+
# break if done or exceed max timestep
|
|
1840
|
+
if should_stop:
|
|
1841
|
+
|
|
1842
|
+
# handle bootstrap value, which is a non-learnable timestep added with the next value for GAE
|
|
1843
|
+
# only if terminated signal not detected
|
|
1844
|
+
if not terminated:
|
|
1845
|
+
next_state_for_model = next_state_image if use_vision else next_state
|
|
1846
|
+
|
|
1847
|
+
_, next_value = stateful_forward(next_state_for_model, condition = maybe_command, cond_mask = tensor(exists(maybe_command)), return_values = True, state_embed_kwargs = {**state_embed_kwargs, 'internal_state': internal_state}, state_id_kwarg = state_id_kwarg, action_select_kwargs = action_select_kwargs)
|
|
1848
|
+
|
|
1849
|
+
terminal_node = dict(
|
|
1850
|
+
state = next_state,
|
|
1851
|
+
state_image = next_state_image,
|
|
1852
|
+
internal_state = next_internal_state,
|
|
1853
|
+
value = next_value,
|
|
1854
|
+
reward = next_value,
|
|
1855
|
+
done = True,
|
|
1856
|
+
condition = maybe_command,
|
|
1857
|
+
cond_mask = exists(maybe_command)
|
|
1858
|
+
)
|
|
1859
|
+
|
|
1860
|
+
else:
|
|
1861
|
+
# terminal node - store a step with 0 reward and value, and done=True, to stop GAE scan
|
|
1862
|
+
terminal_node = dict(
|
|
1863
|
+
state = next_state,
|
|
1864
|
+
state_image = next_state_image,
|
|
1865
|
+
internal_state = next_internal_state,
|
|
1866
|
+
value = torch.zeros_like(value),
|
|
1867
|
+
reward = torch.zeros_like(reward),
|
|
1868
|
+
done = True,
|
|
1869
|
+
condition = maybe_command,
|
|
1870
|
+
cond_mask = exists(maybe_command)
|
|
1871
|
+
)
|
|
1872
|
+
|
|
1873
|
+
terminal_node = {key: value for key, value in terminal_node.items() if key in memory._fields}
|
|
1874
|
+
|
|
1875
|
+
terminal_memory = memory._replace(**terminal_node)
|
|
1876
|
+
|
|
1877
|
+
replay.store(**terminal_memory._asdict())
|
|
1878
|
+
|
|
1879
|
+
# store the final cumulative reward into meta data
|
|
1880
|
+
|
|
1881
|
+
final_meta_data_store_dict.update(cum_rewards = cum_rewards)
|
|
1882
|
+
|
|
1883
|
+
break
|
|
1884
|
+
|
|
1885
|
+
state = next_state
|
|
1886
|
+
state_image = next_state_image
|
|
1887
|
+
internal_state = next_internal_state
|
|
1888
|
+
|
|
1889
|
+
past_action = action
|
|
1890
|
+
|
|
1891
|
+
return cum_rewards
|
|
1892
|
+
|
|
1263
1893
|
def forward(
|
|
1264
1894
|
self,
|
|
1265
1895
|
state: Tensor,
|
|
1266
|
-
cache:
|
|
1896
|
+
cache: TransformerMemory | None = None,
|
|
1267
1897
|
condition: Tensor | None = None,
|
|
1268
|
-
|
|
1898
|
+
cond_mask: Tensor | None = None,
|
|
1899
|
+
past_action: Tensor | None = None,
|
|
1900
|
+
state_embed_kwargs: dict = dict(),
|
|
1901
|
+
action_select_kwargs: dict = dict(),
|
|
1902
|
+
state_id_kwarg: dict = dict(),
|
|
1269
1903
|
detach_cache = False,
|
|
1270
1904
|
return_values = False,
|
|
1905
|
+
return_state_pred = False,
|
|
1271
1906
|
return_raw_value_logits = False
|
|
1272
1907
|
):
|
|
1273
1908
|
|
|
1274
1909
|
state = state.to(self.device)
|
|
1275
1910
|
|
|
1911
|
+
# move condition
|
|
1912
|
+
|
|
1913
|
+
if exists(condition):
|
|
1914
|
+
condition = condition.to(self.device)
|
|
1915
|
+
|
|
1276
1916
|
# determine which function to invoke for state to token for transformer
|
|
1277
1917
|
|
|
1278
1918
|
state_to_token = self.embedder
|
|
1279
1919
|
|
|
1280
|
-
if exists(state_type):
|
|
1281
|
-
state_to_token = self.embedder[state_type]
|
|
1282
|
-
|
|
1283
1920
|
# embed
|
|
1284
1921
|
|
|
1285
|
-
tokens = state_to_token(state)
|
|
1922
|
+
tokens = state_to_token(state, **state_embed_kwargs, **state_id_kwarg)
|
|
1286
1923
|
|
|
1287
|
-
#
|
|
1924
|
+
# maybe add past action
|
|
1288
1925
|
|
|
1289
|
-
|
|
1926
|
+
# determine if first window and start of sequence
|
|
1290
1927
|
|
|
1291
|
-
|
|
1928
|
+
total_tokens = cache.total_tokens if exists(cache) else 0
|
|
1292
1929
|
|
|
1293
|
-
|
|
1294
|
-
timestep_start = 0
|
|
1930
|
+
is_start_of_sequence = total_tokens == 0
|
|
1295
1931
|
|
|
1296
|
-
|
|
1297
|
-
|
|
1932
|
+
# maybe add past action
|
|
1933
|
+
|
|
1934
|
+
if exists(past_action):
|
|
1935
|
+
assert self.embed_past_action
|
|
1936
|
+
past_action_embed = self.past_action_embedder(past_action, **action_select_kwargs)
|
|
1937
|
+
|
|
1938
|
+
if is_start_of_sequence:
|
|
1939
|
+
past_action_embed = pad_at_dim(past_action_embed[..., 1:, :], (1, 0), dim = -2)
|
|
1940
|
+
|
|
1941
|
+
tokens = tokens + past_action_embed
|
|
1942
|
+
|
|
1943
|
+
# time
|
|
1944
|
+
|
|
1945
|
+
time = tokens.shape[-2]
|
|
1298
1946
|
|
|
1299
1947
|
# an assert - make sure during training or inference, forward never gets anything that crosses the window segment boundary, to open up some possibilities with extending memory
|
|
1300
1948
|
|
|
1301
|
-
assert ((
|
|
1949
|
+
assert ((total_tokens % self.window_size) + time) <= self.window_size
|
|
1302
1950
|
|
|
1303
1951
|
# attention
|
|
1304
1952
|
|
|
1305
|
-
|
|
1953
|
+
if not exists(cache):
|
|
1954
|
+
memory_segments = deque(maxlen = self.max_mem_segments)
|
|
1955
|
+
else:
|
|
1956
|
+
total_tokens, kv_cache, gru_cache, mem_mlp_cache, mem_mlp_hidden_states, memory_segments = cache
|
|
1957
|
+
|
|
1958
|
+
attn_cache = cache
|
|
1959
|
+
|
|
1960
|
+
embed, cache = self.transformer(
|
|
1961
|
+
tokens,
|
|
1962
|
+
condition = condition,
|
|
1963
|
+
cond_mask = cond_mask,
|
|
1964
|
+
cache = attn_cache,
|
|
1965
|
+
return_kv_cache = True
|
|
1966
|
+
)
|
|
1306
1967
|
|
|
1307
1968
|
# unembed to actions - in language models this would be the next state
|
|
1308
1969
|
|
|
1309
|
-
|
|
1970
|
+
policy_embed = self.policy_network(embed)
|
|
1971
|
+
|
|
1972
|
+
action_logits = self.unembedder(policy_embed, **action_select_kwargs)
|
|
1310
1973
|
|
|
1311
1974
|
out = action_logits
|
|
1312
1975
|
|
|
1976
|
+
# maybe return state prediction
|
|
1977
|
+
|
|
1978
|
+
if return_state_pred:
|
|
1979
|
+
state_pred = None
|
|
1980
|
+
|
|
1981
|
+
if self.can_pred_state:
|
|
1982
|
+
state_id = state_id_kwarg.get('state_id', 0)
|
|
1983
|
+
state_pred_embed = self.state_pred_network(embed)
|
|
1984
|
+
state_pred = self.state_pred_head(state_pred_embed, selector_index = state_id)
|
|
1985
|
+
|
|
1986
|
+
out = (out, state_pred)
|
|
1987
|
+
|
|
1313
1988
|
# maybe detach cache
|
|
1314
1989
|
|
|
1315
1990
|
if detach_cache:
|
|
1316
|
-
|
|
1991
|
+
cache = tree_map_tensor(cache, lambda t: t.detach())
|
|
1317
1992
|
|
|
1318
1993
|
# handle returning of values
|
|
1319
1994
|
|
|
@@ -1327,20 +2002,22 @@ class Locoformer(Module):
|
|
|
1327
2002
|
|
|
1328
2003
|
out = (out, values)
|
|
1329
2004
|
|
|
1330
|
-
#
|
|
2005
|
+
# handle curtailing kv cache at the right intervals
|
|
1331
2006
|
|
|
1332
|
-
|
|
2007
|
+
total_tokens, kv_cache, gru_cache, mem_mlp_cache, mem_mlp_hidden_states, memory_segments = cache
|
|
1333
2008
|
|
|
1334
|
-
|
|
2009
|
+
if self.fixed_window_size or divisible_by(total_tokens, self.window_size * 2):
|
|
2010
|
+
kv_cache = kv_cache[..., -self.window_size:, :]
|
|
1335
2011
|
|
|
1336
|
-
|
|
2012
|
+
if self.recurrent_cache and divisible_by(total_tokens, self.window_size):
|
|
2013
|
+
kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
|
|
1337
2014
|
|
|
1338
|
-
|
|
1339
|
-
|
|
2015
|
+
if exists(gru_cache):
|
|
2016
|
+
gru_cache = torch.roll(gru_cache, shifts = -1, dims = 0)
|
|
1340
2017
|
|
|
1341
|
-
|
|
2018
|
+
if divisible_by(total_tokens, self.window_size):
|
|
2019
|
+
memory_segments.append(kv_cache.detach())
|
|
1342
2020
|
|
|
1343
|
-
|
|
1344
|
-
kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
|
|
2021
|
+
cache = TransformerMemory(total_tokens, kv_cache, gru_cache, mem_mlp_cache, mem_mlp_hidden_states, memory_segments)
|
|
1345
2022
|
|
|
1346
|
-
return out,
|
|
2023
|
+
return out, cache
|