locoformer 0.0.29__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- locoformer/locoformer.py +1356 -392
- {locoformer-0.0.29.dist-info → locoformer-0.1.1.dist-info}/METADATA +30 -2
- locoformer-0.1.1.dist-info/RECORD +6 -0
- {locoformer-0.0.29.dist-info → locoformer-0.1.1.dist-info}/WHEEL +1 -1
- locoformer-0.0.29.dist-info/RECORD +0 -6
- {locoformer-0.0.29.dist-info → locoformer-0.1.1.dist-info}/licenses/LICENSE +0 -0
locoformer/locoformer.py
CHANGED
|
@@ -1,10 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
import math
|
|
2
3
|
from typing import Callable
|
|
3
|
-
from
|
|
4
|
+
from types import SimpleNamespace
|
|
5
|
+
from functools import partial, wraps
|
|
4
6
|
|
|
5
7
|
from pathlib import Path
|
|
6
8
|
from contextlib import contextmanager
|
|
7
|
-
from collections import namedtuple
|
|
9
|
+
from collections import namedtuple, deque
|
|
10
|
+
|
|
11
|
+
from glom import glom
|
|
12
|
+
|
|
13
|
+
from inspect import signature
|
|
8
14
|
|
|
9
15
|
import numpy as np
|
|
10
16
|
from numpy import ndarray
|
|
@@ -14,16 +20,17 @@ from beartype import beartype
|
|
|
14
20
|
from beartype.door import is_bearable
|
|
15
21
|
|
|
16
22
|
import torch
|
|
17
|
-
from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
|
|
23
|
+
from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy, nested
|
|
18
24
|
import torch.nn.functional as F
|
|
19
25
|
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
|
|
20
|
-
from torch.utils._pytree import tree_map
|
|
26
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
21
27
|
from torch.utils.data import Dataset, DataLoader
|
|
28
|
+
from torch.distributions import Normal
|
|
22
29
|
from torch.optim import Optimizer
|
|
23
30
|
|
|
24
31
|
import einx
|
|
25
|
-
from einops import rearrange, einsum
|
|
26
|
-
from einops.layers.torch import Rearrange
|
|
32
|
+
from einops import rearrange, repeat, einsum, reduce, pack
|
|
33
|
+
from einops.layers.torch import Rearrange, Reduce
|
|
27
34
|
|
|
28
35
|
from rotary_embedding_torch import RotaryEmbedding
|
|
29
36
|
|
|
@@ -31,11 +38,28 @@ from hl_gauss_pytorch import HLGaussLoss
|
|
|
31
38
|
|
|
32
39
|
from assoc_scan import AssocScan
|
|
33
40
|
|
|
41
|
+
from x_mlps_pytorch import MLP
|
|
42
|
+
|
|
43
|
+
from x_evolution import EvoStrategy
|
|
44
|
+
|
|
45
|
+
from discrete_continuous_embed_readout import EmbedAndReadout, Embed, Readout
|
|
46
|
+
|
|
47
|
+
from hyper_connections import mc_get_init_and_expand_reduce_stream_functions
|
|
48
|
+
|
|
49
|
+
from memmap_replay_buffer import ReplayBuffer, ReplayDataset
|
|
50
|
+
|
|
34
51
|
# constants
|
|
35
52
|
|
|
36
53
|
LinearNoBias = partial(Linear, bias = False)
|
|
37
54
|
|
|
38
|
-
|
|
55
|
+
TransformerMemory = namedtuple('TransformerMemory', (
|
|
56
|
+
'total_tokens',
|
|
57
|
+
'kv_cache',
|
|
58
|
+
'gru_cache',
|
|
59
|
+
'mem_mlp_cache',
|
|
60
|
+
'mem_mlp_hidden_states',
|
|
61
|
+
'memory_segments'
|
|
62
|
+
))
|
|
39
63
|
|
|
40
64
|
# helper functions
|
|
41
65
|
|
|
@@ -45,12 +69,55 @@ def exists(v):
|
|
|
45
69
|
def default(v, d):
|
|
46
70
|
return v if exists(v) else d
|
|
47
71
|
|
|
72
|
+
def always(val):
|
|
73
|
+
def inner(*args, **kwargs):
|
|
74
|
+
return val
|
|
75
|
+
|
|
76
|
+
return inner
|
|
77
|
+
|
|
78
|
+
def identity(t, *args, **kwargs):
|
|
79
|
+
return t
|
|
80
|
+
|
|
81
|
+
def pick(data, keys):
|
|
82
|
+
return tuple(data[k] for k in keys)
|
|
83
|
+
|
|
48
84
|
def first(arr):
|
|
49
85
|
return arr[0]
|
|
50
86
|
|
|
87
|
+
def xnor(x, y):
|
|
88
|
+
return not (x ^ y)
|
|
89
|
+
|
|
51
90
|
def divisible_by(num, den):
|
|
52
91
|
return (num % den) == 0
|
|
53
92
|
|
|
93
|
+
def get_param_names(fn):
|
|
94
|
+
parameters = signature(fn).parameters
|
|
95
|
+
return list(parameters.keys())
|
|
96
|
+
|
|
97
|
+
def check_has_param_attr(
|
|
98
|
+
param_name,
|
|
99
|
+
param_attr,
|
|
100
|
+
default_value = None
|
|
101
|
+
):
|
|
102
|
+
def decorator(fn):
|
|
103
|
+
sig = signature(fn)
|
|
104
|
+
|
|
105
|
+
@wraps(fn)
|
|
106
|
+
def inner(*args, **kwargs):
|
|
107
|
+
|
|
108
|
+
bound_args = sig.bind(*args, **kwargs).arguments
|
|
109
|
+
|
|
110
|
+
if not (
|
|
111
|
+
param_name in bound_args and
|
|
112
|
+
hasattr(bound_args[param_name], param_attr)
|
|
113
|
+
):
|
|
114
|
+
return default_value
|
|
115
|
+
|
|
116
|
+
return fn(*args, **kwargs)
|
|
117
|
+
|
|
118
|
+
return inner
|
|
119
|
+
return decorator
|
|
120
|
+
|
|
54
121
|
# tensor helpers
|
|
55
122
|
|
|
56
123
|
def log(t, eps = 1e-20):
|
|
@@ -62,6 +129,11 @@ def is_empty(t):
|
|
|
62
129
|
def tree_map_tensor(x, fn):
|
|
63
130
|
return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
|
|
64
131
|
|
|
132
|
+
def lens_to_mask(lens, max_len):
|
|
133
|
+
device = lens.device
|
|
134
|
+
seq = arange(max_len, device = device)
|
|
135
|
+
return einx.less('j, i -> i j', seq, lens)
|
|
136
|
+
|
|
65
137
|
def pad_at_dim(
|
|
66
138
|
t,
|
|
67
139
|
pad: tuple[int, int],
|
|
@@ -75,12 +147,167 @@ def pad_at_dim(
|
|
|
75
147
|
zeros = ((0, 0) * dims_from_right)
|
|
76
148
|
return F.pad(t, (*zeros, *pad), value = value)
|
|
77
149
|
|
|
78
|
-
def
|
|
79
|
-
|
|
150
|
+
def safe_cat(t, next_t, dim = -1):
|
|
151
|
+
if not exists(t):
|
|
152
|
+
return next_t
|
|
153
|
+
|
|
154
|
+
return cat((t, next_t), dim = dim)
|
|
155
|
+
|
|
156
|
+
def normalize(t, mask = None, eps = 1e-5):
|
|
157
|
+
if exists(mask):
|
|
158
|
+
assert mask.any()
|
|
159
|
+
|
|
160
|
+
t_for_stats = t[mask] if exists(mask) else t
|
|
161
|
+
var, mean = torch.var_mean(t_for_stats)
|
|
162
|
+
|
|
163
|
+
return (t - mean) / var.sqrt().clamp_min(eps)
|
|
164
|
+
|
|
165
|
+
def tensor_to_dict(
|
|
166
|
+
t: Tensor,
|
|
167
|
+
config: tuple[tuple[str, int] | str],
|
|
168
|
+
dim = -1,
|
|
169
|
+
return_dottable = True
|
|
170
|
+
):
|
|
171
|
+
config = tuple((c, 1) if isinstance(c, str) else c for c in config)
|
|
172
|
+
|
|
173
|
+
names, sizes = zip(*config)
|
|
174
|
+
assert sum(sizes) == t.shape[dim]
|
|
175
|
+
|
|
176
|
+
t = t.split(sizes, dim = dim)
|
|
177
|
+
tensor_dict = dict(zip(names, t))
|
|
178
|
+
|
|
179
|
+
if not return_dottable:
|
|
180
|
+
return tensor_dict
|
|
181
|
+
|
|
182
|
+
return SimpleNamespace(**tensor_dict)
|
|
183
|
+
|
|
184
|
+
# dataset related
|
|
185
|
+
|
|
186
|
+
class RemappedReplayDataset(Dataset):
|
|
187
|
+
def __init__(
|
|
188
|
+
self,
|
|
189
|
+
dataset: ReplayDataset,
|
|
190
|
+
episode_mapping: Tensor | list[list[int]],
|
|
191
|
+
shuffle_episodes = False,
|
|
192
|
+
num_trials_select = None
|
|
193
|
+
):
|
|
194
|
+
assert len(dataset) > 0
|
|
195
|
+
self.dataset = dataset
|
|
196
|
+
|
|
197
|
+
if is_tensor(episode_mapping):
|
|
198
|
+
assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
|
|
199
|
+
episode_mapping = episode_mapping.tolist()
|
|
200
|
+
|
|
201
|
+
self.episode_mapping = episode_mapping
|
|
202
|
+
self.shuffle_episodes = shuffle_episodes
|
|
203
|
+
|
|
204
|
+
assert not (exists(num_trials_select) and num_trials_select <= 0)
|
|
205
|
+
self.sub_select_trials = exists(num_trials_select)
|
|
206
|
+
self.num_trials_select = num_trials_select
|
|
207
|
+
|
|
208
|
+
def __len__(self):
|
|
209
|
+
return len(self.episode_mapping)
|
|
210
|
+
|
|
211
|
+
def __getitem__(self, idx):
|
|
212
|
+
|
|
213
|
+
episode_indices = self.episode_mapping[idx]
|
|
214
|
+
|
|
215
|
+
episode_indices = tensor(episode_indices)
|
|
216
|
+
episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
|
|
217
|
+
|
|
218
|
+
assert not is_empty(episode_indices)
|
|
219
|
+
|
|
220
|
+
# shuffle the episode indices if either shuffle episodes is turned on, or `num_trial_select` passed in (for sub selecting episodes from a set)
|
|
221
|
+
|
|
222
|
+
if (
|
|
223
|
+
episode_indices.numel() > 1 and
|
|
224
|
+
(self.shuffle_episodes or self.sub_select_trials)
|
|
225
|
+
):
|
|
226
|
+
num_episodes = len(episode_indices)
|
|
227
|
+
episode_indices = episode_indices[torch.randperm(num_episodes)]
|
|
228
|
+
|
|
229
|
+
# crop out the episodes
|
|
230
|
+
|
|
231
|
+
if self.sub_select_trials:
|
|
232
|
+
episode_indices = episode_indices[:self.num_trials_select]
|
|
233
|
+
|
|
234
|
+
# now select out the episode data and merge along time
|
|
235
|
+
|
|
236
|
+
episode_data = [self.dataset[i] for i in episode_indices.tolist()]
|
|
237
|
+
|
|
238
|
+
episode_lens = stack([data.pop('_lens') for data in episode_data])
|
|
239
|
+
|
|
240
|
+
keys = first(episode_data).keys()
|
|
241
|
+
|
|
242
|
+
values = [list(data.values()) for data in episode_data]
|
|
243
|
+
|
|
244
|
+
values = [cat(field_values) for field_values in zip(*values)] # concat across time
|
|
245
|
+
|
|
246
|
+
multi_episode_data = dict(zip(keys, values))
|
|
247
|
+
|
|
248
|
+
multi_episode_data['_lens'] = episode_lens.sum()
|
|
249
|
+
|
|
250
|
+
multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
|
|
251
|
+
|
|
252
|
+
return multi_episode_data
|
|
253
|
+
|
|
254
|
+
# reward functions - A.2
|
|
255
|
+
|
|
256
|
+
@check_has_param_attr('state', 'v_xy')
|
|
257
|
+
@check_has_param_attr('command', 'v_xy')
|
|
258
|
+
def reward_linear_velocity_command_tracking(
|
|
259
|
+
state,
|
|
260
|
+
command,
|
|
261
|
+
s1 = 1.
|
|
262
|
+
):
|
|
263
|
+
error = (state.v_xy - command.v_xy).norm(dim = -1).pow(2)
|
|
264
|
+
return torch.exp(-error / s1)
|
|
265
|
+
|
|
266
|
+
@check_has_param_attr('state', 'w_z')
|
|
267
|
+
@check_has_param_attr('command', 'w_z')
|
|
268
|
+
def reward_angular_velocity_command_tracking(
|
|
269
|
+
state,
|
|
270
|
+
command,
|
|
271
|
+
s2 = 1.
|
|
272
|
+
):
|
|
273
|
+
error = (state.w_z - command.w_z).norm(dim = -1).pow(2)
|
|
274
|
+
return torch.exp(-error / s2)
|
|
275
|
+
|
|
276
|
+
@check_has_param_attr('state', 'v_z')
|
|
277
|
+
def reward_base_linear_velocity_penalty(
|
|
278
|
+
state
|
|
279
|
+
):
|
|
280
|
+
return -state.v_z.norm(dim = -1).pow(2)
|
|
281
|
+
|
|
282
|
+
@check_has_param_attr('state', 'w_xy')
|
|
283
|
+
def reward_base_angular_velocity_penalty(
|
|
284
|
+
state
|
|
285
|
+
):
|
|
286
|
+
return -state.w_xy.norm(dim = -1).pow(2)
|
|
287
|
+
|
|
288
|
+
@check_has_param_attr('state', 'x_z')
|
|
289
|
+
def reward_base_height_penalty(
|
|
290
|
+
state,
|
|
291
|
+
x_z_nominal = 0.27
|
|
292
|
+
):
|
|
293
|
+
return -(state.x_z - x_z_nominal).norm(dim = -1).pow(2)
|
|
294
|
+
|
|
295
|
+
@check_has_param_attr('state', 'joint_q')
|
|
296
|
+
def reward_joint_acceleration_penalty(
|
|
297
|
+
state
|
|
298
|
+
):
|
|
299
|
+
return -state.joint_q.norm(dim = -1).pow(2)
|
|
300
|
+
|
|
301
|
+
@check_has_param_attr('state', 'tau')
|
|
302
|
+
def reward_torque_penalty(
|
|
303
|
+
state
|
|
304
|
+
):
|
|
305
|
+
return -state.tau.norm(dim = -1).pow(2)
|
|
80
306
|
|
|
81
|
-
def
|
|
82
|
-
|
|
83
|
-
|
|
307
|
+
def reward_alive(
|
|
308
|
+
state
|
|
309
|
+
):
|
|
310
|
+
return 1.
|
|
84
311
|
|
|
85
312
|
# generalized advantage estimate
|
|
86
313
|
|
|
@@ -185,283 +412,87 @@ def create_sliding_mask(
|
|
|
185
412
|
create_kwargs = dict(device = device) if exists(device) else dict()
|
|
186
413
|
return create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True, **create_kwargs)
|
|
187
414
|
|
|
188
|
-
#
|
|
189
|
-
|
|
190
|
-
def collate_var_time(data):
|
|
191
|
-
|
|
192
|
-
datum = first(data)
|
|
193
|
-
keys = datum.keys()
|
|
194
|
-
|
|
195
|
-
all_tensors = zip(*[datum.values() for datum in data])
|
|
196
|
-
|
|
197
|
-
collated_values = []
|
|
198
|
-
|
|
199
|
-
for key, tensors in zip(keys, all_tensors):
|
|
200
|
-
|
|
201
|
-
# the episode lens have zero dimension - think of a cleaner way to handle this later
|
|
202
|
-
|
|
203
|
-
if key != '_lens':
|
|
204
|
-
|
|
205
|
-
times = [t.shape[0] for t in tensors]
|
|
206
|
-
max_time = max(times)
|
|
207
|
-
tensors = [pad_at_dim(t, (0, max_time - t.shape[0]), dim = 0) for t in tensors]
|
|
208
|
-
|
|
209
|
-
collated_values.append(stack(tensors))
|
|
210
|
-
|
|
211
|
-
return dict(zip(keys, collated_values))
|
|
212
|
-
|
|
213
|
-
class ReplayDataset(Dataset):
|
|
214
|
-
def __init__(
|
|
215
|
-
self,
|
|
216
|
-
folder: str | Path,
|
|
217
|
-
fields: tuple[str, ...] | None = None
|
|
218
|
-
):
|
|
219
|
-
if isinstance(folder, str):
|
|
220
|
-
folder = Path(folder)
|
|
221
|
-
|
|
222
|
-
episode_lens = folder / 'episode_lens.npy'
|
|
223
|
-
self.episode_lens = open_memmap(str(episode_lens), mode = 'r')
|
|
224
|
-
|
|
225
|
-
# get indices of non-zero lengthed episodes
|
|
226
|
-
|
|
227
|
-
nonzero_episodes = self.episode_lens > 0
|
|
228
|
-
self.indices = np.arange(self.episode_lens.shape[-1])[nonzero_episodes]
|
|
229
|
-
|
|
230
|
-
# get all data files
|
|
231
|
-
|
|
232
|
-
filepaths = [*folder.glob('*.data.npy')]
|
|
233
|
-
assert len(filepaths) > 0
|
|
234
|
-
|
|
235
|
-
fieldname_to_filepath = {path.name.split('.')[0]: path for path in filepaths}
|
|
236
|
-
|
|
237
|
-
fieldnames_from_files = set(fieldname_to_filepath.keys())
|
|
238
|
-
|
|
239
|
-
fields = default(fields, fieldnames_from_files)
|
|
240
|
-
|
|
241
|
-
self.memmaps = dict()
|
|
242
|
-
|
|
243
|
-
for field in fields:
|
|
244
|
-
assert field in fieldnames_from_files, f'invalid field {field} - must be one of {fieldnames_from_files}'
|
|
245
|
-
|
|
246
|
-
path = fieldname_to_filepath[field]
|
|
247
|
-
|
|
248
|
-
self.memmaps[field] = open_memmap(str(path), mode = 'r')
|
|
415
|
+
# normalization + conditioning (needed for the commands to the robot)
|
|
249
416
|
|
|
250
|
-
|
|
251
|
-
return len(self.indices)
|
|
252
|
-
|
|
253
|
-
def __getitem__(self, idx):
|
|
254
|
-
episode_index = self.indices[idx]
|
|
255
|
-
|
|
256
|
-
episode_len = self.episode_lens[episode_index]
|
|
257
|
-
|
|
258
|
-
data = {field: from_numpy(memmap[episode_index, :episode_len].copy()) for field, memmap in self.memmaps.items()}
|
|
259
|
-
|
|
260
|
-
data['_lens'] = tensor(episode_len)
|
|
261
|
-
|
|
262
|
-
return data
|
|
263
|
-
|
|
264
|
-
class RemappedReplayDataset(Dataset):
|
|
417
|
+
class MaybeAdaRMSNormWrapper(Module):
|
|
265
418
|
def __init__(
|
|
266
419
|
self,
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
420
|
+
fn: Module,
|
|
421
|
+
dim,
|
|
422
|
+
dim_cond = None
|
|
270
423
|
):
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
if is_tensor(episode_mapping):
|
|
275
|
-
assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
|
|
276
|
-
episode_mapping = episode_mapping.tolist()
|
|
277
|
-
|
|
278
|
-
self.episode_mapping = episode_mapping
|
|
279
|
-
self.shuffle_episodes = shuffle_episodes
|
|
280
|
-
|
|
281
|
-
def __len__(self):
|
|
282
|
-
return len(self.episode_mapping)
|
|
283
|
-
|
|
284
|
-
def __getitem__(self, idx):
|
|
285
|
-
|
|
286
|
-
episode_indices = self.episode_mapping[idx]
|
|
287
|
-
|
|
288
|
-
episode_indices = tensor(episode_indices)
|
|
289
|
-
episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
|
|
290
|
-
|
|
291
|
-
assert not is_empty(episode_indices)
|
|
292
|
-
|
|
293
|
-
if self.shuffle_episodes and episode_indices.numel() > 1:
|
|
294
|
-
num_episodes = len(episode_indices)
|
|
295
|
-
episode_indices = episode_indices[torch.randperm(num_episodes)]
|
|
296
|
-
|
|
297
|
-
episode_data = [self.dataset[i] for i in episode_indices.tolist()]
|
|
298
|
-
|
|
299
|
-
episode_lens = stack([data.pop('_lens') for data in episode_data])
|
|
300
|
-
|
|
301
|
-
keys = first(episode_data).keys()
|
|
302
|
-
|
|
303
|
-
values = [list(data.values()) for data in episode_data]
|
|
304
|
-
|
|
305
|
-
values = [cat(field_values) for field_values in zip(*values)] # concat across time
|
|
306
|
-
|
|
307
|
-
multi_episode_data = dict(zip(keys, values))
|
|
424
|
+
super().__init__()
|
|
425
|
+
condition = exists(dim_cond)
|
|
308
426
|
|
|
309
|
-
|
|
427
|
+
self.fn = fn
|
|
428
|
+
self.norm = nn.RMSNorm(dim, elementwise_affine = not condition)
|
|
310
429
|
|
|
311
|
-
|
|
430
|
+
self.accept_condition = condition
|
|
312
431
|
|
|
313
|
-
|
|
432
|
+
if condition:
|
|
433
|
+
self.to_gamma = LinearNoBias(dim_cond, dim)
|
|
434
|
+
self.to_ada_norm_zero = nn.Linear(dim_cond, dim)
|
|
314
435
|
|
|
315
|
-
|
|
436
|
+
nn.init.zeros_(self.to_gamma.weight)
|
|
437
|
+
nn.init.zeros_(self.to_ada_norm_zero.weight)
|
|
438
|
+
nn.init.constant_(self.to_ada_norm_zero.bias, -5.)
|
|
316
439
|
|
|
317
|
-
|
|
318
|
-
def __init__(
|
|
440
|
+
def forward(
|
|
319
441
|
self,
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
str | tuple[str, int | tuple[int, ...]]
|
|
326
|
-
]
|
|
442
|
+
x,
|
|
443
|
+
*args,
|
|
444
|
+
cond = None,
|
|
445
|
+
cond_mask = None,
|
|
446
|
+
**kwargs
|
|
327
447
|
):
|
|
328
448
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
if not isinstance(folder, Path):
|
|
332
|
-
folder = Path(folder)
|
|
333
|
-
folder.mkdir(exist_ok = True)
|
|
334
|
-
|
|
335
|
-
self.folder = folder
|
|
336
|
-
assert folder.is_dir()
|
|
337
|
-
|
|
338
|
-
# keeping track of episode length
|
|
339
|
-
|
|
340
|
-
episode_lens = folder / 'episode_lens.npy'
|
|
341
|
-
|
|
342
|
-
self.episode_index = 0
|
|
343
|
-
self.timestep_index = 0
|
|
344
|
-
|
|
345
|
-
self.max_episodes = max_episodes
|
|
346
|
-
self.max_timesteps= max_timesteps
|
|
449
|
+
need_cond = self.accept_condition
|
|
450
|
+
has_input_cond = need_cond and exists(cond)
|
|
347
451
|
|
|
348
|
-
|
|
452
|
+
if exists(cond):
|
|
453
|
+
assert self.accept_condition
|
|
349
454
|
|
|
350
|
-
|
|
455
|
+
prenormed = self.norm(x)
|
|
351
456
|
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
self.fieldnames = set(fields.keys())
|
|
457
|
+
if has_input_cond:
|
|
458
|
+
if cond.ndim == 2:
|
|
459
|
+
cond = rearrange(cond, 'b d -> b 1 d')
|
|
356
460
|
|
|
357
|
-
|
|
461
|
+
cond_scale = self.to_gamma(cond)
|
|
358
462
|
|
|
359
|
-
|
|
463
|
+
conditioned = prenormed * cond_scale
|
|
360
464
|
|
|
361
|
-
|
|
465
|
+
# handle a condition mask
|
|
362
466
|
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
# memmap file
|
|
369
|
-
|
|
370
|
-
filepath = folder / f'{field_name}.data.npy'
|
|
371
|
-
memmap = open_memmap(str(filepath), mode = 'w+', dtype = dtype, shape = (max_episodes, max_timesteps, *shape))
|
|
372
|
-
|
|
373
|
-
self.memmaps[field_name] = memmap
|
|
374
|
-
self.shapes[field_name] = shape
|
|
375
|
-
self.dtypes[field_name] = dtype
|
|
376
|
-
|
|
377
|
-
self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
|
|
378
|
-
|
|
379
|
-
def __len__(self):
|
|
380
|
-
return (self.episode_lens > 0).sum().item()
|
|
381
|
-
|
|
382
|
-
def reset_(self):
|
|
383
|
-
self.episode_lens[:] = 0
|
|
384
|
-
self.episode_index = 0
|
|
385
|
-
self.timestep_index = 0
|
|
386
|
-
|
|
387
|
-
def advance_episode(self):
|
|
388
|
-
self.episode_index = (self.episode_index + 1) % self.max_episodes
|
|
389
|
-
self.timestep_index = 0
|
|
390
|
-
|
|
391
|
-
def flush(self):
|
|
392
|
-
self.episode_lens[self.episode_index] = self.timestep_index
|
|
393
|
-
|
|
394
|
-
for memmap in self.memmaps.values():
|
|
395
|
-
memmap.flush()
|
|
396
|
-
|
|
397
|
-
self.episode_lens.flush()
|
|
398
|
-
|
|
399
|
-
@contextmanager
|
|
400
|
-
def one_episode(self):
|
|
401
|
-
|
|
402
|
-
yield
|
|
403
|
-
|
|
404
|
-
self.flush()
|
|
405
|
-
self.advance_episode()
|
|
406
|
-
|
|
407
|
-
@beartype
|
|
408
|
-
def store_datapoint(
|
|
409
|
-
self,
|
|
410
|
-
episode_index: int,
|
|
411
|
-
timestep_index: int,
|
|
412
|
-
name: str,
|
|
413
|
-
datapoint: Tensor | ndarray
|
|
414
|
-
):
|
|
415
|
-
assert 0 <= episode_index < self.max_episodes
|
|
416
|
-
assert 0 <= timestep_index < self.max_timesteps
|
|
417
|
-
|
|
418
|
-
if is_tensor(datapoint):
|
|
419
|
-
datapoint = datapoint.detach().cpu().numpy()
|
|
420
|
-
|
|
421
|
-
assert name in self.fieldnames, f'invalid field name {name} - must be one of {self.fieldnames}'
|
|
422
|
-
|
|
423
|
-
assert datapoint.shape == self.shapes[name], f'invalid shape {datapoint.shape} - shape must be {self.shapes[name]}'
|
|
424
|
-
|
|
425
|
-
self.memmaps[name][self.episode_index, self.timestep_index] = datapoint
|
|
426
|
-
|
|
427
|
-
def store(
|
|
428
|
-
self,
|
|
429
|
-
**data
|
|
430
|
-
):
|
|
431
|
-
assert is_bearable(data, dict[str, Tensor | ndarray])
|
|
432
|
-
|
|
433
|
-
assert not self.timestep_index >= self.max_timesteps, 'you exceeded the `max_timesteps` set on the replay buffer'
|
|
467
|
+
if exists(cond_mask):
|
|
468
|
+
prenormed = einx.where('b n, b n d, b n d', cond_mask, conditioned, prenormed)
|
|
469
|
+
else:
|
|
470
|
+
prenormed = conditioned
|
|
434
471
|
|
|
435
|
-
|
|
472
|
+
# the main block, either attention or feedforward or whatever
|
|
436
473
|
|
|
437
|
-
|
|
474
|
+
all_fn_out = self.fn(prenormed, *args, **kwargs)
|
|
438
475
|
|
|
439
|
-
|
|
476
|
+
if not has_input_cond:
|
|
477
|
+
return all_fn_out
|
|
440
478
|
|
|
441
|
-
return
|
|
479
|
+
# function may return multiple args
|
|
442
480
|
|
|
443
|
-
|
|
444
|
-
self,
|
|
445
|
-
episode_mapping: Tensor | list[list[int]] | None = None,
|
|
446
|
-
) -> Dataset:
|
|
447
|
-
self.flush()
|
|
481
|
+
(out, *rest), tree_spec = tree_flatten(all_fn_out)
|
|
448
482
|
|
|
449
|
-
|
|
483
|
+
scale_out = self.to_ada_norm_zero(cond).sigmoid()
|
|
450
484
|
|
|
451
|
-
if
|
|
452
|
-
|
|
485
|
+
if exists(cond_mask):
|
|
486
|
+
is_cond = rearrange(cond_mask, '... -> ... 1')
|
|
487
|
+
out = torch.where(is_cond, out * scale_out, out)
|
|
488
|
+
else:
|
|
489
|
+
out = out * scale_out
|
|
453
490
|
|
|
454
|
-
|
|
491
|
+
# restore
|
|
455
492
|
|
|
456
|
-
|
|
457
|
-
self,
|
|
458
|
-
batch_size,
|
|
459
|
-
episode_mapping: Tensor | list[list[int]] | None = None,
|
|
460
|
-
**kwargs
|
|
461
|
-
) -> DataLoader:
|
|
462
|
-
self.flush()
|
|
493
|
+
all_fn_out = tree_unflatten((out, *rest), tree_spec)
|
|
463
494
|
|
|
464
|
-
return
|
|
495
|
+
return all_fn_out
|
|
465
496
|
|
|
466
497
|
# transformer-xl with ppo
|
|
467
498
|
|
|
@@ -472,15 +503,13 @@ class Attention(Module):
|
|
|
472
503
|
window_size,
|
|
473
504
|
dim_head = 64,
|
|
474
505
|
heads = 8,
|
|
475
|
-
pre_rmsnorm = True,
|
|
476
506
|
fixed_window_size = False,
|
|
477
|
-
accept_value_residual = False
|
|
507
|
+
accept_value_residual = False,
|
|
508
|
+
max_mem_segments = 1
|
|
478
509
|
):
|
|
479
510
|
super().__init__()
|
|
480
511
|
self.scale = dim_head ** -0.5
|
|
481
512
|
|
|
482
|
-
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
|
483
|
-
|
|
484
513
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
|
485
514
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
486
515
|
|
|
@@ -512,20 +541,22 @@ class Attention(Module):
|
|
|
512
541
|
|
|
513
542
|
self.fixed_window_size = fixed_window_size
|
|
514
543
|
self.window_size = window_size
|
|
544
|
+
self.max_mem_segments = max_mem_segments
|
|
545
|
+
|
|
546
|
+
self.register_buffer('causal_mask', None, persistent = False)
|
|
515
547
|
|
|
516
548
|
def forward(
|
|
517
549
|
self,
|
|
518
550
|
tokens,
|
|
519
551
|
value_residual = None,
|
|
520
552
|
kv_cache = None,
|
|
553
|
+
past_segments = None,
|
|
521
554
|
return_kv_cache = False,
|
|
522
555
|
):
|
|
523
556
|
seq_len = tokens.shape[-2]
|
|
524
557
|
|
|
525
558
|
device = tokens.device
|
|
526
559
|
|
|
527
|
-
tokens = self.norm(tokens)
|
|
528
|
-
|
|
529
560
|
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
|
|
530
561
|
|
|
531
562
|
q, k, v = map(self.split_heads, (q, k, v))
|
|
@@ -544,6 +575,11 @@ class Attention(Module):
|
|
|
544
575
|
k = cat((ck, k), dim = -2)
|
|
545
576
|
v = cat((cv, v), dim = -2)
|
|
546
577
|
|
|
578
|
+
if exists(past_segments):
|
|
579
|
+
pk, pv = past_segments
|
|
580
|
+
k = cat((pk, k), dim = -2)
|
|
581
|
+
v = cat((pv, v), dim = -2)
|
|
582
|
+
|
|
547
583
|
if return_kv_cache:
|
|
548
584
|
next_kv_cache = stack((k, v))
|
|
549
585
|
|
|
@@ -557,9 +593,12 @@ class Attention(Module):
|
|
|
557
593
|
i_seq = arange(i, device = device)
|
|
558
594
|
j_seq = arange(j, device = device) - (j - i)
|
|
559
595
|
dist = einx.subtract('i, j -> i j', i_seq, j_seq)
|
|
560
|
-
causal_mask = (dist < 0) | (dist > self.window_size)
|
|
596
|
+
causal_mask = (dist < 0) | (dist > (self.max_mem_segments * self.window_size))
|
|
561
597
|
else:
|
|
562
|
-
|
|
598
|
+
if not exists(self.causal_mask) or self.causal_mask.shape != (i, j):
|
|
599
|
+
self.causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
|
|
600
|
+
|
|
601
|
+
causal_mask = self.causal_mask
|
|
563
602
|
|
|
564
603
|
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
|
565
604
|
|
|
@@ -606,6 +645,7 @@ class FeedForward(Module):
|
|
|
606
645
|
return self.proj_out(x)
|
|
607
646
|
|
|
608
647
|
class TransformerXL(Module):
|
|
648
|
+
@beartype
|
|
609
649
|
def __init__(
|
|
610
650
|
self,
|
|
611
651
|
dim,
|
|
@@ -614,27 +654,83 @@ class TransformerXL(Module):
|
|
|
614
654
|
dim_head = 64,
|
|
615
655
|
heads = 8,
|
|
616
656
|
expansion_factor = 4.,
|
|
657
|
+
dim_cond = None,
|
|
617
658
|
final_norm = True,
|
|
618
659
|
fixed_window_size = False,
|
|
660
|
+
gru_layers = False,
|
|
661
|
+
long_term_mem_layers: tuple[int, ...] = (),
|
|
662
|
+
mem_kwargs: dict = dict(),
|
|
663
|
+
num_residual_streams = 1,
|
|
664
|
+
max_mem_segments = 1
|
|
619
665
|
):
|
|
620
666
|
super().__init__()
|
|
667
|
+
self.dim = dim
|
|
668
|
+
|
|
669
|
+
# memory
|
|
670
|
+
|
|
671
|
+
long_term_mem_layers = set(long_term_mem_layers)
|
|
672
|
+
|
|
673
|
+
assert all([1 <= l <= depth for l in long_term_mem_layers])
|
|
674
|
+
|
|
675
|
+
self.long_term_mem_layers = long_term_mem_layers
|
|
676
|
+
self.num_mem_mlps = len(long_term_mem_layers)
|
|
677
|
+
self.has_mem = self.num_mem_mlps > 0
|
|
678
|
+
self.max_mem_segments = max_mem_segments
|
|
679
|
+
|
|
680
|
+
# hyper connections
|
|
681
|
+
|
|
682
|
+
init_hyper_conn, self.expand_streams, self.reduce_streams = mc_get_init_and_expand_reduce_stream_functions(num_residual_streams)
|
|
683
|
+
|
|
684
|
+
# condition
|
|
685
|
+
|
|
686
|
+
condition = exists(dim_cond)
|
|
687
|
+
|
|
688
|
+
self.to_cond_tokens = MLP(dim_cond, dim * 2, activate_last = True) if exists(dim_cond) else None
|
|
689
|
+
|
|
690
|
+
norm_fn = partial(MaybeAdaRMSNormWrapper, dim = dim, dim_cond = (dim * 2) if condition else None)
|
|
691
|
+
|
|
692
|
+
# layers
|
|
621
693
|
|
|
622
694
|
layers = ModuleList([])
|
|
623
695
|
|
|
624
696
|
for i in range(depth):
|
|
625
|
-
|
|
697
|
+
layer = i + 1
|
|
698
|
+
is_first = layer == 1
|
|
699
|
+
has_mem = layer in long_term_mem_layers
|
|
700
|
+
|
|
701
|
+
gru = norm_fn(nn.GRU(dim, dim, batch_first = True)) if gru_layers else None
|
|
702
|
+
|
|
703
|
+
mem = MemoryMLP(dim, **mem_kwargs) if has_mem else None
|
|
704
|
+
|
|
705
|
+
attn = norm_fn(Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first, max_mem_segments = max_mem_segments))
|
|
706
|
+
|
|
707
|
+
ff = norm_fn(FeedForward(dim = dim, expansion_factor = expansion_factor))
|
|
626
708
|
|
|
627
|
-
|
|
709
|
+
# maybe hyper connections
|
|
628
710
|
|
|
629
|
-
|
|
711
|
+
mem_store_hc = init_hyper_conn(dim = dim, add_branch_out_to_residual = False)
|
|
712
|
+
|
|
713
|
+
if num_residual_streams > 1:
|
|
714
|
+
|
|
715
|
+
attn = init_hyper_conn(dim = dim, branch = attn)
|
|
716
|
+
|
|
717
|
+
ff = init_hyper_conn(dim = dim, branch = ff)
|
|
718
|
+
|
|
719
|
+
if gru_layers:
|
|
720
|
+
gru = init_hyper_conn(dim = dim, branch = gru)
|
|
721
|
+
|
|
722
|
+
if has_mem:
|
|
723
|
+
mem = init_hyper_conn(dim = dim, branch = mem, forward_method_names = ('store',))
|
|
630
724
|
|
|
631
725
|
layers.append(ModuleList([
|
|
632
|
-
attn, ff
|
|
726
|
+
gru, mem, mem_store_hc, attn, ff
|
|
633
727
|
]))
|
|
634
728
|
|
|
635
729
|
self.layers = layers
|
|
636
730
|
self.norm = RMSNorm(dim) if final_norm else Identity()
|
|
637
731
|
|
|
732
|
+
self.gru_layers = gru_layers
|
|
733
|
+
|
|
638
734
|
# fixed window size
|
|
639
735
|
|
|
640
736
|
self.fixed_window_size = fixed_window_size
|
|
@@ -643,72 +739,439 @@ class TransformerXL(Module):
|
|
|
643
739
|
def forward(
|
|
644
740
|
self,
|
|
645
741
|
x,
|
|
646
|
-
cache = None,
|
|
647
|
-
return_kv_cache = False
|
|
742
|
+
cache: TransformerMemory | None = None,
|
|
743
|
+
return_kv_cache = False,
|
|
744
|
+
condition: Tensor | None = None,
|
|
745
|
+
cond_mask: Tensor | None = None
|
|
648
746
|
):
|
|
747
|
+
curr_token_seq_len = x.shape[-2]
|
|
649
748
|
|
|
650
|
-
|
|
749
|
+
# cache and residuals
|
|
750
|
+
|
|
751
|
+
num_layers = len(self.layers)
|
|
752
|
+
|
|
753
|
+
# extract variables from cache
|
|
754
|
+
|
|
755
|
+
is_first_window = True
|
|
756
|
+
total_tokens = 0
|
|
757
|
+
kv_cache = gru_cache = mem_mlp_cache = mem_mlp_hidden_states = memory_segments = None
|
|
758
|
+
|
|
759
|
+
if exists(cache):
|
|
760
|
+
total_tokens, kv_cache, gru_cache, mem_mlp_cache, mem_mlp_hidden_states, memory_segments = cache
|
|
761
|
+
is_first_window = total_tokens < self.window_size
|
|
762
|
+
else:
|
|
763
|
+
memory_segments = deque(maxlen = self.max_mem_segments)
|
|
764
|
+
|
|
765
|
+
# handle memory segments
|
|
766
|
+
|
|
767
|
+
past_segments = None
|
|
768
|
+
if len(memory_segments) > 0:
|
|
769
|
+
past_segments = stack(list(memory_segments), dim = 0)
|
|
770
|
+
past_segments = rearrange(past_segments, 'l depth kv b h n d -> depth kv b h (l n) d')
|
|
771
|
+
|
|
772
|
+
kv_cache = default(kv_cache, (None,) * num_layers)
|
|
773
|
+
gru_cache = default(gru_cache, (None,) * num_layers)
|
|
774
|
+
mem_mlp_cache = default(mem_mlp_cache, (None,) * num_layers)
|
|
775
|
+
mem_mlp_hidden_states = default(mem_mlp_hidden_states, (None,) * num_layers)
|
|
776
|
+
|
|
777
|
+
# prepare next cache
|
|
778
|
+
|
|
779
|
+
next_kv_caches = []
|
|
780
|
+
next_gru_hiddens = [] if self.gru_layers else None
|
|
781
|
+
next_mem_mlp_cache = [] if self.has_mem else None
|
|
782
|
+
next_mem_mlp_hidden_states = [] if self.has_mem else None
|
|
783
|
+
next_total_tokens = total_tokens + curr_token_seq_len
|
|
784
|
+
|
|
785
|
+
is_window_boundary = divisible_by(next_total_tokens, self.window_size)
|
|
651
786
|
|
|
652
|
-
next_kv_caches = []
|
|
653
787
|
value_residual = None
|
|
654
788
|
|
|
655
|
-
|
|
789
|
+
# handle condition
|
|
790
|
+
|
|
791
|
+
cond_tokens = None
|
|
792
|
+
|
|
793
|
+
if exists(condition):
|
|
794
|
+
assert exists(self.to_cond_tokens)
|
|
795
|
+
cond_tokens = self.to_cond_tokens(condition)
|
|
796
|
+
|
|
797
|
+
cond_kwargs = dict(cond = cond_tokens, cond_mask = cond_mask)
|
|
798
|
+
|
|
799
|
+
# hc expand
|
|
800
|
+
|
|
801
|
+
x = self.expand_streams(x)
|
|
802
|
+
|
|
803
|
+
# layers
|
|
804
|
+
|
|
805
|
+
for layer_index, ((maybe_gru, maybe_mem, maybe_mem_store_hc, attn, ff), layer_gru_cache, layer_mem_mlp, layer_kv_cache, layer_hidden_states) in enumerate(zip(self.layers, gru_cache, mem_mlp_cache, kv_cache, mem_mlp_hidden_states)):
|
|
806
|
+
|
|
807
|
+
# handle maybe rnn
|
|
656
808
|
|
|
657
|
-
|
|
809
|
+
if exists(maybe_gru):
|
|
810
|
+
x, gru_hiddens = maybe_gru(x, layer_gru_cache, **cond_kwargs)
|
|
658
811
|
|
|
659
|
-
|
|
660
|
-
|
|
812
|
+
next_gru_hiddens.append(gru_hiddens)
|
|
813
|
+
|
|
814
|
+
# maybe handle retrieving
|
|
815
|
+
|
|
816
|
+
is_mem_layer = exists(maybe_mem)
|
|
817
|
+
|
|
818
|
+
if (
|
|
819
|
+
not is_first_window and
|
|
820
|
+
is_mem_layer
|
|
821
|
+
):
|
|
822
|
+
x = maybe_mem(x, layer_mem_mlp)
|
|
823
|
+
|
|
824
|
+
# attention
|
|
825
|
+
|
|
826
|
+
layer_past_segments = None
|
|
827
|
+
if exists(past_segments):
|
|
828
|
+
layer_past_segments = past_segments[layer_index]
|
|
829
|
+
|
|
830
|
+
x, (next_kv_cache, values) = attn(x, **cond_kwargs, value_residual = value_residual, kv_cache = layer_kv_cache, past_segments = layer_past_segments, return_kv_cache = True)
|
|
831
|
+
|
|
832
|
+
# handle storing of memory
|
|
833
|
+
|
|
834
|
+
if self.has_mem:
|
|
835
|
+
next_mem_mlp = layer_mem_mlp
|
|
836
|
+
next_layer_hidden_states = layer_hidden_states
|
|
837
|
+
|
|
838
|
+
if is_mem_layer:
|
|
839
|
+
# accumulate hidden states
|
|
840
|
+
next_layer_hidden_states = safe_cat(layer_hidden_states, x, dim = -2)
|
|
841
|
+
|
|
842
|
+
if is_window_boundary:
|
|
843
|
+
mem_store_input, _ = maybe_mem_store_hc(next_layer_hidden_states)
|
|
844
|
+
|
|
845
|
+
next_mem_mlp = maybe_mem.store(mem_store_input, layer_mem_mlp)
|
|
846
|
+
next_layer_hidden_states = None
|
|
847
|
+
|
|
848
|
+
next_mem_mlp_cache.append(next_mem_mlp)
|
|
849
|
+
next_mem_mlp_hidden_states.append(next_layer_hidden_states)
|
|
850
|
+
|
|
851
|
+
# feedforward
|
|
852
|
+
|
|
853
|
+
x = ff(x, **cond_kwargs)
|
|
661
854
|
|
|
662
855
|
next_kv_caches.append(next_kv_cache)
|
|
663
856
|
value_residual = default(value_residual, values)
|
|
664
857
|
|
|
665
|
-
|
|
858
|
+
# hc reduce
|
|
666
859
|
|
|
667
|
-
|
|
668
|
-
|
|
860
|
+
x = self.reduce_streams(x)
|
|
861
|
+
|
|
862
|
+
# norm
|
|
863
|
+
|
|
864
|
+
embed = self.norm(x)
|
|
669
865
|
|
|
670
866
|
next_kv_cache = stack(next_kv_caches)
|
|
671
867
|
|
|
672
|
-
|
|
868
|
+
if exists(next_gru_hiddens):
|
|
869
|
+
next_gru_hiddens = stack(next_gru_hiddens)
|
|
870
|
+
|
|
871
|
+
next_cache = TransformerMemory(next_total_tokens, next_kv_cache, next_gru_hiddens, next_mem_mlp_cache, next_mem_mlp_hidden_states, memory_segments)
|
|
872
|
+
|
|
873
|
+
return embed, next_cache
|
|
874
|
+
|
|
875
|
+
# simple 2 layer memory mlp
|
|
876
|
+
# following ttt/titans
|
|
877
|
+
|
|
878
|
+
from torch.func import functional_call, grad, vmap
|
|
879
|
+
|
|
880
|
+
class MemoryMLP(Module):
|
|
881
|
+
def __init__(
|
|
882
|
+
self,
|
|
883
|
+
dim,
|
|
884
|
+
expansion_factor = 4.
|
|
885
|
+
):
|
|
886
|
+
super().__init__()
|
|
887
|
+
|
|
888
|
+
dim_hidden = int(dim * expansion_factor)
|
|
889
|
+
|
|
890
|
+
self.norm = nn.RMSNorm(dim)
|
|
891
|
+
|
|
892
|
+
# queries, keys, values
|
|
893
|
+
|
|
894
|
+
self.to_queries = Linear(dim, dim, bias = False)
|
|
895
|
+
|
|
896
|
+
self.to_key_values = nn.Sequential(
|
|
897
|
+
Linear(dim, dim * 2, bias = False),
|
|
898
|
+
nn.SiLU()
|
|
899
|
+
)
|
|
900
|
+
|
|
901
|
+
# memory mlp
|
|
902
|
+
|
|
903
|
+
self.mlp = MLP(dim, dim_hidden, dim, activation = nn.SiLU())
|
|
904
|
+
|
|
905
|
+
# initial params
|
|
906
|
+
|
|
907
|
+
self.init_mlp_params = dict(self.mlp.named_parameters())
|
|
908
|
+
|
|
909
|
+
# grad for storing
|
|
910
|
+
|
|
911
|
+
def retrieve_fn(params, queries: Tensor):
|
|
912
|
+
return functional_call(self.mlp, params, queries)
|
|
913
|
+
|
|
914
|
+
def loss_fn(params, inputs: tuple[Tensor, Tensor, Tensor]):
|
|
915
|
+
keys, values, learning_rate = inputs
|
|
916
|
+
pred = functional_call(self.mlp, params, keys)
|
|
917
|
+
loss = F.mse_loss(pred, values, reduction = 'none')
|
|
918
|
+
loss = loss * learning_rate
|
|
919
|
+
return loss.mean()
|
|
920
|
+
|
|
921
|
+
self.grad_fn = vmap(grad(loss_fn), in_dims = (0, (0, 0, 0)))
|
|
922
|
+
|
|
923
|
+
self.retrieve_fn = vmap(retrieve_fn, in_dims = (0, 0))
|
|
924
|
+
|
|
925
|
+
# forgetting
|
|
926
|
+
|
|
927
|
+
self.to_forget_gate = nn.Sequential(
|
|
928
|
+
Reduce('b n d -> b d', 'mean'),
|
|
929
|
+
nn.Linear(dim, 1, bias = False),
|
|
930
|
+
Rearrange('b 1 -> b'),
|
|
931
|
+
nn.Sigmoid()
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
# loss weight / learning rate
|
|
935
|
+
|
|
936
|
+
self.to_loss_weight = nn.Linear(dim, 1, bias = False)
|
|
937
|
+
|
|
938
|
+
def get_init_mlp_params(
|
|
939
|
+
self,
|
|
940
|
+
batch_size
|
|
941
|
+
):
|
|
942
|
+
return {name: repeat(params, '... -> b ...', b = batch_size) for name, params in self.init_mlp_params.items()}
|
|
943
|
+
|
|
944
|
+
def store(
|
|
945
|
+
self,
|
|
946
|
+
tokens, # (b n d)
|
|
947
|
+
memories: dict[str, Tensor] | None = None
|
|
948
|
+
):
|
|
949
|
+
|
|
950
|
+
batch_size = tokens.shape[0]
|
|
951
|
+
|
|
952
|
+
if not exists(memories):
|
|
953
|
+
memories = self.get_init_mlp_params(batch_size)
|
|
954
|
+
|
|
955
|
+
tokens = self.norm(tokens)
|
|
956
|
+
|
|
957
|
+
keys, values = self.to_key_values(tokens).chunk(2, dim = -1)
|
|
958
|
+
|
|
959
|
+
loss_weight = self.to_loss_weight(tokens)
|
|
960
|
+
|
|
961
|
+
grad = self.grad_fn(memories, (keys, values, loss_weight))
|
|
962
|
+
|
|
963
|
+
# prepare forget
|
|
964
|
+
|
|
965
|
+
forget = self.to_forget_gate(tokens)
|
|
966
|
+
|
|
967
|
+
# update memories
|
|
968
|
+
|
|
969
|
+
next_memories = dict()
|
|
970
|
+
|
|
971
|
+
for param_name, past_memory in memories.items():
|
|
972
|
+
change = grad[param_name]
|
|
973
|
+
|
|
974
|
+
past_memory = einx.multiply('b, b ...', forget, past_memory)
|
|
975
|
+
|
|
976
|
+
next_memories[param_name] = past_memory - change
|
|
977
|
+
|
|
978
|
+
return next_memories
|
|
979
|
+
|
|
980
|
+
def forward(
|
|
981
|
+
self,
|
|
982
|
+
tokens, # (b n d)
|
|
983
|
+
memories: dict[str, Tensor] | None = None
|
|
984
|
+
):
|
|
985
|
+
batch_size = tokens.shape[0]
|
|
986
|
+
|
|
987
|
+
if not exists(memories):
|
|
988
|
+
memories = self.get_init_mlp_params(batch_size)
|
|
989
|
+
|
|
990
|
+
tokens = self.norm(tokens)
|
|
991
|
+
|
|
992
|
+
queries = self.to_queries(tokens)
|
|
993
|
+
|
|
994
|
+
retrieved = self.retrieve_fn(memories, queries)
|
|
995
|
+
|
|
996
|
+
return retrieved
|
|
997
|
+
|
|
998
|
+
# state embedder
|
|
999
|
+
|
|
1000
|
+
class StateEmbedder(Module):
|
|
1001
|
+
@beartype
|
|
1002
|
+
def __init__(
|
|
1003
|
+
self,
|
|
1004
|
+
dim,
|
|
1005
|
+
dim_state: tuple[int, ...] | list[int] | int,
|
|
1006
|
+
num_internal_states: int | None = None,
|
|
1007
|
+
internal_states_selectors: list[list[int]] | None = None
|
|
1008
|
+
):
|
|
1009
|
+
super().__init__()
|
|
1010
|
+
dim_hidden = dim * 2
|
|
1011
|
+
|
|
1012
|
+
self.image_to_token = nn.Sequential(
|
|
1013
|
+
Rearrange('b t c h w -> b c t h w'),
|
|
1014
|
+
nn.Conv3d(3, dim_hidden, (1, 7, 7), padding = (0, 3, 3)),
|
|
1015
|
+
nn.ReLU(),
|
|
1016
|
+
nn.Conv3d(dim_hidden, dim_hidden, (1, 3, 3), stride = (1, 2, 2), padding = (0, 1, 1)),
|
|
1017
|
+
nn.ReLU(),
|
|
1018
|
+
nn.Conv3d(dim_hidden, dim_hidden, (1, 3, 3), stride = (1, 2, 2), padding = (0, 1, 1)),
|
|
1019
|
+
Reduce('b c t h w -> b t c', 'mean'),
|
|
1020
|
+
nn.Linear(dim_hidden, dim)
|
|
1021
|
+
)
|
|
1022
|
+
|
|
1023
|
+
dim_states = (dim_state,) if not isinstance(dim_state, (tuple, list)) else dim_state
|
|
1024
|
+
|
|
1025
|
+
self.dim_states = dim_states
|
|
1026
|
+
self.state_to_token = ModuleList([MLP(dim_state, dim, bias = False) for dim_state in dim_states])
|
|
1027
|
+
|
|
1028
|
+
# internal state embeds for each robot
|
|
1029
|
+
|
|
1030
|
+
self.internal_state_embedder = None
|
|
1031
|
+
|
|
1032
|
+
if exists(num_internal_states) and exists(internal_states_selectors):
|
|
1033
|
+
self.internal_state_embedder = Embed(
|
|
1034
|
+
dim,
|
|
1035
|
+
num_continuous = num_internal_states,
|
|
1036
|
+
selectors = internal_states_selectors
|
|
1037
|
+
)
|
|
1038
|
+
|
|
1039
|
+
@property
|
|
1040
|
+
def device(self):
|
|
1041
|
+
return next(self.parameters()).device
|
|
1042
|
+
|
|
1043
|
+
def forward(
|
|
1044
|
+
self,
|
|
1045
|
+
state,
|
|
1046
|
+
state_type,
|
|
1047
|
+
state_id = 0,
|
|
1048
|
+
internal_state = None,
|
|
1049
|
+
internal_state_selector_id: int | None = None
|
|
1050
|
+
):
|
|
1051
|
+
|
|
1052
|
+
if state_type == 'image':
|
|
1053
|
+
token_embeds = self.image_to_token(state)
|
|
1054
|
+
elif state_type == 'raw':
|
|
1055
|
+
state_to_token = self.state_to_token[state_id]
|
|
1056
|
+
token_embeds = state_to_token(state)
|
|
1057
|
+
else:
|
|
1058
|
+
raise ValueError('invalid state type')
|
|
673
1059
|
|
|
674
|
-
|
|
1060
|
+
if (
|
|
1061
|
+
exists(internal_state_selector_id) and
|
|
1062
|
+
exists(internal_state) and
|
|
1063
|
+
exists(self.internal_state_embedder)
|
|
1064
|
+
):
|
|
1065
|
+
internal_state = internal_state.to(self.device)
|
|
1066
|
+
|
|
1067
|
+
internal_state_embed = self.internal_state_embedder(internal_state, selector_index = internal_state_selector_id)
|
|
1068
|
+
|
|
1069
|
+
token_embeds = token_embeds + internal_state_embed
|
|
1070
|
+
|
|
1071
|
+
return token_embeds
|
|
675
1072
|
|
|
676
1073
|
# class
|
|
677
1074
|
|
|
1075
|
+
OneRewardShaper = Callable[..., float | Tensor]
|
|
1076
|
+
|
|
1077
|
+
MaybeOneRewardShaper = OneRewardShaper | None
|
|
1078
|
+
|
|
1079
|
+
@beartype
|
|
1080
|
+
def default_parse_env_reset_out(reset_out: tuple):
|
|
1081
|
+
assert len(reset_out) == 2
|
|
1082
|
+
return dict(zip(('state', 'info'), reset_out))
|
|
1083
|
+
|
|
1084
|
+
@beartype
|
|
1085
|
+
def default_parse_env_step_out(step_out: tuple):
|
|
1086
|
+
assert len(step_out) in {4, 5}
|
|
1087
|
+
|
|
1088
|
+
if len(step_out) == 5:
|
|
1089
|
+
data_dict = dict(zip(('state', 'reward', 'terminated', 'truncated', 'info'), step_out))
|
|
1090
|
+
elif len(step_out) == 4:
|
|
1091
|
+
data_dict = dict(zip(('state', 'reward', 'terminated', 'info'), step_out))
|
|
1092
|
+
data_dict['truncated'] = False
|
|
1093
|
+
|
|
1094
|
+
return data_dict
|
|
1095
|
+
|
|
678
1096
|
class Locoformer(Module):
|
|
679
1097
|
def __init__(
|
|
680
1098
|
self,
|
|
681
|
-
embedder: Module,
|
|
682
|
-
unembedder:
|
|
1099
|
+
embedder: dict | Module,
|
|
1100
|
+
unembedder: dict | Readout,
|
|
683
1101
|
transformer: dict | TransformerXL,
|
|
1102
|
+
*,
|
|
684
1103
|
discount_factor = 0.999,
|
|
685
1104
|
gae_lam = 0.95,
|
|
686
1105
|
ppo_eps_clip = 0.2,
|
|
687
1106
|
ppo_entropy_weight = 0.01,
|
|
688
1107
|
ppo_value_clip = 0.4,
|
|
689
|
-
|
|
1108
|
+
ppo_soft_constrain_action_max = None,
|
|
1109
|
+
ppo_soft_constrain_action_loss_weight = 0.1,
|
|
1110
|
+
dim_value_input = None, # needs to be set for value network to be available
|
|
690
1111
|
value_network: Module = nn.Identity(),
|
|
1112
|
+
policy_network: Module = nn.Identity(),
|
|
1113
|
+
state_pred_network: Module | None = None,
|
|
1114
|
+
embed_past_action = False,
|
|
1115
|
+
state_pred_loss_weight = 0.05,
|
|
691
1116
|
reward_range: tuple[float, float] | None = None,
|
|
692
|
-
reward_shaping_fns:
|
|
1117
|
+
reward_shaping_fns: (
|
|
1118
|
+
MaybeOneRewardShaper |
|
|
1119
|
+
list[MaybeOneRewardShaper] |
|
|
1120
|
+
list[list[MaybeOneRewardShaper]]
|
|
1121
|
+
) = None,
|
|
693
1122
|
num_reward_bins = 32,
|
|
694
1123
|
hl_gauss_loss_kwargs = dict(),
|
|
695
1124
|
value_loss_weight = 0.5,
|
|
696
1125
|
calc_gae_kwargs: dict = dict(),
|
|
697
|
-
|
|
698
|
-
|
|
1126
|
+
parse_env_reset_out: Callable | None = None,
|
|
1127
|
+
parse_env_step_out: Callable | None = None,
|
|
1128
|
+
recurrent_cache = True,
|
|
1129
|
+
use_spo = False, # simple policy optimization https://arxiv.org/abs/2401.16025 - Levine's group (PI) verified it is more stable than PPO
|
|
1130
|
+
asymmetric_spo = False, # https://openreview.net/pdf?id=BA6n0nmagi
|
|
1131
|
+
max_mem_segments = 1
|
|
699
1132
|
):
|
|
700
1133
|
super().__init__()
|
|
701
1134
|
|
|
702
1135
|
if isinstance(transformer, dict):
|
|
703
|
-
transformer = TransformerXL(**transformer)
|
|
1136
|
+
transformer = TransformerXL(max_mem_segments = max_mem_segments, **transformer)
|
|
704
1137
|
|
|
705
1138
|
self.transformer = transformer
|
|
706
1139
|
|
|
1140
|
+
# handle state embedder
|
|
1141
|
+
|
|
1142
|
+
if isinstance(embedder, dict):
|
|
1143
|
+
embedder = StateEmbedder(**embedder)
|
|
1144
|
+
|
|
707
1145
|
self.embedder = embedder
|
|
1146
|
+
|
|
1147
|
+
# unembed state to actions or ssl predictions
|
|
1148
|
+
|
|
1149
|
+
action_embedder = None
|
|
1150
|
+
if isinstance(unembedder, dict):
|
|
1151
|
+
action_embedder, unembedder = EmbedAndReadout(
|
|
1152
|
+
explicit_single_action_dim_given = True,
|
|
1153
|
+
**unembedder,
|
|
1154
|
+
)
|
|
1155
|
+
|
|
708
1156
|
self.unembedder = unembedder
|
|
709
1157
|
|
|
1158
|
+
# embedding past actions
|
|
1159
|
+
|
|
1160
|
+
self.past_action_embedder = None
|
|
1161
|
+
self.embed_past_action = embed_past_action
|
|
1162
|
+
|
|
1163
|
+
if embed_past_action and exists(action_embedder):
|
|
1164
|
+
self.past_action_embedder = action_embedder
|
|
1165
|
+
|
|
1166
|
+
# attention window related
|
|
1167
|
+
|
|
710
1168
|
self.fixed_window_size = transformer.fixed_window_size
|
|
711
1169
|
self.window_size = transformer.window_size
|
|
1170
|
+
self.max_mem_segments = max_mem_segments
|
|
1171
|
+
|
|
1172
|
+
# policy network
|
|
1173
|
+
|
|
1174
|
+
self.policy_network = policy_network
|
|
712
1175
|
|
|
713
1176
|
# determine value network, using HL Gauss Layer
|
|
714
1177
|
|
|
@@ -731,6 +1194,22 @@ class Locoformer(Module):
|
|
|
731
1194
|
**hl_gauss_loss_kwargs
|
|
732
1195
|
)
|
|
733
1196
|
|
|
1197
|
+
# state prediction related
|
|
1198
|
+
|
|
1199
|
+
self.can_pred_state = exists(state_pred_network)
|
|
1200
|
+
self.state_pred_network = state_pred_network
|
|
1201
|
+
|
|
1202
|
+
if exists(state_pred_network):
|
|
1203
|
+
dim_states = self.embedder.dim_states
|
|
1204
|
+
total_dim_states = sum(dim_states)
|
|
1205
|
+
|
|
1206
|
+
selectors = [t.tolist() for t in arange(total_dim_states).split(dim_states)]
|
|
1207
|
+
|
|
1208
|
+
self.state_pred_head = Readout(transformer.dim, num_continuous = total_dim_states, selectors = selectors)
|
|
1209
|
+
|
|
1210
|
+
self.has_state_pred_loss = state_pred_loss_weight > 0.
|
|
1211
|
+
self.state_pred_loss_weight = state_pred_loss_weight
|
|
1212
|
+
|
|
734
1213
|
# ppo related
|
|
735
1214
|
|
|
736
1215
|
self.discount_factor = discount_factor
|
|
@@ -738,6 +1217,9 @@ class Locoformer(Module):
|
|
|
738
1217
|
self.ppo_eps_clip = ppo_eps_clip
|
|
739
1218
|
self.ppo_entropy_weight = ppo_entropy_weight
|
|
740
1219
|
self.ppo_value_clip = ppo_value_clip
|
|
1220
|
+
self.ppo_soft_constrain_action_max = ppo_soft_constrain_action_max
|
|
1221
|
+
self.ppo_soft_constrain_action_loss_weight = ppo_soft_constrain_action_loss_weight
|
|
1222
|
+
|
|
741
1223
|
self.value_loss_weight = value_loss_weight
|
|
742
1224
|
|
|
743
1225
|
self.calc_gae_kwargs = calc_gae_kwargs
|
|
@@ -746,14 +1228,26 @@ class Locoformer(Module):
|
|
|
746
1228
|
|
|
747
1229
|
self.use_spo = use_spo
|
|
748
1230
|
|
|
749
|
-
|
|
1231
|
+
self.asymmetric_spo = asymmetric_spo
|
|
1232
|
+
|
|
1233
|
+
# maybe recurrent kv cache, from Ding et al. https://arxiv.org/abs/2012.15688
|
|
1234
|
+
|
|
1235
|
+
self.recurrent_cache = recurrent_cache
|
|
750
1236
|
|
|
751
|
-
|
|
1237
|
+
# environment returns to dictionary
|
|
1238
|
+
|
|
1239
|
+
self.parse_env_reset_out = default(parse_env_reset_out, default_parse_env_reset_out)
|
|
1240
|
+
self.parse_env_step_out = default(parse_env_step_out, default_parse_env_step_out)
|
|
752
1241
|
|
|
753
1242
|
# reward shaping function
|
|
754
1243
|
|
|
755
1244
|
self.has_reward_shaping = exists(reward_shaping_fns)
|
|
1245
|
+
|
|
1246
|
+
if is_bearable(reward_shaping_fns, OneRewardShaper):
|
|
1247
|
+
reward_shaping_fns = [reward_shaping_fns]
|
|
1248
|
+
|
|
756
1249
|
self.reward_shaping_fns = reward_shaping_fns
|
|
1250
|
+
self.reward_shaping_fns_multiple_envs = is_bearable(reward_shaping_fns, list[list[OneRewardShaper]])
|
|
757
1251
|
|
|
758
1252
|
# loss related
|
|
759
1253
|
|
|
@@ -764,7 +1258,10 @@ class Locoformer(Module):
|
|
|
764
1258
|
return next(self.parameters()).device
|
|
765
1259
|
|
|
766
1260
|
def actor_parameters(self):
|
|
767
|
-
return
|
|
1261
|
+
return [
|
|
1262
|
+
*self.policy_network.parameters(),
|
|
1263
|
+
*self.unembedder.parameters()
|
|
1264
|
+
]
|
|
768
1265
|
|
|
769
1266
|
def critic_parameters(self):
|
|
770
1267
|
if not exists(self.to_value_pred):
|
|
@@ -772,41 +1269,140 @@ class Locoformer(Module):
|
|
|
772
1269
|
|
|
773
1270
|
return self.to_value_pred.parameters()
|
|
774
1271
|
|
|
1272
|
+
@beartype
|
|
1273
|
+
def learn(
|
|
1274
|
+
self,
|
|
1275
|
+
optims,
|
|
1276
|
+
accelerator,
|
|
1277
|
+
replay,
|
|
1278
|
+
state_embed_kwargs: dict,
|
|
1279
|
+
action_select_kwargs: dict,
|
|
1280
|
+
state_id_kwarg: dict = dict(),
|
|
1281
|
+
batch_size = 16,
|
|
1282
|
+
epochs = 2,
|
|
1283
|
+
use_vision = False,
|
|
1284
|
+
compute_state_pred_loss = False,
|
|
1285
|
+
state_pred_loss_weight = None,
|
|
1286
|
+
maybe_construct_trial_from_buffer: Callable[[ReplayBuffer], Tensor] | None = None
|
|
1287
|
+
):
|
|
1288
|
+
state_field = 'state_image' if use_vision else 'state'
|
|
1289
|
+
|
|
1290
|
+
episode_mapping = None
|
|
1291
|
+
|
|
1292
|
+
if exists(maybe_construct_trial_from_buffer):
|
|
1293
|
+
episode_mapping = maybe_construct_trial_from_buffer(replay)
|
|
1294
|
+
|
|
1295
|
+
dataset = replay.dataset()
|
|
1296
|
+
|
|
1297
|
+
if exists(episode_mapping):
|
|
1298
|
+
dataset = RemappedReplayDataset(dataset, episode_mapping)
|
|
1299
|
+
|
|
1300
|
+
dl = replay.dataloader(
|
|
1301
|
+
batch_size = batch_size,
|
|
1302
|
+
dataset = dataset,
|
|
1303
|
+
shuffle = True
|
|
1304
|
+
)
|
|
1305
|
+
|
|
1306
|
+
self, dl, *optims = accelerator.prepare(self, dl, *optims)
|
|
1307
|
+
|
|
1308
|
+
for _ in range(epochs):
|
|
1309
|
+
for data in dl:
|
|
1310
|
+
|
|
1311
|
+
data = SimpleNamespace(**data)
|
|
1312
|
+
|
|
1313
|
+
actor_loss, critic_loss = self.ppo(
|
|
1314
|
+
state = getattr(data, state_field),
|
|
1315
|
+
internal_state = getattr(data, 'internal_state', None),
|
|
1316
|
+
action = data.action,
|
|
1317
|
+
action_log_prob = data.action_log_prob,
|
|
1318
|
+
reward = data.reward,
|
|
1319
|
+
value = data.value,
|
|
1320
|
+
done = data.done,
|
|
1321
|
+
condition = getattr(data, 'condition', None),
|
|
1322
|
+
cond_mask = getattr(data, 'cond_mask', None),
|
|
1323
|
+
episode_lens = data._lens,
|
|
1324
|
+
optims = optims,
|
|
1325
|
+
state_embed_kwargs = state_embed_kwargs,
|
|
1326
|
+
action_select_kwargs = action_select_kwargs,
|
|
1327
|
+
state_id_kwarg = state_id_kwarg,
|
|
1328
|
+
compute_state_pred_loss = compute_state_pred_loss,
|
|
1329
|
+
state_pred_loss_weight = state_pred_loss_weight,
|
|
1330
|
+
accelerator = accelerator
|
|
1331
|
+
)
|
|
1332
|
+
|
|
1333
|
+
accelerator.print(f'actor: {actor_loss.item():.3f} | critic: {critic_loss.item():.3f}')
|
|
1334
|
+
|
|
1335
|
+
def evolve(
|
|
1336
|
+
self,
|
|
1337
|
+
environment,
|
|
1338
|
+
**kwargs
|
|
1339
|
+
):
|
|
1340
|
+
evo_strat = EvoStrategy(self, environment = environment, **kwargs)
|
|
1341
|
+
evo_strat()
|
|
1342
|
+
|
|
775
1343
|
def ppo(
|
|
776
1344
|
self,
|
|
777
1345
|
state,
|
|
1346
|
+
internal_state,
|
|
778
1347
|
action,
|
|
779
|
-
|
|
1348
|
+
action_log_prob,
|
|
780
1349
|
reward,
|
|
781
|
-
|
|
782
|
-
|
|
1350
|
+
value,
|
|
1351
|
+
done,
|
|
783
1352
|
episode_lens,
|
|
784
|
-
|
|
785
|
-
|
|
1353
|
+
condition: Tensor | None = None,
|
|
1354
|
+
cond_mask: Tensor | None = None,
|
|
1355
|
+
optims: list[Optimizer] | None = None,
|
|
1356
|
+
state_embed_kwargs: dict = dict(),
|
|
1357
|
+
action_select_kwargs: dict = dict(),
|
|
1358
|
+
state_id_kwarg: dict = dict(),
|
|
1359
|
+
compute_state_pred_loss = True,
|
|
1360
|
+
state_pred_loss_weight = None,
|
|
1361
|
+
accelerator = None,
|
|
1362
|
+
max_grad_norm = 0.5
|
|
786
1363
|
):
|
|
787
|
-
|
|
788
|
-
total_learnable_tokens = mask.sum().item()
|
|
1364
|
+
state_pred_loss_weight = default(state_pred_loss_weight, self.state_pred_loss_weight)
|
|
789
1365
|
|
|
1366
|
+
window_size = self.window_size
|
|
1367
|
+
mask = ~done
|
|
790
1368
|
seq_len = state.shape[1]
|
|
791
|
-
|
|
1369
|
+
padding_mask = einx.less('j, i -> i j', arange(seq_len, device = self.device), episode_lens)
|
|
1370
|
+
gae_mask = padding_mask & mask
|
|
792
1371
|
|
|
793
|
-
|
|
1372
|
+
total_learnable_tokens = gae_mask.sum().item()
|
|
794
1373
|
|
|
795
|
-
advantage =
|
|
1374
|
+
advantage, returns = calc_gae(reward, value, masks = gae_mask, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
|
|
796
1375
|
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
1376
|
+
advantage = normalize(advantage, mask = gae_mask)
|
|
1377
|
+
|
|
1378
|
+
advantage = rearrange(advantage, '... -> ... 1')
|
|
1379
|
+
|
|
1380
|
+
past_action = pad_at_dim(action, (1, -1), dim = -2)
|
|
1381
|
+
|
|
1382
|
+
data_dict = dict(
|
|
1383
|
+
state = state,
|
|
1384
|
+
internal_state = internal_state,
|
|
1385
|
+
action = action,
|
|
1386
|
+
past_action = past_action,
|
|
1387
|
+
old_action_log_prob = action_log_prob,
|
|
1388
|
+
reward = reward,
|
|
1389
|
+
mask = mask,
|
|
1390
|
+
advantage = advantage,
|
|
1391
|
+
returns = returns,
|
|
1392
|
+
windowed_gae_mask = gae_mask,
|
|
1393
|
+
condition = condition,
|
|
1394
|
+
cond_mask = cond_mask
|
|
1395
|
+
)
|
|
1396
|
+
|
|
1397
|
+
num_windows = math.ceil(seq_len / window_size)
|
|
1398
|
+
|
|
1399
|
+
windowed_data = dict()
|
|
1400
|
+
|
|
1401
|
+
for name, tensor in data_dict.items():
|
|
1402
|
+
if exists(tensor):
|
|
1403
|
+
windowed_data[name] = tensor.split(window_size, dim = 1)
|
|
1404
|
+
else:
|
|
1405
|
+
windowed_data[name] = (None,) * num_windows
|
|
810
1406
|
|
|
811
1407
|
mean_actor_loss = self.zero.clone()
|
|
812
1408
|
mean_critic_loss = self.zero.clone()
|
|
@@ -815,52 +1411,90 @@ class Locoformer(Module):
|
|
|
815
1411
|
|
|
816
1412
|
cache = None
|
|
817
1413
|
|
|
818
|
-
for (
|
|
819
|
-
state,
|
|
820
|
-
action,
|
|
821
|
-
old_action_log_prob,
|
|
822
|
-
reward,
|
|
823
|
-
old_value,
|
|
824
|
-
mask,
|
|
825
|
-
advantage,
|
|
826
|
-
returns
|
|
827
|
-
) in zip(*windowed_tensors):
|
|
1414
|
+
for window_tensors in zip(*windowed_data.values()):
|
|
828
1415
|
|
|
829
|
-
|
|
830
|
-
entropy = calc_entropy(action_logits)
|
|
1416
|
+
data = SimpleNamespace(**dict(zip(windowed_data.keys(), window_tensors)))
|
|
831
1417
|
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
log_prob =
|
|
1418
|
+
((action_logits, maybe_state_pred), value_logits), cache = self.forward(data.state, past_action = data.past_action if self.embed_past_action else None, state_embed_kwargs = {**state_embed_kwargs, 'internal_state': data.internal_state}, action_select_kwargs = action_select_kwargs, state_id_kwarg = state_id_kwarg, condition = data.condition, cond_mask = data.cond_mask, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True, return_state_pred = True)
|
|
1419
|
+
|
|
1420
|
+
log_prob = self.unembedder.log_prob(action_logits, data.action, **action_select_kwargs)
|
|
835
1421
|
|
|
836
1422
|
# update actor, classic clipped surrogate loss
|
|
837
1423
|
|
|
838
1424
|
eps_clip = self.ppo_eps_clip
|
|
839
|
-
ratio = (log_prob - old_action_log_prob).exp()
|
|
1425
|
+
ratio = (log_prob - data.old_action_log_prob).exp()
|
|
1426
|
+
|
|
1427
|
+
calc_spo = lambda: -(ratio * data.advantage - (data.advantage.abs() * (ratio - 1.).square()) / (2 * eps_clip))
|
|
1428
|
+
|
|
1429
|
+
calc_ppo = lambda: -torch.min(ratio * data.advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * data.advantage)
|
|
840
1430
|
|
|
841
|
-
if self.
|
|
842
|
-
actor_loss =
|
|
1431
|
+
if self.asymmetric_spo:
|
|
1432
|
+
actor_loss = torch.where(data.advantage >= 0, calc_ppo(), calc_spo())
|
|
1433
|
+
elif self.use_spo:
|
|
1434
|
+
actor_loss = calc_spo()
|
|
843
1435
|
else:
|
|
844
|
-
actor_loss =
|
|
1436
|
+
actor_loss = calc_ppo()
|
|
845
1437
|
|
|
846
|
-
|
|
1438
|
+
# maybe entropy
|
|
847
1439
|
|
|
848
|
-
|
|
849
|
-
|
|
1440
|
+
if self.ppo_entropy_weight > 0.:
|
|
1441
|
+
entropy = self.unembedder.entropy(action_logits, **action_select_kwargs)
|
|
850
1442
|
|
|
851
|
-
|
|
1443
|
+
if exists(entropy):
|
|
1444
|
+
actor_loss = actor_loss - self.ppo_entropy_weight * entropy
|
|
1445
|
+
|
|
1446
|
+
windowed_actor_loss = actor_loss[data.windowed_gae_mask].sum() / total_learnable_tokens
|
|
1447
|
+
|
|
1448
|
+
# maybe add state prediction
|
|
1449
|
+
|
|
1450
|
+
if (
|
|
1451
|
+
exists(maybe_state_pred) and
|
|
1452
|
+
self.has_state_pred_loss and
|
|
1453
|
+
compute_state_pred_loss and
|
|
1454
|
+
data.windowed_gae_mask[:, :-1].any()
|
|
1455
|
+
):
|
|
1456
|
+
state_pred = maybe_state_pred[:, :-1]
|
|
1457
|
+
state_labels = data.state[:, 1:]
|
|
1458
|
+
loss_mask = data.windowed_gae_mask[:, :-1]
|
|
1459
|
+
|
|
1460
|
+
state_id = state_id_kwarg.get('state_id', 0)
|
|
1461
|
+
|
|
1462
|
+
state_pred_loss = self.state_pred_head.calculate_loss(state_pred, state_labels, selector_index = state_id, return_unreduced_loss = True)
|
|
1463
|
+
|
|
1464
|
+
state_pred_loss = state_pred_loss.mean(dim = -1) # average over state features
|
|
1465
|
+
|
|
1466
|
+
windowed_state_pred_loss = state_pred_loss[loss_mask].sum() / total_learnable_tokens
|
|
1467
|
+
|
|
1468
|
+
windowed_actor_loss = (
|
|
1469
|
+
windowed_actor_loss +
|
|
1470
|
+
windowed_state_pred_loss * state_pred_loss_weight
|
|
1471
|
+
)
|
|
1472
|
+
|
|
1473
|
+
# maybe soft constrain continuous actions
|
|
1474
|
+
|
|
1475
|
+
if (
|
|
1476
|
+
self.ppo_soft_constrain_action_max and
|
|
1477
|
+
self.unembedder.has_continuous
|
|
1478
|
+
):
|
|
1479
|
+
loss_mask = data.windowed_gae_mask
|
|
852
1480
|
|
|
853
|
-
|
|
1481
|
+
soft_constrain_loss = (action_logits[..., 0].abs() - self.ppo_soft_constrain_action_max).relu().pow(2)
|
|
1482
|
+
windowed_soft_constrain_loss = soft_constrain_loss[loss_mask].sum() / total_learnable_tokens
|
|
854
1483
|
|
|
855
|
-
|
|
856
|
-
|
|
1484
|
+
windowed_actor_loss = (
|
|
1485
|
+
windowed_actor_loss +
|
|
1486
|
+
windowed_soft_constrain_loss * self.ppo_soft_constrain_action_loss_weight
|
|
1487
|
+
)
|
|
857
1488
|
|
|
858
|
-
|
|
859
|
-
clipped_value_loss = self.hl_gauss_loss(clipped_value, returns, reduction = 'none')
|
|
1489
|
+
# windowed loss
|
|
860
1490
|
|
|
861
|
-
|
|
1491
|
+
windowed_actor_loss.backward(retain_graph = True)
|
|
1492
|
+
|
|
1493
|
+
# update critic
|
|
862
1494
|
|
|
863
|
-
|
|
1495
|
+
value_loss = self.hl_gauss_loss(value_logits, data.returns, reduction = 'none') * self.value_loss_weight
|
|
1496
|
+
|
|
1497
|
+
windowed_critic_loss = value_loss[data.windowed_gae_mask].sum() / total_learnable_tokens
|
|
864
1498
|
windowed_critic_loss.backward(retain_graph = True)
|
|
865
1499
|
|
|
866
1500
|
# accumulate
|
|
@@ -870,31 +1504,69 @@ class Locoformer(Module):
|
|
|
870
1504
|
|
|
871
1505
|
# optimizer update
|
|
872
1506
|
|
|
873
|
-
if exists(
|
|
874
|
-
|
|
875
|
-
|
|
1507
|
+
if exists(optims):
|
|
1508
|
+
|
|
1509
|
+
if exists(accelerator):
|
|
1510
|
+
accelerator.clip_grad_norm_(self.parameters(), max_grad_norm)
|
|
1511
|
+
else:
|
|
1512
|
+
nn.utils.clip_grad_norm_(self.parameters(), max_grad_norm)
|
|
876
1513
|
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
1514
|
+
for optim in optims:
|
|
1515
|
+
optim.step()
|
|
1516
|
+
optim.zero_grad()
|
|
880
1517
|
|
|
881
1518
|
# return losses for logging
|
|
882
1519
|
|
|
883
1520
|
return mean_actor_loss.detach(), mean_critic_loss.detach()
|
|
884
1521
|
|
|
885
|
-
def
|
|
1522
|
+
def state_and_command_to_rewards(
|
|
886
1523
|
self,
|
|
887
|
-
state
|
|
1524
|
+
state,
|
|
1525
|
+
commands = None,
|
|
1526
|
+
env_index: int | None = None
|
|
888
1527
|
) -> Tensor:
|
|
889
1528
|
|
|
890
1529
|
assert self.has_reward_shaping
|
|
1530
|
+
assert xnor(exists(env_index), self.reward_shaping_fns_multiple_envs), f'`env_index` must be passed in if multiple reward shaping functions are defined, and vice versa (not passed in if only single list of reward shaping functions)'
|
|
1531
|
+
|
|
1532
|
+
rewards = []
|
|
1533
|
+
|
|
1534
|
+
reward_shaping_fns = self.reward_shaping_fns[env_index] if exists(env_index) else self.reward_shaping_fns
|
|
891
1535
|
|
|
892
|
-
|
|
1536
|
+
for fn in reward_shaping_fns:
|
|
1537
|
+
param_names = get_param_names(fn)
|
|
1538
|
+
param_names = set(param_names) & {'state', 'command'}
|
|
1539
|
+
|
|
1540
|
+
if param_names == {'state'}: # only state
|
|
1541
|
+
reward = fn(state = state)
|
|
1542
|
+
elif param_names == {'state', 'command'}: # state and command
|
|
1543
|
+
reward = fn(state = state, command = commands)
|
|
1544
|
+
else:
|
|
1545
|
+
raise ValueError('invalid number of arguments for reward shaping function')
|
|
1546
|
+
|
|
1547
|
+
rewards.append(reward)
|
|
1548
|
+
|
|
1549
|
+
# cast to Tensor if returns a float, just make it flexible for researcher
|
|
893
1550
|
|
|
894
1551
|
rewards = [tensor(reward) if not is_tensor(reward) else reward for reward in rewards]
|
|
895
|
-
return stack(rewards)
|
|
896
1552
|
|
|
897
|
-
|
|
1553
|
+
assert all([r.numel() == 1 for r in rewards])
|
|
1554
|
+
|
|
1555
|
+
if len(rewards) == 0:
|
|
1556
|
+
return None
|
|
1557
|
+
|
|
1558
|
+
packed_rewards, _ = pack(rewards, '*')
|
|
1559
|
+
return packed_rewards
|
|
1560
|
+
|
|
1561
|
+
@beartype
|
|
1562
|
+
def wrap_env_functions(
|
|
1563
|
+
self,
|
|
1564
|
+
env,
|
|
1565
|
+
env_output_transforms: dict[str, Callable] = dict(),
|
|
1566
|
+
state_transform: Callable = identity,
|
|
1567
|
+
reward_norm = 1.,
|
|
1568
|
+
command_generator: Callable = always(None)
|
|
1569
|
+
):
|
|
898
1570
|
|
|
899
1571
|
def transform_output(el):
|
|
900
1572
|
if isinstance(el, ndarray):
|
|
@@ -907,23 +1579,57 @@ class Locoformer(Module):
|
|
|
907
1579
|
def wrapped_reset(*args, **kwargs):
|
|
908
1580
|
env_reset_out = env.reset(*args, **kwargs)
|
|
909
1581
|
|
|
910
|
-
|
|
1582
|
+
env_reset_out_torch = tree_map(transform_output, env_reset_out)
|
|
1583
|
+
|
|
1584
|
+
env_reset_out_dict = self.parse_env_reset_out(env_reset_out_torch)
|
|
1585
|
+
|
|
1586
|
+
env_reset_out_dict['state'] = state_transform(env_reset_out_dict['state'])
|
|
911
1587
|
|
|
912
|
-
|
|
1588
|
+
derived_states = dict()
|
|
1589
|
+
|
|
1590
|
+
for derived_name, transform in env_output_transforms.items():
|
|
1591
|
+
derived_states[derived_name] = transform(env_reset_out_dict, env)
|
|
1592
|
+
|
|
1593
|
+
env_reset_out_dict['derived_state'] = derived_states
|
|
1594
|
+
|
|
1595
|
+
return env_reset_out_dict
|
|
1596
|
+
|
|
1597
|
+
def wrapped_step(action, *args, command = None, env_index = None, **kwargs):
|
|
913
1598
|
|
|
914
1599
|
if is_tensor(action):
|
|
915
|
-
|
|
1600
|
+
if action.numel() == 1:
|
|
1601
|
+
action = action.item()
|
|
1602
|
+
else:
|
|
1603
|
+
action = action.tolist()
|
|
916
1604
|
|
|
917
1605
|
env_step_out = env.step(action, *args, **kwargs)
|
|
918
1606
|
|
|
919
1607
|
env_step_out_torch = tree_map(transform_output, env_step_out)
|
|
920
1608
|
|
|
921
|
-
|
|
922
|
-
|
|
1609
|
+
env_step_out_dict = self.parse_env_step_out(env_step_out_torch)
|
|
1610
|
+
|
|
1611
|
+
env_step_out_dict['state'] = state_transform(env_step_out_dict['state'])
|
|
1612
|
+
|
|
1613
|
+
env_step_out_dict['reward'] = env_step_out_dict['reward'] / reward_norm
|
|
1614
|
+
|
|
1615
|
+
if self.has_reward_shaping:
|
|
1616
|
+
shaped_rewards = self.state_and_command_to_rewards(env_step_out_dict['state'], command, env_index = env_index)
|
|
1617
|
+
|
|
1618
|
+
if exists(shaped_rewards):
|
|
1619
|
+
env_step_out_dict['shaped_rewards'] = shaped_rewards
|
|
1620
|
+
|
|
1621
|
+
# add shaped rewards to main reward
|
|
1622
|
+
|
|
1623
|
+
env_step_out_dict['reward'] = env_step_out_dict['reward'] + shaped_rewards.sum()
|
|
923
1624
|
|
|
924
|
-
|
|
1625
|
+
derived_states = dict()
|
|
925
1626
|
|
|
926
|
-
|
|
1627
|
+
for derived_name, transform in env_output_transforms.items():
|
|
1628
|
+
derived_states[derived_name] = transform(env_step_out_dict, env)
|
|
1629
|
+
|
|
1630
|
+
env_step_out_dict['derived_state'] = derived_states
|
|
1631
|
+
|
|
1632
|
+
return env_step_out_dict
|
|
927
1633
|
|
|
928
1634
|
return wrapped_reset, wrapped_step
|
|
929
1635
|
|
|
@@ -936,24 +1642,54 @@ class Locoformer(Module):
|
|
|
936
1642
|
state_time_dim = 1,
|
|
937
1643
|
**kwargs
|
|
938
1644
|
):
|
|
939
|
-
window_size = self.window_size
|
|
940
1645
|
|
|
941
1646
|
cache = None
|
|
942
1647
|
|
|
943
|
-
def stateful_forward(
|
|
1648
|
+
def stateful_forward(
|
|
1649
|
+
state: Tensor,
|
|
1650
|
+
condition: Tensor | None = None,
|
|
1651
|
+
cond_mask: Tensor | None = None,
|
|
1652
|
+
**override_kwargs
|
|
1653
|
+
):
|
|
944
1654
|
nonlocal cache
|
|
945
1655
|
|
|
1656
|
+
state = state.to(self.device)
|
|
1657
|
+
|
|
1658
|
+
if exists(condition):
|
|
1659
|
+
condition = condition.to(self.device)
|
|
1660
|
+
|
|
1661
|
+
if exists(cond_mask):
|
|
1662
|
+
cond_mask = cond_mask.to(self.device)
|
|
1663
|
+
|
|
946
1664
|
# handle no batch or time, for easier time rolling out against envs
|
|
947
1665
|
|
|
948
1666
|
if not has_batch_dim:
|
|
949
1667
|
state = rearrange(state, '... -> 1 ...')
|
|
950
1668
|
|
|
1669
|
+
if exists(condition):
|
|
1670
|
+
condition = rearrange(condition, '... -> 1 ...')
|
|
1671
|
+
|
|
1672
|
+
if exists(cond_mask):
|
|
1673
|
+
cond_mask = rearrange(cond_mask, '... -> 1 ...')
|
|
1674
|
+
|
|
951
1675
|
if not has_time_dim:
|
|
952
1676
|
state = state.unsqueeze(state_time_dim)
|
|
953
1677
|
|
|
1678
|
+
if exists(condition):
|
|
1679
|
+
condition = rearrange(condition, '... d -> ... 1 d')
|
|
1680
|
+
|
|
1681
|
+
if exists(cond_mask):
|
|
1682
|
+
cond_mask = cond_mask.unsqueeze(state_time_dim)
|
|
1683
|
+
|
|
954
1684
|
# forwards
|
|
955
1685
|
|
|
956
|
-
out, cache = self.forward(
|
|
1686
|
+
out, cache = self.forward(
|
|
1687
|
+
state,
|
|
1688
|
+
condition = condition,
|
|
1689
|
+
cond_mask = cond_mask,
|
|
1690
|
+
cache = cache,
|
|
1691
|
+
**{**kwargs, **override_kwargs}
|
|
1692
|
+
)
|
|
957
1693
|
|
|
958
1694
|
# maybe remove batch or time
|
|
959
1695
|
|
|
@@ -984,49 +1720,275 @@ class Locoformer(Module):
|
|
|
984
1720
|
|
|
985
1721
|
return stateful_forward, initial_logits
|
|
986
1722
|
|
|
1723
|
+
@beartype
|
|
1724
|
+
def gather_experience_from_env_(
|
|
1725
|
+
self,
|
|
1726
|
+
wrapped_env_functions: tuple[Callable, Callable],
|
|
1727
|
+
replay: ReplayBuffer,
|
|
1728
|
+
embed_past_action = False,
|
|
1729
|
+
max_timesteps = None,
|
|
1730
|
+
use_vision = False,
|
|
1731
|
+
action_select_kwargs: dict = dict(),
|
|
1732
|
+
state_embed_kwargs: dict = dict(),
|
|
1733
|
+
state_id_kwarg: dict = dict(),
|
|
1734
|
+
env_index: int | None = None,
|
|
1735
|
+
state_entropy_bonus_weight = 0.,
|
|
1736
|
+
action_rescale_range: tuple[float, float] | None = None,
|
|
1737
|
+
command_fn: Callable = always(None)
|
|
1738
|
+
):
|
|
1739
|
+
|
|
1740
|
+
env_reset, env_step = wrapped_env_functions
|
|
1741
|
+
|
|
1742
|
+
reset_out_dict = env_reset()
|
|
1743
|
+
derived, state = pick(reset_out_dict, ('derived_state', 'state'))
|
|
1744
|
+
|
|
1745
|
+
state_image = derived.get('state_image', None)
|
|
1746
|
+
internal_state = derived.get('internal_state', None)
|
|
1747
|
+
|
|
1748
|
+
timestep = 0
|
|
1749
|
+
|
|
1750
|
+
max_timesteps = default(max_timesteps, replay.max_timesteps)
|
|
1751
|
+
|
|
1752
|
+
stateful_forward = self.get_stateful_forward(
|
|
1753
|
+
has_batch_dim = False,
|
|
1754
|
+
has_time_dim = False,
|
|
1755
|
+
inference_mode = True
|
|
1756
|
+
)
|
|
1757
|
+
|
|
1758
|
+
cum_rewards = 0.
|
|
1759
|
+
|
|
1760
|
+
with replay.one_episode() as final_meta_data_store_dict:
|
|
1761
|
+
|
|
1762
|
+
past_action = None
|
|
1763
|
+
|
|
1764
|
+
while True:
|
|
1765
|
+
state_for_model = state_image if use_vision else state
|
|
1766
|
+
|
|
1767
|
+
maybe_command = command_fn(state_for_model)
|
|
1768
|
+
|
|
1769
|
+
# predict next action
|
|
1770
|
+
|
|
1771
|
+
(action_logits, state_pred), value = stateful_forward(
|
|
1772
|
+
state_for_model,
|
|
1773
|
+
condition = maybe_command,
|
|
1774
|
+
cond_mask = tensor(exists(maybe_command)),
|
|
1775
|
+
state_embed_kwargs = {**state_embed_kwargs, 'internal_state': internal_state},
|
|
1776
|
+
action_select_kwargs = action_select_kwargs,
|
|
1777
|
+
state_id_kwarg = state_id_kwarg,
|
|
1778
|
+
past_action = past_action if embed_past_action else None,
|
|
1779
|
+
return_values = True,
|
|
1780
|
+
return_state_pred = True
|
|
1781
|
+
)
|
|
1782
|
+
|
|
1783
|
+
action = self.unembedder.sample(action_logits, **action_select_kwargs)
|
|
1784
|
+
|
|
1785
|
+
# maybe clip
|
|
1786
|
+
|
|
1787
|
+
if exists(action_rescale_range):
|
|
1788
|
+
min_val, max_val = action_rescale_range
|
|
1789
|
+
action = (action + 1.) * 0.5 * (max_val - min_val) + min_val
|
|
1790
|
+
|
|
1791
|
+
# pass to environment
|
|
1792
|
+
|
|
1793
|
+
step_dict = env_step(action, command = maybe_command, env_index = env_index)
|
|
1794
|
+
|
|
1795
|
+
derived, next_state, reward, terminated, truncated = pick(step_dict, ('derived_state', 'state', 'reward', 'terminated', 'truncated'))
|
|
1796
|
+
|
|
1797
|
+
next_state_image = derived.get('state_image', None)
|
|
1798
|
+
next_internal_state = derived.get('internal_state', None)
|
|
1799
|
+
|
|
1800
|
+
# maybe state entropy bonus
|
|
1801
|
+
|
|
1802
|
+
if state_entropy_bonus_weight > 0. and exists(state_pred):
|
|
1803
|
+
state_id = state_id_kwarg.get('state_id', 0)
|
|
1804
|
+
entropy = self.state_pred_head.entropy(state_pred, selector_index = state_id)
|
|
1805
|
+
|
|
1806
|
+
state_entropy_bonus = (entropy * state_entropy_bonus_weight).sum()
|
|
1807
|
+
|
|
1808
|
+
reward = reward + state_entropy_bonus.item() # the entropy is directly related to log variance
|
|
1809
|
+
|
|
1810
|
+
# cum rewards
|
|
1811
|
+
|
|
1812
|
+
cum_rewards += reward
|
|
1813
|
+
|
|
1814
|
+
# increment counters
|
|
1815
|
+
# we will store the step with done=False, as only the bootstrap/boundary node is done=True
|
|
1816
|
+
|
|
1817
|
+
exceeds_max_timesteps = max_timesteps >= 0 and timestep == (max_timesteps - 1)
|
|
1818
|
+
should_stop = truncated or terminated or tensor(exceeds_max_timesteps)
|
|
1819
|
+
|
|
1820
|
+
# get log prob of action
|
|
1821
|
+
|
|
1822
|
+
action_log_prob = self.unembedder.log_prob(action_logits, action, **action_select_kwargs)
|
|
1823
|
+
|
|
1824
|
+
memory = replay.store(
|
|
1825
|
+
state = state,
|
|
1826
|
+
state_image = state_image,
|
|
1827
|
+
action = action,
|
|
1828
|
+
action_log_prob = action_log_prob,
|
|
1829
|
+
internal_state = internal_state,
|
|
1830
|
+
reward = reward,
|
|
1831
|
+
value = value,
|
|
1832
|
+
done = tensor(False),
|
|
1833
|
+
condition = maybe_command,
|
|
1834
|
+
cond_mask = tensor(exists(maybe_command))
|
|
1835
|
+
)
|
|
1836
|
+
|
|
1837
|
+
timestep += 1
|
|
1838
|
+
|
|
1839
|
+
# break if done or exceed max timestep
|
|
1840
|
+
if should_stop:
|
|
1841
|
+
|
|
1842
|
+
# handle bootstrap value, which is a non-learnable timestep added with the next value for GAE
|
|
1843
|
+
# only if terminated signal not detected
|
|
1844
|
+
if not terminated:
|
|
1845
|
+
next_state_for_model = next_state_image if use_vision else next_state
|
|
1846
|
+
|
|
1847
|
+
_, next_value = stateful_forward(next_state_for_model, condition = maybe_command, cond_mask = tensor(exists(maybe_command)), return_values = True, state_embed_kwargs = {**state_embed_kwargs, 'internal_state': internal_state}, state_id_kwarg = state_id_kwarg, action_select_kwargs = action_select_kwargs)
|
|
1848
|
+
|
|
1849
|
+
terminal_node = dict(
|
|
1850
|
+
state = next_state,
|
|
1851
|
+
state_image = next_state_image,
|
|
1852
|
+
internal_state = next_internal_state,
|
|
1853
|
+
value = next_value,
|
|
1854
|
+
reward = next_value,
|
|
1855
|
+
done = True,
|
|
1856
|
+
condition = maybe_command,
|
|
1857
|
+
cond_mask = exists(maybe_command)
|
|
1858
|
+
)
|
|
1859
|
+
|
|
1860
|
+
else:
|
|
1861
|
+
# terminal node - store a step with 0 reward and value, and done=True, to stop GAE scan
|
|
1862
|
+
terminal_node = dict(
|
|
1863
|
+
state = next_state,
|
|
1864
|
+
state_image = next_state_image,
|
|
1865
|
+
internal_state = next_internal_state,
|
|
1866
|
+
value = torch.zeros_like(value),
|
|
1867
|
+
reward = torch.zeros_like(reward),
|
|
1868
|
+
done = True,
|
|
1869
|
+
condition = maybe_command,
|
|
1870
|
+
cond_mask = exists(maybe_command)
|
|
1871
|
+
)
|
|
1872
|
+
|
|
1873
|
+
terminal_node = {key: value for key, value in terminal_node.items() if key in memory._fields}
|
|
1874
|
+
|
|
1875
|
+
terminal_memory = memory._replace(**terminal_node)
|
|
1876
|
+
|
|
1877
|
+
replay.store(**terminal_memory._asdict())
|
|
1878
|
+
|
|
1879
|
+
# store the final cumulative reward into meta data
|
|
1880
|
+
|
|
1881
|
+
final_meta_data_store_dict.update(cum_rewards = cum_rewards)
|
|
1882
|
+
|
|
1883
|
+
break
|
|
1884
|
+
|
|
1885
|
+
state = next_state
|
|
1886
|
+
state_image = next_state_image
|
|
1887
|
+
internal_state = next_internal_state
|
|
1888
|
+
|
|
1889
|
+
past_action = action
|
|
1890
|
+
|
|
1891
|
+
return cum_rewards
|
|
1892
|
+
|
|
987
1893
|
def forward(
|
|
988
1894
|
self,
|
|
989
1895
|
state: Tensor,
|
|
990
|
-
cache:
|
|
1896
|
+
cache: TransformerMemory | None = None,
|
|
1897
|
+
condition: Tensor | None = None,
|
|
1898
|
+
cond_mask: Tensor | None = None,
|
|
1899
|
+
past_action: Tensor | None = None,
|
|
1900
|
+
state_embed_kwargs: dict = dict(),
|
|
1901
|
+
action_select_kwargs: dict = dict(),
|
|
1902
|
+
state_id_kwarg: dict = dict(),
|
|
991
1903
|
detach_cache = False,
|
|
992
1904
|
return_values = False,
|
|
1905
|
+
return_state_pred = False,
|
|
993
1906
|
return_raw_value_logits = False
|
|
994
1907
|
):
|
|
995
1908
|
|
|
996
1909
|
state = state.to(self.device)
|
|
997
1910
|
|
|
998
|
-
|
|
1911
|
+
# move condition
|
|
999
1912
|
|
|
1000
|
-
|
|
1913
|
+
if exists(condition):
|
|
1914
|
+
condition = condition.to(self.device)
|
|
1001
1915
|
|
|
1002
|
-
|
|
1916
|
+
# determine which function to invoke for state to token for transformer
|
|
1003
1917
|
|
|
1004
|
-
|
|
1918
|
+
state_to_token = self.embedder
|
|
1005
1919
|
|
|
1006
|
-
|
|
1007
|
-
timestep_start = 0
|
|
1920
|
+
# embed
|
|
1008
1921
|
|
|
1009
|
-
|
|
1010
|
-
|
|
1922
|
+
tokens = state_to_token(state, **state_embed_kwargs, **state_id_kwarg)
|
|
1923
|
+
|
|
1924
|
+
# maybe add past action
|
|
1925
|
+
|
|
1926
|
+
# determine if first window and start of sequence
|
|
1927
|
+
|
|
1928
|
+
total_tokens = cache.total_tokens if exists(cache) else 0
|
|
1929
|
+
|
|
1930
|
+
is_start_of_sequence = total_tokens == 0
|
|
1931
|
+
|
|
1932
|
+
# maybe add past action
|
|
1933
|
+
|
|
1934
|
+
if exists(past_action):
|
|
1935
|
+
assert self.embed_past_action
|
|
1936
|
+
past_action_embed = self.past_action_embedder(past_action, **action_select_kwargs)
|
|
1937
|
+
|
|
1938
|
+
if is_start_of_sequence:
|
|
1939
|
+
past_action_embed = pad_at_dim(past_action_embed[..., 1:, :], (1, 0), dim = -2)
|
|
1940
|
+
|
|
1941
|
+
tokens = tokens + past_action_embed
|
|
1942
|
+
|
|
1943
|
+
# time
|
|
1944
|
+
|
|
1945
|
+
time = tokens.shape[-2]
|
|
1011
1946
|
|
|
1012
1947
|
# an assert - make sure during training or inference, forward never gets anything that crosses the window segment boundary, to open up some possibilities with extending memory
|
|
1013
1948
|
|
|
1014
|
-
assert ((
|
|
1949
|
+
assert ((total_tokens % self.window_size) + time) <= self.window_size
|
|
1015
1950
|
|
|
1016
1951
|
# attention
|
|
1017
1952
|
|
|
1018
|
-
|
|
1953
|
+
if not exists(cache):
|
|
1954
|
+
memory_segments = deque(maxlen = self.max_mem_segments)
|
|
1955
|
+
else:
|
|
1956
|
+
total_tokens, kv_cache, gru_cache, mem_mlp_cache, mem_mlp_hidden_states, memory_segments = cache
|
|
1957
|
+
|
|
1958
|
+
attn_cache = cache
|
|
1959
|
+
|
|
1960
|
+
embed, cache = self.transformer(
|
|
1961
|
+
tokens,
|
|
1962
|
+
condition = condition,
|
|
1963
|
+
cond_mask = cond_mask,
|
|
1964
|
+
cache = attn_cache,
|
|
1965
|
+
return_kv_cache = True
|
|
1966
|
+
)
|
|
1019
1967
|
|
|
1020
1968
|
# unembed to actions - in language models this would be the next state
|
|
1021
1969
|
|
|
1022
|
-
|
|
1970
|
+
policy_embed = self.policy_network(embed)
|
|
1971
|
+
|
|
1972
|
+
action_logits = self.unembedder(policy_embed, **action_select_kwargs)
|
|
1023
1973
|
|
|
1024
1974
|
out = action_logits
|
|
1025
1975
|
|
|
1976
|
+
# maybe return state prediction
|
|
1977
|
+
|
|
1978
|
+
if return_state_pred:
|
|
1979
|
+
state_pred = None
|
|
1980
|
+
|
|
1981
|
+
if self.can_pred_state:
|
|
1982
|
+
state_id = state_id_kwarg.get('state_id', 0)
|
|
1983
|
+
state_pred_embed = self.state_pred_network(embed)
|
|
1984
|
+
state_pred = self.state_pred_head(state_pred_embed, selector_index = state_id)
|
|
1985
|
+
|
|
1986
|
+
out = (out, state_pred)
|
|
1987
|
+
|
|
1026
1988
|
# maybe detach cache
|
|
1027
1989
|
|
|
1028
1990
|
if detach_cache:
|
|
1029
|
-
|
|
1991
|
+
cache = tree_map_tensor(cache, lambda t: t.detach())
|
|
1030
1992
|
|
|
1031
1993
|
# handle returning of values
|
|
1032
1994
|
|
|
@@ -1040,20 +2002,22 @@ class Locoformer(Module):
|
|
|
1040
2002
|
|
|
1041
2003
|
out = (out, values)
|
|
1042
2004
|
|
|
1043
|
-
#
|
|
2005
|
+
# handle curtailing kv cache at the right intervals
|
|
1044
2006
|
|
|
1045
|
-
|
|
2007
|
+
total_tokens, kv_cache, gru_cache, mem_mlp_cache, mem_mlp_hidden_states, memory_segments = cache
|
|
1046
2008
|
|
|
1047
|
-
|
|
2009
|
+
if self.fixed_window_size or divisible_by(total_tokens, self.window_size * 2):
|
|
2010
|
+
kv_cache = kv_cache[..., -self.window_size:, :]
|
|
1048
2011
|
|
|
1049
|
-
|
|
2012
|
+
if self.recurrent_cache and divisible_by(total_tokens, self.window_size):
|
|
2013
|
+
kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
|
|
1050
2014
|
|
|
1051
|
-
|
|
1052
|
-
|
|
2015
|
+
if exists(gru_cache):
|
|
2016
|
+
gru_cache = torch.roll(gru_cache, shifts = -1, dims = 0)
|
|
1053
2017
|
|
|
1054
|
-
|
|
2018
|
+
if divisible_by(total_tokens, self.window_size):
|
|
2019
|
+
memory_segments.append(kv_cache.detach())
|
|
1055
2020
|
|
|
1056
|
-
|
|
1057
|
-
kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
|
|
2021
|
+
cache = TransformerMemory(total_tokens, kv_cache, gru_cache, mem_mlp_cache, mem_mlp_hidden_states, memory_segments)
|
|
1058
2022
|
|
|
1059
|
-
return out,
|
|
2023
|
+
return out, cache
|