dreamer4 0.0.7__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- dreamer4/__init__.py +8 -2
- dreamer4/dreamer4.py +2808 -814
- dreamer4/mocks.py +97 -0
- dreamer4/trainers.py +525 -3
- {dreamer4-0.0.7.dist-info → dreamer4-0.1.16.dist-info}/METADATA +97 -11
- dreamer4-0.1.16.dist-info/RECORD +8 -0
- dreamer4-0.0.7.dist-info/RECORD +0 -7
- {dreamer4-0.0.7.dist-info → dreamer4-0.1.16.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.7.dist-info → dreamer4-0.1.16.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -3,19 +3,31 @@ from __future__ import annotations
|
|
|
3
3
|
import math
|
|
4
4
|
from math import ceil, log2
|
|
5
5
|
from random import random
|
|
6
|
+
from contextlib import nullcontext
|
|
6
7
|
from collections import namedtuple
|
|
7
|
-
from functools import partial
|
|
8
|
+
from functools import partial, wraps
|
|
9
|
+
from dataclasses import dataclass, asdict
|
|
8
10
|
|
|
9
11
|
import torch
|
|
10
12
|
import torch.nn.functional as F
|
|
11
|
-
from torch.
|
|
12
|
-
from torch import
|
|
13
|
+
from torch.nested import nested_tensor
|
|
14
|
+
from torch.distributions import Normal, kl
|
|
15
|
+
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
|
16
|
+
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
17
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
13
18
|
|
|
14
19
|
import torchvision
|
|
15
20
|
from torchvision.models import VGG16_Weights
|
|
16
21
|
|
|
17
|
-
from
|
|
22
|
+
from torch.optim import Optimizer
|
|
23
|
+
from adam_atan2_pytorch import MuonAdamAtan2
|
|
24
|
+
|
|
18
25
|
from x_mlps_pytorch.ensemble import Ensemble
|
|
26
|
+
from x_mlps_pytorch.normed_mlp import create_mlp
|
|
27
|
+
|
|
28
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
29
|
+
|
|
30
|
+
from vit_pytorch.vit_with_decorr import DecorrelationLoss
|
|
19
31
|
|
|
20
32
|
from assoc_scan import AssocScan
|
|
21
33
|
|
|
@@ -27,13 +39,18 @@ from assoc_scan import AssocScan
|
|
|
27
39
|
# d - feature dimension
|
|
28
40
|
# f - frequencies (rotary)
|
|
29
41
|
# l - logit / predicted bins
|
|
42
|
+
# y - layers of transformer
|
|
30
43
|
# p - positions (3 for spacetime in this work)
|
|
31
44
|
# t - time
|
|
45
|
+
# na - action dimension (number of discrete and continuous actions)
|
|
32
46
|
# g - groups of query heads to key heads (gqa)
|
|
33
47
|
# vc - video channels
|
|
34
48
|
# vh, vw - video height and width
|
|
49
|
+
# mtp - multi token prediction length
|
|
50
|
+
# v - video viewpoints
|
|
35
51
|
|
|
36
52
|
import einx
|
|
53
|
+
from einx import add, multiply
|
|
37
54
|
from einops import einsum, rearrange, repeat, reduce, pack, unpack
|
|
38
55
|
from einops.layers.torch import Rearrange, Reduce
|
|
39
56
|
|
|
@@ -53,7 +70,95 @@ except ImportError:
|
|
|
53
70
|
|
|
54
71
|
LinearNoBias = partial(Linear, bias = False)
|
|
55
72
|
|
|
56
|
-
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
|
73
|
+
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
|
|
74
|
+
|
|
75
|
+
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
|
76
|
+
|
|
77
|
+
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
|
|
78
|
+
|
|
79
|
+
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs'))
|
|
80
|
+
|
|
81
|
+
MaybeTensor = Tensor | None
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class Experience:
|
|
85
|
+
latents: Tensor
|
|
86
|
+
video: MaybeTensor = None
|
|
87
|
+
proprio: MaybeTensor = None
|
|
88
|
+
agent_embed: MaybeTensor = None
|
|
89
|
+
rewards: Tensor | None = None
|
|
90
|
+
actions: tuple[MaybeTensor, MaybeTensor] | None = None
|
|
91
|
+
log_probs: tuple[MaybeTensor, MaybeTensor] | None = None
|
|
92
|
+
old_action_unembeds: tuple[MaybeTensor, MaybeTensor] | None = None
|
|
93
|
+
values: MaybeTensor = None
|
|
94
|
+
step_size: int | None = None
|
|
95
|
+
lens: MaybeTensor = None
|
|
96
|
+
is_truncated: MaybeTensor = None
|
|
97
|
+
agent_index: int = 0
|
|
98
|
+
is_from_world_model: bool = True
|
|
99
|
+
|
|
100
|
+
def cpu(self):
|
|
101
|
+
return self.to(torch.device('cpu'))
|
|
102
|
+
|
|
103
|
+
def to(self, device):
|
|
104
|
+
experience_dict = asdict(self)
|
|
105
|
+
experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
|
|
106
|
+
return Experience(**experience_dict)
|
|
107
|
+
|
|
108
|
+
def combine_experiences(
|
|
109
|
+
exps: list[Experiences]
|
|
110
|
+
) -> Experience:
|
|
111
|
+
|
|
112
|
+
assert len(exps) > 0
|
|
113
|
+
|
|
114
|
+
# set lens if not there
|
|
115
|
+
|
|
116
|
+
for exp in exps:
|
|
117
|
+
latents = exp.latents
|
|
118
|
+
batch, time, device = *latents.shape[:2], latents.device
|
|
119
|
+
|
|
120
|
+
if not exists(exp.lens):
|
|
121
|
+
exp.lens = full((batch,), time, device = device)
|
|
122
|
+
|
|
123
|
+
if not exists(exp.is_truncated):
|
|
124
|
+
exp.is_truncated = full((batch,), True, device = device)
|
|
125
|
+
|
|
126
|
+
# convert to dictionary
|
|
127
|
+
|
|
128
|
+
exps_dict = [asdict(exp) for exp in exps]
|
|
129
|
+
|
|
130
|
+
values, tree_specs = zip(*[tree_flatten(exp_dict) for exp_dict in exps_dict])
|
|
131
|
+
|
|
132
|
+
tree_spec = first(tree_specs)
|
|
133
|
+
|
|
134
|
+
all_field_values = list(zip(*values))
|
|
135
|
+
|
|
136
|
+
# an assert to make sure all fields are either all tensors, or a single matching value (for step size, agent index etc) - can change this later
|
|
137
|
+
|
|
138
|
+
assert all([
|
|
139
|
+
all([is_tensor(v) for v in field_values]) or len(set(field_values)) == 1
|
|
140
|
+
for field_values in all_field_values
|
|
141
|
+
])
|
|
142
|
+
|
|
143
|
+
concatted = []
|
|
144
|
+
|
|
145
|
+
for field_values in all_field_values:
|
|
146
|
+
|
|
147
|
+
if is_tensor(first(field_values)):
|
|
148
|
+
|
|
149
|
+
field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2))
|
|
150
|
+
|
|
151
|
+
new_field_value = cat(field_values)
|
|
152
|
+
else:
|
|
153
|
+
new_field_value = first(list(set(field_values)))
|
|
154
|
+
|
|
155
|
+
concatted.append(new_field_value)
|
|
156
|
+
|
|
157
|
+
# return experience
|
|
158
|
+
|
|
159
|
+
concat_exp_dict = tree_unflatten(concatted, tree_spec)
|
|
160
|
+
|
|
161
|
+
return Experience(**concat_exp_dict)
|
|
57
162
|
|
|
58
163
|
# helpers
|
|
59
164
|
|
|
@@ -66,6 +171,15 @@ def default(v, d):
|
|
|
66
171
|
def first(arr):
|
|
67
172
|
return arr[0]
|
|
68
173
|
|
|
174
|
+
def xnor(x, y):
|
|
175
|
+
return not (x ^ y)
|
|
176
|
+
|
|
177
|
+
def has_at_least_one(*bools):
|
|
178
|
+
return sum([*map(int, bools)]) > 0
|
|
179
|
+
|
|
180
|
+
def ensure_tuple(t):
|
|
181
|
+
return (t,) if not isinstance(t, tuple) else t
|
|
182
|
+
|
|
69
183
|
def divisible_by(num, den):
|
|
70
184
|
return (num % den) == 0
|
|
71
185
|
|
|
@@ -75,8 +189,88 @@ def sample_prob(prob):
|
|
|
75
189
|
def is_power_two(num):
|
|
76
190
|
return log2(num).is_integer()
|
|
77
191
|
|
|
192
|
+
def maybe(fn):
|
|
193
|
+
def inner(t, *args, **kwargs):
|
|
194
|
+
if not exists(t) or not exists(fn):
|
|
195
|
+
return None
|
|
196
|
+
return fn(t)
|
|
197
|
+
return inner
|
|
198
|
+
|
|
78
199
|
# tensor helpers
|
|
79
200
|
|
|
201
|
+
def is_empty(t):
|
|
202
|
+
return t.numel() == 0
|
|
203
|
+
|
|
204
|
+
def lens_to_mask(t, max_len = None):
|
|
205
|
+
if not exists(max_len):
|
|
206
|
+
max_len = t.amax().item()
|
|
207
|
+
|
|
208
|
+
device = t.device
|
|
209
|
+
seq = torch.arange(max_len, device = device)
|
|
210
|
+
|
|
211
|
+
return einx.less('j, i -> i j', seq, t)
|
|
212
|
+
|
|
213
|
+
def masked_mean(t, mask = None):
|
|
214
|
+
if not exists(mask):
|
|
215
|
+
return t.mean()
|
|
216
|
+
|
|
217
|
+
if not mask.any():
|
|
218
|
+
return t[mask].sum()
|
|
219
|
+
|
|
220
|
+
return t[mask].mean()
|
|
221
|
+
|
|
222
|
+
def log(t, eps = 1e-20):
|
|
223
|
+
return t.clamp(min = eps).log()
|
|
224
|
+
|
|
225
|
+
def mean_log_var_to_distr(
|
|
226
|
+
mean_log_var: Tensor
|
|
227
|
+
) -> Normal:
|
|
228
|
+
|
|
229
|
+
mean, log_var = mean_log_var.unbind(dim = -1)
|
|
230
|
+
std = (0.5 * log_var).exp()
|
|
231
|
+
return Normal(mean, std)
|
|
232
|
+
|
|
233
|
+
def safe_stack(tensors, dim = 0):
|
|
234
|
+
tensors = [*filter(exists, tensors)]
|
|
235
|
+
|
|
236
|
+
if len(tensors) == 0:
|
|
237
|
+
return None
|
|
238
|
+
|
|
239
|
+
return stack(tensors, dim = dim)
|
|
240
|
+
|
|
241
|
+
def safe_cat(tensors, dim):
|
|
242
|
+
tensors = [*filter(exists, tensors)]
|
|
243
|
+
|
|
244
|
+
if len(tensors) == 0:
|
|
245
|
+
return None
|
|
246
|
+
elif len(tensors) == 1:
|
|
247
|
+
return tensors[0]
|
|
248
|
+
|
|
249
|
+
return cat(tensors, dim = dim)
|
|
250
|
+
|
|
251
|
+
def safe_squeeze_first(t):
|
|
252
|
+
if not exists(t):
|
|
253
|
+
return None
|
|
254
|
+
|
|
255
|
+
if t.shape[0] != 1:
|
|
256
|
+
return t
|
|
257
|
+
|
|
258
|
+
return rearrange(t, '1 ... -> ...')
|
|
259
|
+
|
|
260
|
+
def gumbel_noise(t):
|
|
261
|
+
noise = torch.rand_like(t)
|
|
262
|
+
return -log(-log(noise))
|
|
263
|
+
|
|
264
|
+
def gumbel_sample(
|
|
265
|
+
t,
|
|
266
|
+
temperature = 1.,
|
|
267
|
+
dim = -1,
|
|
268
|
+
keepdim = False,
|
|
269
|
+
eps = 1e-10
|
|
270
|
+
):
|
|
271
|
+
noised = (t / max(temperature, eps)) + gumbel_noise(t)
|
|
272
|
+
return noised.argmax(dim = dim, keepdim = keepdim)
|
|
273
|
+
|
|
80
274
|
def pack_one(t, pattern):
|
|
81
275
|
packed, packed_shape = pack([t], pattern)
|
|
82
276
|
|
|
@@ -99,6 +293,27 @@ def pad_at_dim(
|
|
|
99
293
|
zeros = ((0, 0) * dims_from_right)
|
|
100
294
|
return F.pad(t, (*zeros, *pad), value = value)
|
|
101
295
|
|
|
296
|
+
def pad_to_len(t, target_len, *, dim):
|
|
297
|
+
curr_len = t.shape[dim]
|
|
298
|
+
|
|
299
|
+
if curr_len >= target_len:
|
|
300
|
+
return t
|
|
301
|
+
|
|
302
|
+
return pad_at_dim(t, (0, target_len - curr_len), dim = dim)
|
|
303
|
+
|
|
304
|
+
def pad_tensors_at_dim_to_max_len(
|
|
305
|
+
tensors: list[Tensor],
|
|
306
|
+
dims: tuple[int, ...]
|
|
307
|
+
):
|
|
308
|
+
for dim in dims:
|
|
309
|
+
if dim >= first(tensors).ndim:
|
|
310
|
+
continue
|
|
311
|
+
|
|
312
|
+
max_time = max([t.shape[dim] for t in tensors])
|
|
313
|
+
tensors = [pad_to_len(t, max_time, dim = dim) for t in tensors]
|
|
314
|
+
|
|
315
|
+
return tensors
|
|
316
|
+
|
|
102
317
|
def align_dims_left(t, aligned_to):
|
|
103
318
|
shape = t.shape
|
|
104
319
|
num_right_dims = aligned_to.ndim - t.ndim
|
|
@@ -114,8 +329,74 @@ def l2norm(t):
|
|
|
114
329
|
def softclamp(t, value = 50.):
|
|
115
330
|
return (t / value).tanh() * value
|
|
116
331
|
|
|
332
|
+
def create_multi_token_prediction_targets(
|
|
333
|
+
t, # (b t ...)
|
|
334
|
+
steps_future,
|
|
335
|
+
|
|
336
|
+
): # (b t-1 steps ...), (b t-1 steps) - targets and the mask, where mask is False for padding
|
|
337
|
+
|
|
338
|
+
batch, seq_len, device = *t.shape[:2], t.device
|
|
339
|
+
|
|
340
|
+
batch_arange = arange(batch, device = device)
|
|
341
|
+
seq_arange = arange(seq_len, device = device)
|
|
342
|
+
steps_arange = arange(steps_future, device = device)
|
|
343
|
+
|
|
344
|
+
indices = add('t, steps -> t steps', seq_arange, steps_arange)
|
|
345
|
+
mask = indices < seq_len
|
|
346
|
+
|
|
347
|
+
batch_arange = rearrange(batch_arange, 'b -> b 1 1')
|
|
348
|
+
|
|
349
|
+
indices[~mask] = 0
|
|
350
|
+
mask = repeat(mask, 't steps -> b t steps', b = batch)
|
|
351
|
+
|
|
352
|
+
out = t[batch_arange, indices]
|
|
353
|
+
|
|
354
|
+
return out, mask
|
|
355
|
+
|
|
117
356
|
# loss related
|
|
118
357
|
|
|
358
|
+
class LossNormalizer(Module):
|
|
359
|
+
|
|
360
|
+
# the authors mentioned the need for loss normalization in the dynamics transformer
|
|
361
|
+
|
|
362
|
+
def __init__(
|
|
363
|
+
self,
|
|
364
|
+
num_losses: int,
|
|
365
|
+
beta = 0.95,
|
|
366
|
+
eps = 1e-6
|
|
367
|
+
):
|
|
368
|
+
super().__init__()
|
|
369
|
+
self.register_buffer('exp_avg_sq', torch.ones(num_losses))
|
|
370
|
+
self.beta = beta
|
|
371
|
+
self.eps = eps
|
|
372
|
+
|
|
373
|
+
def forward(
|
|
374
|
+
self,
|
|
375
|
+
losses: Tensor | list[Tensor] | dict[str, Tensor],
|
|
376
|
+
update_ema = None
|
|
377
|
+
):
|
|
378
|
+
exp_avg_sq, beta = self.exp_avg_sq, self.beta
|
|
379
|
+
update_ema = default(update_ema, self.training)
|
|
380
|
+
|
|
381
|
+
# get the rms value - as mentioned at the end of section 3 in the paper
|
|
382
|
+
|
|
383
|
+
rms = exp_avg_sq.sqrt()
|
|
384
|
+
|
|
385
|
+
if update_ema:
|
|
386
|
+
decay = 1. - beta
|
|
387
|
+
|
|
388
|
+
# update the ema
|
|
389
|
+
|
|
390
|
+
exp_avg_sq.lerp_(losses.detach().square(), decay)
|
|
391
|
+
|
|
392
|
+
# then normalize
|
|
393
|
+
|
|
394
|
+
assert losses.numel() == rms.numel()
|
|
395
|
+
|
|
396
|
+
normed_losses = losses / rms.clamp(min = self.eps)
|
|
397
|
+
|
|
398
|
+
return normed_losses
|
|
399
|
+
|
|
119
400
|
class LPIPSLoss(Module):
|
|
120
401
|
def __init__(
|
|
121
402
|
self,
|
|
@@ -267,1007 +548,2281 @@ class SymExpTwoHot(Module):
|
|
|
267
548
|
|
|
268
549
|
return inverse_pack(encoded, '* l')
|
|
269
550
|
|
|
270
|
-
#
|
|
271
|
-
|
|
272
|
-
@torch.no_grad()
|
|
273
|
-
def calc_gae(
|
|
274
|
-
rewards,
|
|
275
|
-
values,
|
|
276
|
-
masks,
|
|
277
|
-
gamma = 0.99,
|
|
278
|
-
lam = 0.95,
|
|
279
|
-
use_accelerated = None
|
|
280
|
-
):
|
|
281
|
-
assert values.shape[-1] == rewards.shape[-1]
|
|
282
|
-
use_accelerated = default(use_accelerated, rewards.is_cuda)
|
|
283
|
-
|
|
284
|
-
values = F.pad(values, (0, 1), value = 0.)
|
|
285
|
-
values, values_next = values[..., :-1], values[..., 1:]
|
|
286
|
-
|
|
287
|
-
delta = rewards + gamma * values_next * masks - values
|
|
288
|
-
gates = gamma * lam * masks
|
|
289
|
-
|
|
290
|
-
scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
|
|
291
|
-
|
|
292
|
-
gae = scan(gates, delta)
|
|
293
|
-
|
|
294
|
-
returns = gae + values
|
|
551
|
+
# action related
|
|
295
552
|
|
|
296
|
-
|
|
553
|
+
ActionEmbeds = namedtuple('ActionEmbed', ('discrete', 'continuous'))
|
|
297
554
|
|
|
298
|
-
|
|
299
|
-
# https://jerryxio.ng/posts/nd-rope/
|
|
300
|
-
|
|
301
|
-
def _phi(m):
|
|
302
|
-
x = 2.
|
|
303
|
-
for _ in range(10):
|
|
304
|
-
x = (1. + x) ** (1. / (m + 1.))
|
|
305
|
-
return x
|
|
306
|
-
|
|
307
|
-
def make_directions(n, d):
|
|
308
|
-
g = _phi(d)
|
|
309
|
-
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
|
|
310
|
-
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
|
|
311
|
-
z = torch.fmod(i * alpha, 1.0)
|
|
312
|
-
directions = torch.erfinv(2.0 * z - 1.0)
|
|
313
|
-
directions = l2norm(directions)
|
|
314
|
-
return directions.float()
|
|
315
|
-
|
|
316
|
-
class GoldenGateRoPENd(Module):
|
|
555
|
+
class ActionEmbedder(Module):
|
|
317
556
|
def __init__(
|
|
318
557
|
self,
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
558
|
+
dim,
|
|
559
|
+
*,
|
|
560
|
+
num_discrete_actions: int | tuple[int, ...] = 0,
|
|
561
|
+
num_continuous_actions = 0,
|
|
562
|
+
continuous_norm_stats: tuple[tuple[float, float], ...] | None = None,
|
|
563
|
+
can_unembed = False,
|
|
564
|
+
unembed_dim = None,
|
|
565
|
+
num_unembed_preds = 1,
|
|
566
|
+
squeeze_unembed_preds = True # will auto-squeeze if prediction is just 1
|
|
325
567
|
):
|
|
326
568
|
super().__init__()
|
|
327
|
-
assert divisible_by(dim_head, 2)
|
|
328
569
|
|
|
329
|
-
|
|
330
|
-
n_zero_freqs = round(rope_p_zero_freqs * n_freqs)
|
|
570
|
+
# handle discrete actions
|
|
331
571
|
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
rope_min_freq * (rope_max_freq / rope_min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
|
|
335
|
-
))
|
|
572
|
+
num_discrete_actions = tensor(ensure_tuple(num_discrete_actions))
|
|
573
|
+
total_discrete_actions = num_discrete_actions.sum().item()
|
|
336
574
|
|
|
337
|
-
|
|
338
|
-
|
|
575
|
+
self.num_discrete_action_types = len(num_discrete_actions)
|
|
576
|
+
self.discrete_action_embed = Embedding(total_discrete_actions, dim)
|
|
339
577
|
|
|
340
|
-
|
|
341
|
-
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
|
|
578
|
+
self.register_buffer('num_discrete_actions', num_discrete_actions, persistent = False)
|
|
342
579
|
|
|
343
|
-
|
|
344
|
-
self,
|
|
345
|
-
pos # (b n p)
|
|
346
|
-
):
|
|
580
|
+
# continuous actions
|
|
347
581
|
|
|
348
|
-
|
|
349
|
-
|
|
582
|
+
self.num_continuous_action_types = num_continuous_actions
|
|
583
|
+
self.continuous_action_embed = Embedding(num_continuous_actions, dim)
|
|
350
584
|
|
|
351
|
-
|
|
585
|
+
self.continuous_need_norm = exists(continuous_norm_stats)
|
|
352
586
|
|
|
353
|
-
|
|
587
|
+
if self.continuous_need_norm:
|
|
588
|
+
self.register_buffer('continuous_norm_stats', tensor(continuous_norm_stats))
|
|
354
589
|
|
|
355
|
-
|
|
590
|
+
# defaults
|
|
356
591
|
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
self,
|
|
360
|
-
dim_head,
|
|
361
|
-
theta = 10000.
|
|
362
|
-
):
|
|
363
|
-
super().__init__()
|
|
364
|
-
inv_freq = 1.0 / (theta ** (arange(0, dim_head, 2).float() / dim_head))
|
|
365
|
-
self.register_buffer('inv_freq', inv_freq)
|
|
592
|
+
self.register_buffer('default_discrete_action_types', arange(self.num_discrete_action_types), persistent = False)
|
|
593
|
+
self.register_buffer('default_continuous_action_types', arange(self.num_continuous_action_types), persistent = False)
|
|
366
594
|
|
|
367
|
-
|
|
368
|
-
self,
|
|
369
|
-
seq_len
|
|
370
|
-
):
|
|
371
|
-
device, dtype = self.inv_freq.device, self.inv_freq.dtype
|
|
595
|
+
# calculate offsets
|
|
372
596
|
|
|
373
|
-
|
|
374
|
-
|
|
597
|
+
offsets = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0)
|
|
598
|
+
self.register_buffer('discrete_action_offsets', offsets, persistent = False)
|
|
375
599
|
|
|
376
|
-
|
|
600
|
+
# unembedding
|
|
377
601
|
|
|
602
|
+
self.can_unembed = can_unembed
|
|
378
603
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
t # (b h n d)
|
|
382
|
-
):
|
|
383
|
-
heads, dtype = t.shape[1], t.dtype
|
|
384
|
-
t = t.float()
|
|
604
|
+
self.num_unembed_preds = num_unembed_preds
|
|
605
|
+
self.squeeze_unembed_preds = squeeze_unembed_preds
|
|
385
606
|
|
|
386
|
-
|
|
607
|
+
if not can_unembed:
|
|
608
|
+
return
|
|
387
609
|
|
|
388
|
-
|
|
389
|
-
|
|
610
|
+
unembed_dim = default(unembed_dim, dim)
|
|
611
|
+
self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, num_unembed_preds, unembed_dim) * 1e-2)
|
|
390
612
|
|
|
391
|
-
|
|
392
|
-
groups = heads // rotary_heads
|
|
393
|
-
rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups)
|
|
613
|
+
discrete_action_index = arange(total_discrete_actions)
|
|
394
614
|
|
|
395
|
-
|
|
396
|
-
|
|
615
|
+
padded_num_discrete_actions = F.pad(num_discrete_actions, (1, 0), value = 0)
|
|
616
|
+
exclusive_cumsum = padded_num_discrete_actions.cumsum(dim = -1)
|
|
397
617
|
|
|
398
|
-
|
|
618
|
+
discrete_action_mask = (
|
|
619
|
+
einx.greater_equal('j, i -> i j', discrete_action_index, exclusive_cumsum[:-1]) &
|
|
620
|
+
einx.less('j, i -> i j', discrete_action_index, exclusive_cumsum[1:])
|
|
621
|
+
)
|
|
399
622
|
|
|
400
|
-
|
|
401
|
-
return rotated.type(dtype)
|
|
623
|
+
self.register_buffer('discrete_action_mask', discrete_action_mask, persistent = False)
|
|
402
624
|
|
|
403
|
-
|
|
625
|
+
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, num_unembed_preds, unembed_dim, 2) * 1e-2)
|
|
404
626
|
|
|
405
|
-
|
|
406
|
-
|
|
627
|
+
def embed_parameters(self):
|
|
628
|
+
return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()])
|
|
629
|
+
|
|
630
|
+
def unembed_parameters(self):
|
|
631
|
+
return set([self.discrete_action_unembed, self.continuous_action_unembed])
|
|
632
|
+
|
|
633
|
+
@property
|
|
634
|
+
def device(self):
|
|
635
|
+
return self.discrete_action_offsets.device
|
|
636
|
+
|
|
637
|
+
@property
|
|
638
|
+
def has_actions(self):
|
|
639
|
+
return self.num_discrete_action_types > 0 or self.num_continuous_action_types > 0
|
|
640
|
+
|
|
641
|
+
def cast_action_types(
|
|
407
642
|
self,
|
|
408
|
-
|
|
409
|
-
heads = 8
|
|
643
|
+
action_types = None
|
|
410
644
|
):
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
645
|
+
if exists(action_types) and not is_tensor(action_types):
|
|
646
|
+
if isinstance(action_types, int):
|
|
647
|
+
action_types = (action_types,)
|
|
414
648
|
|
|
415
|
-
|
|
649
|
+
action_types = tensor(action_types, device = self.device)
|
|
650
|
+
|
|
651
|
+
return action_types
|
|
652
|
+
|
|
653
|
+
def unembed(
|
|
416
654
|
self,
|
|
417
|
-
|
|
418
|
-
)
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
655
|
+
embeds, # (... d)
|
|
656
|
+
discrete_action_types = None, # (na)
|
|
657
|
+
continuous_action_types = None, # (na)
|
|
658
|
+
return_split_discrete = False,
|
|
659
|
+
pred_head_index: int | Tensor | None = None
|
|
422
660
|
|
|
423
|
-
#
|
|
661
|
+
): # (... discrete_na), (... continuous_na 2)
|
|
424
662
|
|
|
425
|
-
|
|
426
|
-
q, k, v,
|
|
427
|
-
softclamp_value = None,
|
|
428
|
-
scale = None,
|
|
429
|
-
causal = False,
|
|
430
|
-
causal_block_size = 1,
|
|
431
|
-
mask = None
|
|
432
|
-
):
|
|
663
|
+
device = embeds.device
|
|
433
664
|
|
|
434
|
-
|
|
435
|
-
scale = q.shape[-1] ** -0.5
|
|
665
|
+
assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init'
|
|
436
666
|
|
|
437
|
-
|
|
667
|
+
# handle only one prediction head during inference
|
|
438
668
|
|
|
439
|
-
|
|
669
|
+
if exists(pred_head_index) and isinstance(pred_head_index, int):
|
|
670
|
+
pred_head_index = tensor(pred_head_index, device = device)
|
|
440
671
|
|
|
441
|
-
|
|
672
|
+
# if pred_head_index given as a solo int, just assume we want to squeeze out the prediction head dimension
|
|
442
673
|
|
|
443
|
-
|
|
674
|
+
squeeze_one_pred_head = exists(pred_head_index) and pred_head_index.ndim == 0
|
|
444
675
|
|
|
445
|
-
|
|
676
|
+
# get action types
|
|
446
677
|
|
|
447
|
-
|
|
678
|
+
discrete_action_types, continuous_action_types = tuple(self.cast_action_types(t) for t in (discrete_action_types, continuous_action_types))
|
|
448
679
|
|
|
449
|
-
|
|
680
|
+
# discrete actions
|
|
450
681
|
|
|
451
|
-
|
|
682
|
+
discrete_action_logits = None
|
|
452
683
|
|
|
453
|
-
|
|
454
|
-
sim = softclamp(sim, softclamp_value)
|
|
684
|
+
if self.num_discrete_action_types > 0:
|
|
455
685
|
|
|
456
|
-
|
|
686
|
+
discrete_action_unembed = self.discrete_action_unembed
|
|
457
687
|
|
|
458
|
-
|
|
688
|
+
if exists(discrete_action_types):
|
|
689
|
+
discrete_action_mask = self.discrete_action_mask[discrete_action_types].any(dim = 0)
|
|
459
690
|
|
|
460
|
-
|
|
461
|
-
sim = sim.masked_fill(~mask, mask_value)
|
|
691
|
+
discrete_action_unembed = discrete_action_unembed[discrete_action_mask]
|
|
462
692
|
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
i, j = sim.shape[-2:]
|
|
693
|
+
if exists(pred_head_index):
|
|
694
|
+
discrete_action_unembed = discrete_action_unembed.index_select(1, pred_head_index)
|
|
466
695
|
|
|
467
|
-
|
|
468
|
-
i = ceil(i / causal_block_size)
|
|
469
|
-
j = ceil(j / causal_block_size)
|
|
696
|
+
discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na mtp d -> mtp ... na')
|
|
470
697
|
|
|
471
|
-
|
|
698
|
+
if self.squeeze_unembed_preds or squeeze_one_pred_head:
|
|
699
|
+
discrete_action_logits = safe_squeeze_first(discrete_action_logits)
|
|
472
700
|
|
|
473
|
-
|
|
474
|
-
causal_mask = repeat(causal_mask, 'i j -> (i b1) (j b2)', b1 = causal_block_size, b2 = causal_block_size)
|
|
475
|
-
causal_mask = causal_mask[:sim.shape[-2], :sim.shape[-1]]
|
|
701
|
+
# whether to split the discrete action logits by the number of actions per action type
|
|
476
702
|
|
|
477
|
-
|
|
703
|
+
if exists(discrete_action_logits) and return_split_discrete:
|
|
478
704
|
|
|
479
|
-
|
|
705
|
+
split_sizes = self.num_discrete_actions[discrete_action_types] if exists(discrete_action_types) else self.num_discrete_actions
|
|
480
706
|
|
|
481
|
-
|
|
707
|
+
discrete_action_logits = discrete_action_logits.split(split_sizes.tolist(), dim = -1)
|
|
482
708
|
|
|
483
|
-
|
|
709
|
+
# continuous actions
|
|
484
710
|
|
|
485
|
-
|
|
711
|
+
continuous_action_mean_log_var = None
|
|
486
712
|
|
|
487
|
-
|
|
713
|
+
if self.num_continuous_action_types > 0:
|
|
488
714
|
|
|
489
|
-
|
|
715
|
+
continuous_action_unembed = self.continuous_action_unembed
|
|
490
716
|
|
|
491
|
-
|
|
717
|
+
if exists(continuous_action_types):
|
|
718
|
+
continuous_action_unembed = continuous_action_unembed[continuous_action_types]
|
|
492
719
|
|
|
493
|
-
|
|
720
|
+
if exists(pred_head_index):
|
|
721
|
+
continuous_action_unembed = continuous_action_unembed.index_select(1, pred_head_index)
|
|
494
722
|
|
|
495
|
-
|
|
496
|
-
bq = q // block_size
|
|
497
|
-
bk = k // block_size
|
|
498
|
-
return bq >= bk
|
|
723
|
+
continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na mtp d two -> mtp ... na two')
|
|
499
724
|
|
|
500
|
-
|
|
725
|
+
if self.squeeze_unembed_preds or squeeze_one_pred_head:
|
|
726
|
+
continuous_action_mean_log_var = safe_squeeze_first(continuous_action_mean_log_var)
|
|
501
727
|
|
|
502
|
-
|
|
503
|
-
bq = q % seq_len
|
|
504
|
-
bk = k % seq_len
|
|
728
|
+
return discrete_action_logits, continuous_action_mean_log_var
|
|
505
729
|
|
|
506
|
-
|
|
730
|
+
def sample(
|
|
731
|
+
self,
|
|
732
|
+
embed,
|
|
733
|
+
discrete_temperature = 1.,
|
|
734
|
+
continuous_temperature = 1.,
|
|
735
|
+
inverse_norm_continuous = None,
|
|
736
|
+
pred_head_index: int | Tensor | None = None,
|
|
737
|
+
squeeze = True,
|
|
738
|
+
**kwargs
|
|
739
|
+
):
|
|
740
|
+
inverse_norm_continuous = default(inverse_norm_continuous, self.continuous_need_norm)
|
|
507
741
|
|
|
508
|
-
|
|
509
|
-
k_is_special = k >= is_special_start_index
|
|
742
|
+
discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, pred_head_index = pred_head_index, **kwargs)
|
|
510
743
|
|
|
511
|
-
|
|
512
|
-
out = ~(q_is_special & ~k_is_special) # modality attends to everything, but latent can only attend to itself (proposed attention pattern for encoder of video tokenizer)
|
|
513
|
-
else:
|
|
514
|
-
out = ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens
|
|
744
|
+
sampled_discrete = sampled_continuous = None
|
|
515
745
|
|
|
516
|
-
|
|
746
|
+
if exists(discrete_logits):
|
|
747
|
+
sampled_discrete = []
|
|
517
748
|
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
num_tokens
|
|
521
|
-
):
|
|
522
|
-
def inner(b, h, q, k):
|
|
523
|
-
return special_token_mask(q, k, seq_len, num_tokens)
|
|
524
|
-
return inner
|
|
749
|
+
for one_discrete_logits in discrete_logits:
|
|
750
|
+
sampled_discrete.append(gumbel_sample(one_discrete_logits, temperature = discrete_temperature, keepdim = True))
|
|
525
751
|
|
|
526
|
-
|
|
527
|
-
def inner(b, h, q, k):
|
|
528
|
-
return mask1(b, h, q, k) & mask2(b, h, q, k)
|
|
752
|
+
sampled_discrete = cat(sampled_discrete, dim = -1)
|
|
529
753
|
|
|
530
|
-
|
|
754
|
+
if exists(continuous_mean_log_var):
|
|
755
|
+
mean, log_var = continuous_mean_log_var.unbind(dim = -1)
|
|
756
|
+
std = (0.5 * log_var).exp()
|
|
531
757
|
|
|
532
|
-
|
|
533
|
-
return b >= 0
|
|
758
|
+
sampled_continuous = mean + std * torch.randn_like(mean) * continuous_temperature
|
|
534
759
|
|
|
535
|
-
|
|
536
|
-
def inner(sim, b, h, q, k):
|
|
537
|
-
if not exists(value):
|
|
538
|
-
return sim
|
|
760
|
+
# maybe inverse norm
|
|
539
761
|
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
return sim
|
|
762
|
+
if inverse_norm_continuous:
|
|
763
|
+
norm_mean, norm_std = self.continuous_norm_stats.unbind(dim = -1)
|
|
764
|
+
sampled_continuous = (sampled_continuous * norm_std) + norm_mean
|
|
544
765
|
|
|
545
|
-
|
|
766
|
+
return sampled_discrete, sampled_continuous
|
|
546
767
|
|
|
547
|
-
|
|
768
|
+
def log_probs(
|
|
769
|
+
self,
|
|
770
|
+
embeds, # (... d)
|
|
771
|
+
discrete_targets = None, # (... na)
|
|
772
|
+
continuous_targets = None, # (... na)
|
|
773
|
+
discrete_action_types = None, # (na)
|
|
774
|
+
continuous_action_types = None, # (na)
|
|
775
|
+
pred_head_index: int | Tensor | None = None,
|
|
776
|
+
parallel_discrete_calc = None,
|
|
777
|
+
return_entropies = False
|
|
778
|
+
):
|
|
779
|
+
parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1)
|
|
548
780
|
|
|
549
|
-
|
|
550
|
-
use_flex,
|
|
551
|
-
seq_len,
|
|
552
|
-
k_seq_len,
|
|
553
|
-
causal = False,
|
|
554
|
-
causal_block_size = 1,
|
|
555
|
-
softclamp_value = 50.,
|
|
556
|
-
num_special_tokens = 0, # special tokens are latents / agents
|
|
557
|
-
block_size_per_special = None, # defaults to k_seq_len
|
|
558
|
-
special_attend_only_itself = False, # by default, modality only attends to itself while special sees everything, but if turned True, will be the inverse - special can only attend to itself but modality can attend everything
|
|
559
|
-
device = None
|
|
560
|
-
):
|
|
561
|
-
block_size_per_special = default(block_size_per_special, k_seq_len)
|
|
781
|
+
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, pred_head_index = pred_head_index, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
|
|
562
782
|
|
|
563
|
-
|
|
564
|
-
# flex pathway
|
|
783
|
+
# discrete
|
|
565
784
|
|
|
566
|
-
|
|
785
|
+
discrete_log_probs = None
|
|
786
|
+
discrete_entropies = None
|
|
567
787
|
|
|
568
|
-
if
|
|
569
|
-
special_block_mask = block_mask_special_tokens_right(block_size_per_special, num_special_tokens, special_attend_only_itself)
|
|
570
|
-
block_mask_fn = compose_mask(block_mask_fn, special_block_mask)
|
|
788
|
+
if exists(discrete_targets):
|
|
571
789
|
|
|
572
|
-
|
|
790
|
+
if parallel_discrete_calc:
|
|
791
|
+
# use nested tensors
|
|
573
792
|
|
|
574
|
-
|
|
575
|
-
attend_fn = partial(flex_attention, block_mask = block_mask, score_mod = score_mod, enable_gqa = True)
|
|
576
|
-
else:
|
|
577
|
-
# naive pathway
|
|
793
|
+
jagged_dims = tuple(t.shape[-1] for t in discrete_action_logits)
|
|
578
794
|
|
|
579
|
-
|
|
580
|
-
if num_special_tokens > 0:
|
|
581
|
-
q_seq = torch.arange(seq_len, device = device)[:, None]
|
|
582
|
-
k_seq = torch.arange(k_seq_len, device = device)[None, :]
|
|
795
|
+
discrete_action_logits = cat(discrete_action_logits, dim = -1)
|
|
583
796
|
|
|
584
|
-
|
|
797
|
+
discrete_action_logits, inverse_pack_lead_dims = pack_one(discrete_action_logits, '* l')
|
|
798
|
+
batch = discrete_action_logits.shape[0]
|
|
585
799
|
|
|
586
|
-
|
|
800
|
+
discrete_action_logits = rearrange(discrete_action_logits, 'b l -> (b l)')
|
|
587
801
|
|
|
588
|
-
|
|
802
|
+
# to nested tensor
|
|
589
803
|
|
|
590
|
-
|
|
804
|
+
nested_logits = nested_tensor(discrete_action_logits.split(jagged_dims * batch), layout = torch.jagged, device = self.device, requires_grad = True)
|
|
591
805
|
|
|
592
|
-
|
|
593
|
-
|
|
806
|
+
prob = nested_logits.softmax(dim = -1)
|
|
807
|
+
|
|
808
|
+
log_probs = log(prob)
|
|
809
|
+
|
|
810
|
+
# maybe entropy
|
|
811
|
+
|
|
812
|
+
if return_entropies:
|
|
813
|
+
discrete_entropies = (-prob * log_probs).sum(dim = -1, keepdim = True)
|
|
814
|
+
discrete_entropies = cat(discrete_entropies.unbind())
|
|
815
|
+
discrete_entropies = rearrange(discrete_entropies, '(b na) -> b na', b = batch)
|
|
816
|
+
|
|
817
|
+
discrete_entropies = inverse_pack_lead_dims(discrete_entropies, '* na')
|
|
818
|
+
|
|
819
|
+
# back to regular tensor
|
|
820
|
+
|
|
821
|
+
log_probs = cat(log_probs.unbind())
|
|
822
|
+
log_probs = rearrange(log_probs, '(b l) -> b l', b = batch)
|
|
823
|
+
|
|
824
|
+
log_probs = inverse_pack_lead_dims(log_probs)
|
|
825
|
+
|
|
826
|
+
# get indices to gather
|
|
827
|
+
|
|
828
|
+
discrete_action_types = default(discrete_action_types, self.default_discrete_action_types)
|
|
829
|
+
|
|
830
|
+
num_discrete_actions = self.num_discrete_actions[discrete_action_types]
|
|
831
|
+
|
|
832
|
+
offset = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0)
|
|
833
|
+
log_prob_indices = discrete_targets + offset
|
|
834
|
+
|
|
835
|
+
# gather
|
|
836
|
+
|
|
837
|
+
discrete_log_probs = log_probs.gather(-1, log_prob_indices)
|
|
838
|
+
|
|
839
|
+
else:
|
|
840
|
+
discrete_log_probs = []
|
|
841
|
+
discrete_entropies = []
|
|
842
|
+
|
|
843
|
+
for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):
|
|
844
|
+
|
|
845
|
+
one_discrete_probs = one_discrete_action_logit.softmax(dim = -1)
|
|
846
|
+
one_discrete_log_probs = log(one_discrete_probs)
|
|
847
|
+
one_discrete_target = rearrange(one_discrete_target, '... -> ... 1')
|
|
848
|
+
|
|
849
|
+
log_prob = one_discrete_log_probs.gather(-1, one_discrete_target)
|
|
850
|
+
discrete_log_probs.append(log_prob)
|
|
851
|
+
|
|
852
|
+
if return_entropies:
|
|
853
|
+
entropy = (-one_discrete_probs * one_discrete_log_probs).sum(dim = -1)
|
|
854
|
+
discrete_entropies.append(entropy)
|
|
855
|
+
|
|
856
|
+
discrete_log_probs = cat(discrete_log_probs, dim = -1)
|
|
857
|
+
|
|
858
|
+
if return_entropies:
|
|
859
|
+
discrete_entropies = stack(discrete_entropies, dim = -1)
|
|
860
|
+
|
|
861
|
+
# continuous
|
|
862
|
+
|
|
863
|
+
continuous_log_probs = None
|
|
864
|
+
continuous_entropies = None
|
|
865
|
+
|
|
866
|
+
if exists(continuous_targets):
|
|
867
|
+
distr = mean_log_var_to_distr(continuous_action_mean_log_var)
|
|
868
|
+
continuous_log_probs = distr.log_prob(continuous_targets)
|
|
869
|
+
|
|
870
|
+
if return_entropies:
|
|
871
|
+
continuous_entropies = distr.entropy()
|
|
872
|
+
|
|
873
|
+
log_probs = (discrete_log_probs, continuous_log_probs)
|
|
874
|
+
|
|
875
|
+
if not return_entropies:
|
|
876
|
+
return log_probs
|
|
877
|
+
|
|
878
|
+
entropies = (discrete_entropies, continuous_entropies)
|
|
879
|
+
|
|
880
|
+
return log_probs, entropies
|
|
881
|
+
|
|
882
|
+
def kl_div(
|
|
594
883
|
self,
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
pre_rmsnorm = True,
|
|
600
|
-
):
|
|
601
|
-
super().__init__()
|
|
602
|
-
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
|
884
|
+
src: tuple[MaybeTensor, MaybeTensor],
|
|
885
|
+
tgt: tuple[MaybeTensor, MaybeTensor],
|
|
886
|
+
reduce_across_num_actions = True
|
|
887
|
+
) -> tuple[MaybeTensor, MaybeTensor]:
|
|
603
888
|
|
|
604
|
-
|
|
889
|
+
src_discrete, src_continuous = src
|
|
890
|
+
tgt_discrete, tgt_continuous = tgt
|
|
605
891
|
|
|
606
|
-
|
|
607
|
-
assert query_heads >= heads and divisible_by(query_heads, heads)
|
|
892
|
+
discrete_kl_div = None
|
|
608
893
|
|
|
609
|
-
#
|
|
894
|
+
# split discrete if it is not already (multiple discrete actions)
|
|
610
895
|
|
|
611
|
-
|
|
612
|
-
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
896
|
+
if exists(src_discrete):
|
|
613
897
|
|
|
614
|
-
|
|
615
|
-
dim_kv_inner = dim_head * heads
|
|
898
|
+
discrete_split = self.num_discrete_actions.tolist()
|
|
616
899
|
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
self.to_out = LinearNoBias(dim_q_inner, dim)
|
|
900
|
+
if is_tensor(src_discrete):
|
|
901
|
+
src_discrete = src_discrete.split(discrete_split, dim = -1)
|
|
620
902
|
|
|
621
|
-
|
|
903
|
+
if is_tensor(tgt_discrete):
|
|
904
|
+
tgt_discrete = tgt_discrete.split(discrete_split, dim = -1)
|
|
905
|
+
|
|
906
|
+
discrete_kl_divs = []
|
|
907
|
+
|
|
908
|
+
for src_logit, tgt_logit in zip(src_discrete, tgt_discrete):
|
|
909
|
+
|
|
910
|
+
src_log_probs = src_logit.log_softmax(dim = -1)
|
|
911
|
+
tgt_prob = tgt_logit.softmax(dim = -1)
|
|
912
|
+
|
|
913
|
+
one_discrete_kl_div = F.kl_div(src_log_probs, tgt_prob, reduction = 'none')
|
|
914
|
+
|
|
915
|
+
discrete_kl_divs.append(one_discrete_kl_div.sum(dim = -1))
|
|
916
|
+
|
|
917
|
+
discrete_kl_div = stack(discrete_kl_divs, dim = -1)
|
|
918
|
+
|
|
919
|
+
# calculate kl divergence for continuous
|
|
920
|
+
|
|
921
|
+
continuous_kl_div = None
|
|
922
|
+
|
|
923
|
+
if exists(src_continuous):
|
|
924
|
+
src_normal = mean_log_var_to_distr(src_continuous)
|
|
925
|
+
tgt_normal = mean_log_var_to_distr(tgt_continuous)
|
|
926
|
+
|
|
927
|
+
continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
|
|
928
|
+
|
|
929
|
+
# maybe reduce
|
|
930
|
+
|
|
931
|
+
if reduce_across_num_actions:
|
|
932
|
+
if exists(discrete_kl_div):
|
|
933
|
+
discrete_kl_div = discrete_kl_div.sum(dim = -1)
|
|
622
934
|
|
|
623
|
-
|
|
624
|
-
|
|
935
|
+
if exists(continuous_kl_div):
|
|
936
|
+
continuous_kl_div = continuous_kl_div.sum(dim = -1)
|
|
937
|
+
|
|
938
|
+
return discrete_kl_div, continuous_kl_div
|
|
625
939
|
|
|
626
940
|
def forward(
|
|
627
941
|
self,
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
942
|
+
*,
|
|
943
|
+
discrete_actions = None, # (... na)
|
|
944
|
+
continuous_actions = None, # (... na)
|
|
945
|
+
discrete_action_types = None, # (na)
|
|
946
|
+
continuous_action_types = None, # (na)
|
|
947
|
+
return_sum_pooled_embeds = True
|
|
633
948
|
):
|
|
634
|
-
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
|
|
635
949
|
|
|
636
|
-
|
|
950
|
+
discrete_embeds = continuous_embeds = None
|
|
637
951
|
|
|
638
|
-
|
|
952
|
+
if exists(discrete_actions):
|
|
639
953
|
|
|
640
|
-
|
|
954
|
+
discrete_action_types = default(discrete_action_types, self.default_discrete_action_types)
|
|
641
955
|
|
|
642
|
-
|
|
956
|
+
discrete_action_types = self.cast_action_types(discrete_action_types)
|
|
643
957
|
|
|
644
|
-
|
|
958
|
+
offsets = self.discrete_action_offsets[discrete_action_types]
|
|
645
959
|
|
|
646
|
-
|
|
647
|
-
k = self.k_heads_rmsnorm(k)
|
|
960
|
+
assert offsets.shape[-1] == discrete_actions.shape[-1], 'mismatched number of discrete actions'
|
|
648
961
|
|
|
649
|
-
|
|
962
|
+
# offset the discrete actions based on the action types passed in (by default all discrete actions) and the calculated offset
|
|
650
963
|
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
k = cat((ck, k), dim = -2)
|
|
654
|
-
v = cat((cv, v), dim = -2)
|
|
964
|
+
discrete_actions_offsetted = add('... na, na', discrete_actions, offsets)
|
|
965
|
+
discrete_embeds = self.discrete_action_embed(discrete_actions_offsetted)
|
|
655
966
|
|
|
656
|
-
|
|
967
|
+
if exists(continuous_actions):
|
|
968
|
+
continuous_action_types = default(continuous_action_types, self.default_continuous_action_types)
|
|
657
969
|
|
|
658
|
-
|
|
659
|
-
q = apply_rotations(rotary_pos_emb, q)
|
|
660
|
-
k = apply_rotations(rotary_pos_emb, k)
|
|
970
|
+
continuous_action_types = self.cast_action_types(continuous_action_types)
|
|
661
971
|
|
|
662
|
-
|
|
972
|
+
assert continuous_action_types.shape[-1] == continuous_actions.shape[-1], 'mismatched number of continuous actions'
|
|
663
973
|
|
|
664
|
-
|
|
974
|
+
continuous_action_embed = self.continuous_action_embed(continuous_action_types)
|
|
665
975
|
|
|
666
|
-
|
|
976
|
+
# maybe normalization
|
|
667
977
|
|
|
668
|
-
|
|
978
|
+
if self.continuous_need_norm:
|
|
979
|
+
norm_mean, norm_std = self.continuous_norm_stats.unbind(dim = -1)
|
|
980
|
+
continuous_actions = (continuous_actions - norm_mean) / norm_std.clamp(min = 1e-6)
|
|
669
981
|
|
|
670
|
-
|
|
982
|
+
# continuous embed is just the selected continuous action type with the scale
|
|
671
983
|
|
|
672
|
-
|
|
984
|
+
continuous_embeds = multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions)
|
|
673
985
|
|
|
674
|
-
|
|
986
|
+
# return not pooled
|
|
675
987
|
|
|
676
|
-
|
|
988
|
+
if not return_sum_pooled_embeds:
|
|
989
|
+
return ActionEmbeds(discrete_embeds, continuous_embeds)
|
|
677
990
|
|
|
678
|
-
|
|
679
|
-
return out
|
|
991
|
+
# handle sum pooling, which is what they did in the paper for all the actions
|
|
680
992
|
|
|
681
|
-
|
|
993
|
+
pooled = 0.
|
|
682
994
|
|
|
683
|
-
|
|
995
|
+
if exists(discrete_embeds):
|
|
996
|
+
pooled = pooled + reduce(discrete_embeds, '... na d -> ... d', 'sum')
|
|
684
997
|
|
|
685
|
-
|
|
998
|
+
if exists(continuous_embeds):
|
|
999
|
+
pooled = pooled + reduce(continuous_embeds, '... na d -> ... d', 'sum')
|
|
1000
|
+
|
|
1001
|
+
return pooled
|
|
1002
|
+
|
|
1003
|
+
# generalized advantage estimate
|
|
1004
|
+
|
|
1005
|
+
@torch.no_grad()
|
|
1006
|
+
def calc_gae(
|
|
1007
|
+
rewards,
|
|
1008
|
+
values,
|
|
1009
|
+
masks = None,
|
|
1010
|
+
gamma = 0.99,
|
|
1011
|
+
lam = 0.95,
|
|
1012
|
+
use_accelerated = None
|
|
1013
|
+
):
|
|
1014
|
+
assert values.shape[-1] == rewards.shape[-1]
|
|
1015
|
+
use_accelerated = default(use_accelerated, rewards.is_cuda)
|
|
1016
|
+
|
|
1017
|
+
if not exists(masks):
|
|
1018
|
+
masks = torch.ones_like(values)
|
|
1019
|
+
|
|
1020
|
+
values = F.pad(values, (0, 1), value = 0.)
|
|
1021
|
+
values, values_next = values[..., :-1], values[..., 1:]
|
|
1022
|
+
|
|
1023
|
+
delta = rewards + gamma * values_next * masks - values
|
|
1024
|
+
gates = gamma * lam * masks
|
|
1025
|
+
|
|
1026
|
+
scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
|
|
1027
|
+
|
|
1028
|
+
gae = scan(gates, delta)
|
|
1029
|
+
|
|
1030
|
+
returns = gae + values
|
|
1031
|
+
|
|
1032
|
+
return returns
|
|
1033
|
+
|
|
1034
|
+
# rotary embeddings for time
|
|
1035
|
+
|
|
1036
|
+
class Rotary1D(Module):
|
|
686
1037
|
def __init__(
|
|
687
1038
|
self,
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
pre_rmsnorm = True
|
|
1039
|
+
dim_head,
|
|
1040
|
+
theta = 10000.
|
|
691
1041
|
):
|
|
692
1042
|
super().__init__()
|
|
693
|
-
|
|
1043
|
+
inv_freq = 1.0 / (theta ** (arange(0, dim_head, 2).float() / dim_head))
|
|
1044
|
+
self.register_buffer('inv_freq', inv_freq)
|
|
694
1045
|
|
|
695
|
-
|
|
1046
|
+
def forward(
|
|
1047
|
+
self,
|
|
1048
|
+
seq_len,
|
|
1049
|
+
offset = 0
|
|
1050
|
+
):
|
|
1051
|
+
device, dtype = self.inv_freq.device, self.inv_freq.dtype
|
|
696
1052
|
|
|
697
|
-
|
|
698
|
-
|
|
1053
|
+
t = torch.arange(seq_len, device = device).type(dtype) + offset
|
|
1054
|
+
freqs = einsum(t, self.inv_freq, 'i, j -> i j')
|
|
699
1055
|
|
|
700
|
-
|
|
701
|
-
x = self.norm(x)
|
|
1056
|
+
return cat((freqs, freqs), dim = -1)
|
|
702
1057
|
|
|
703
|
-
x, gates = self.proj_in(x).chunk(2, dim = -1)
|
|
704
|
-
x = x * F.gelu(gates)
|
|
705
1058
|
|
|
706
|
-
|
|
1059
|
+
def apply_rotations(
|
|
1060
|
+
rotations, # (h n d) | (n d)
|
|
1061
|
+
t # (b h n d)
|
|
1062
|
+
):
|
|
707
1063
|
|
|
708
|
-
|
|
1064
|
+
heads, seq_len, dtype = *t.shape[1:3], t.dtype
|
|
709
1065
|
|
|
710
|
-
|
|
1066
|
+
rotations_seq_len = rotations.shape[-2]
|
|
1067
|
+
|
|
1068
|
+
# handle kv caching with rotations
|
|
1069
|
+
|
|
1070
|
+
if rotations_seq_len > seq_len:
|
|
1071
|
+
rotations = rotations[-seq_len:]
|
|
1072
|
+
|
|
1073
|
+
# precision
|
|
1074
|
+
|
|
1075
|
+
t = t.float()
|
|
1076
|
+
|
|
1077
|
+
# handle gqa for rotary
|
|
1078
|
+
|
|
1079
|
+
if rotations.ndim == 3 and rotations.shape[0] < heads:
|
|
1080
|
+
rotary_heads = rotations.shape[0]
|
|
1081
|
+
|
|
1082
|
+
assert divisible_by(heads, rotary_heads)
|
|
1083
|
+
groups = heads // rotary_heads
|
|
1084
|
+
rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups)
|
|
1085
|
+
|
|
1086
|
+
x1, x2 = t.chunk(2, dim = -1)
|
|
1087
|
+
rotated_half_t = cat((-x2, x1), dim = -1)
|
|
1088
|
+
|
|
1089
|
+
# rotate in the positions
|
|
1090
|
+
|
|
1091
|
+
rotated = t * rotations.cos() + rotated_half_t * rotations.sin()
|
|
1092
|
+
return rotated.type(dtype)
|
|
1093
|
+
|
|
1094
|
+
# multi-head rmsnorm
|
|
1095
|
+
|
|
1096
|
+
class MultiHeadRMSNorm(Module):
|
|
711
1097
|
def __init__(
|
|
712
1098
|
self,
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
1099
|
+
dim_head,
|
|
1100
|
+
heads = 8
|
|
1101
|
+
):
|
|
1102
|
+
super().__init__()
|
|
1103
|
+
self.scale = dim_head ** 0.5
|
|
1104
|
+
self.gamma = Parameter(torch.zeros(heads, dim_head)) # weight decay friendly
|
|
1105
|
+
|
|
1106
|
+
def forward(
|
|
1107
|
+
self,
|
|
1108
|
+
x # (b h n d)
|
|
1109
|
+
):
|
|
1110
|
+
normed = l2norm(x)
|
|
1111
|
+
scale = (self.gamma + 1.) * self.scale
|
|
1112
|
+
return multiply('... h n d, h d', normed, scale)
|
|
1113
|
+
|
|
1114
|
+
# naive attend
|
|
1115
|
+
|
|
1116
|
+
def naive_attend(
|
|
1117
|
+
q, k, v,
|
|
1118
|
+
softclamp_value = None,
|
|
1119
|
+
scale = None,
|
|
1120
|
+
causal = False,
|
|
1121
|
+
causal_block_size = 1,
|
|
1122
|
+
mask = None
|
|
1123
|
+
):
|
|
1124
|
+
|
|
1125
|
+
if not exists(scale):
|
|
1126
|
+
scale = q.shape[-1] ** -0.5
|
|
1127
|
+
|
|
1128
|
+
# grouped query attention
|
|
1129
|
+
|
|
1130
|
+
groups = q.shape[1] // k.shape[1]
|
|
1131
|
+
|
|
1132
|
+
q = rearrange(q, 'b (h g) ... -> b h g ...', g = groups)
|
|
1133
|
+
|
|
1134
|
+
# similarity
|
|
1135
|
+
|
|
1136
|
+
sim = einsum(q, k, 'b h g i d, b h j d -> b h g i j')
|
|
1137
|
+
|
|
1138
|
+
# scale and attention
|
|
1139
|
+
|
|
1140
|
+
sim = sim * scale
|
|
1141
|
+
|
|
1142
|
+
# softclamping a la gemma 3
|
|
1143
|
+
|
|
1144
|
+
if exists(softclamp_value):
|
|
1145
|
+
sim = softclamp(sim, softclamp_value)
|
|
1146
|
+
|
|
1147
|
+
# masking
|
|
1148
|
+
|
|
1149
|
+
mask_value = -torch.finfo(sim.dtype).max
|
|
1150
|
+
|
|
1151
|
+
if exists(mask):
|
|
1152
|
+
sim = sim.masked_fill(~mask, mask_value)
|
|
1153
|
+
|
|
1154
|
+
if causal:
|
|
1155
|
+
is_blocked_causal = causal_block_size > 1
|
|
1156
|
+
i, j = sim.shape[-2:]
|
|
1157
|
+
|
|
1158
|
+
if is_blocked_causal:
|
|
1159
|
+
i = ceil(i / causal_block_size)
|
|
1160
|
+
j = ceil(j / causal_block_size)
|
|
1161
|
+
|
|
1162
|
+
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
|
|
1163
|
+
|
|
1164
|
+
if causal_block_size > 1:
|
|
1165
|
+
causal_mask = repeat(causal_mask, 'i j -> (i b1) (j b2)', b1 = causal_block_size, b2 = causal_block_size)
|
|
1166
|
+
causal_mask = causal_mask[:sim.shape[-2], :sim.shape[-1]]
|
|
1167
|
+
|
|
1168
|
+
sim = sim.masked_fill(causal_mask, mask_value)
|
|
1169
|
+
|
|
1170
|
+
# attend
|
|
1171
|
+
|
|
1172
|
+
attn = sim.softmax(dim = -1)
|
|
1173
|
+
|
|
1174
|
+
# aggregate
|
|
1175
|
+
|
|
1176
|
+
out = einsum(attn, v, 'b h g i j, b h j d -> b h g i d')
|
|
1177
|
+
|
|
1178
|
+
# merge the groups
|
|
1179
|
+
|
|
1180
|
+
return rearrange(out, 'b h g i d -> b (h g) i d')
|
|
1181
|
+
|
|
1182
|
+
# flex attention related and factory function for attend depending on whether on cuda + flex attention available
|
|
1183
|
+
|
|
1184
|
+
def block_mask_causal(block_size):
|
|
1185
|
+
|
|
1186
|
+
def inner(b, h, q, k):
|
|
1187
|
+
bq = q // block_size
|
|
1188
|
+
bk = k // block_size
|
|
1189
|
+
return bq >= bk
|
|
1190
|
+
|
|
1191
|
+
return inner
|
|
1192
|
+
|
|
1193
|
+
def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = False):
|
|
1194
|
+
bq = q % seq_len
|
|
1195
|
+
bk = k % seq_len
|
|
1196
|
+
|
|
1197
|
+
is_special_start_index = seq_len - num_tokens
|
|
1198
|
+
|
|
1199
|
+
q_is_special = q >= is_special_start_index
|
|
1200
|
+
k_is_special = k >= is_special_start_index
|
|
1201
|
+
|
|
1202
|
+
if special_attend_only_itself:
|
|
1203
|
+
out = ~(q_is_special & ~k_is_special) # modality attends to everything, but latent can only attend to itself (proposed attention pattern for encoder of video tokenizer)
|
|
1204
|
+
else:
|
|
1205
|
+
out = ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens
|
|
1206
|
+
|
|
1207
|
+
return out
|
|
1208
|
+
|
|
1209
|
+
def block_mask_special_tokens_right(
|
|
1210
|
+
seq_len,
|
|
1211
|
+
num_tokens,
|
|
1212
|
+
special_attend_only_itself = False
|
|
1213
|
+
):
|
|
1214
|
+
def inner(b, h, q, k):
|
|
1215
|
+
return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself)
|
|
1216
|
+
return inner
|
|
1217
|
+
|
|
1218
|
+
def compose_mask(mask1, mask2):
|
|
1219
|
+
def inner(b, h, q, k):
|
|
1220
|
+
return mask1(b, h, q, k) & mask2(b, h, q, k)
|
|
1221
|
+
|
|
1222
|
+
return inner
|
|
1223
|
+
|
|
1224
|
+
def block_mask_noop(b, h, q, k):
|
|
1225
|
+
return b >= 0
|
|
1226
|
+
|
|
1227
|
+
def score_mod_softclamp(value):
|
|
1228
|
+
def inner(sim, b, h, q, k):
|
|
1229
|
+
if not exists(value):
|
|
1230
|
+
return sim
|
|
1231
|
+
|
|
1232
|
+
sim = sim / value
|
|
1233
|
+
sim = torch.tanh(sim)
|
|
1234
|
+
sim = sim * value
|
|
1235
|
+
return sim
|
|
1236
|
+
|
|
1237
|
+
return inner
|
|
1238
|
+
|
|
1239
|
+
# factory for attend function
|
|
1240
|
+
|
|
1241
|
+
def get_attend_fn(
|
|
1242
|
+
use_flex,
|
|
1243
|
+
seq_len,
|
|
1244
|
+
k_seq_len,
|
|
1245
|
+
causal = False,
|
|
1246
|
+
causal_block_size = 1,
|
|
1247
|
+
softclamp_value = 50.,
|
|
1248
|
+
num_special_tokens = 0, # special tokens are latents / agents
|
|
1249
|
+
block_size_per_special = None, # defaults to k_seq_len
|
|
1250
|
+
special_attend_only_itself = False, # by default, modality only attends to itself while special sees everything, but if turned True, will be the inverse - special can only attend to itself but modality can attend everything
|
|
1251
|
+
device = None
|
|
1252
|
+
):
|
|
1253
|
+
block_size_per_special = default(block_size_per_special, k_seq_len)
|
|
1254
|
+
|
|
1255
|
+
if use_flex:
|
|
1256
|
+
# flex pathway
|
|
1257
|
+
|
|
1258
|
+
block_mask_fn = block_mask_causal(causal_block_size) if causal else block_mask_noop
|
|
1259
|
+
|
|
1260
|
+
if num_special_tokens > 0:
|
|
1261
|
+
special_block_mask = block_mask_special_tokens_right(block_size_per_special, num_special_tokens, special_attend_only_itself)
|
|
1262
|
+
block_mask_fn = compose_mask(block_mask_fn, special_block_mask)
|
|
1263
|
+
|
|
1264
|
+
block_mask = create_block_mask(block_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = k_seq_len)
|
|
1265
|
+
|
|
1266
|
+
score_mod = score_mod_softclamp(softclamp_value)
|
|
1267
|
+
attend_fn = partial(flex_attention, block_mask = block_mask, score_mod = score_mod, enable_gqa = True)
|
|
1268
|
+
else:
|
|
1269
|
+
# naive pathway
|
|
1270
|
+
|
|
1271
|
+
mask = None
|
|
1272
|
+
if num_special_tokens > 0:
|
|
1273
|
+
q_seq = torch.arange(seq_len, device = device)[:, None]
|
|
1274
|
+
k_seq = torch.arange(k_seq_len, device = device)[None, :]
|
|
1275
|
+
|
|
1276
|
+
mask = special_token_mask(q_seq, k_seq, block_size_per_special, num_special_tokens, special_attend_only_itself)
|
|
1277
|
+
|
|
1278
|
+
attend_fn = partial(naive_attend, causal = causal, causal_block_size = causal_block_size, mask = mask, softclamp_value = softclamp_value)
|
|
1279
|
+
|
|
1280
|
+
return attend_fn
|
|
1281
|
+
|
|
1282
|
+
# attention
|
|
1283
|
+
|
|
1284
|
+
class Attention(Module):
|
|
1285
|
+
def __init__(
|
|
1286
|
+
self,
|
|
1287
|
+
dim,
|
|
1288
|
+
dim_head = 64,
|
|
1289
|
+
query_heads = None,
|
|
1290
|
+
heads = 8,
|
|
1291
|
+
pre_rmsnorm = True,
|
|
1292
|
+
gate_values = True,
|
|
1293
|
+
rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2
|
|
1294
|
+
rmsnorm_key = True,
|
|
1295
|
+
value_residual = True
|
|
1296
|
+
):
|
|
1297
|
+
super().__init__()
|
|
1298
|
+
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
|
1299
|
+
|
|
1300
|
+
# setup grouped query attention
|
|
1301
|
+
|
|
1302
|
+
query_heads = default(query_heads, heads)
|
|
1303
|
+
assert query_heads >= heads and divisible_by(query_heads, heads)
|
|
1304
|
+
|
|
1305
|
+
# scaling, splitting and merging of heads
|
|
1306
|
+
|
|
1307
|
+
self.split_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
|
|
1308
|
+
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
1309
|
+
|
|
1310
|
+
dim_q_inner = dim_head * query_heads
|
|
1311
|
+
dim_kv_inner = dim_head * heads
|
|
1312
|
+
|
|
1313
|
+
self.to_q = LinearNoBias(dim, dim_q_inner)
|
|
1314
|
+
self.to_k = LinearNoBias(dim, dim_kv_inner)
|
|
1315
|
+
self.to_v = LinearNoBias(dim, dim_kv_inner)
|
|
1316
|
+
self.to_out = LinearNoBias(dim_q_inner, dim)
|
|
1317
|
+
|
|
1318
|
+
# alphafold gating per head, for attending to nothing
|
|
1319
|
+
|
|
1320
|
+
self.to_gates = None
|
|
1321
|
+
|
|
1322
|
+
if gate_values:
|
|
1323
|
+
self.to_gates = Sequential(
|
|
1324
|
+
LinearNoBias(dim, query_heads),
|
|
1325
|
+
Rearrange('b n h -> b h n 1'),
|
|
1326
|
+
nn.Sigmoid()
|
|
1327
|
+
)
|
|
1328
|
+
|
|
1329
|
+
# stability related
|
|
1330
|
+
|
|
1331
|
+
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity()
|
|
1332
|
+
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
|
|
1333
|
+
|
|
1334
|
+
# value residual
|
|
1335
|
+
|
|
1336
|
+
self.to_learned_value_residual_mix = nn.Sequential(
|
|
1337
|
+
nn.Linear(dim, heads),
|
|
1338
|
+
Rearrange('b n h -> b h n 1'),
|
|
1339
|
+
nn.Sigmoid()
|
|
1340
|
+
) if value_residual else None
|
|
1341
|
+
|
|
1342
|
+
def muon_parameters(self):
|
|
1343
|
+
# omit the queries and keys for now given what we learned from kimi 2 paper
|
|
1344
|
+
|
|
1345
|
+
return [
|
|
1346
|
+
*self.to_v.parameters(),
|
|
1347
|
+
*self.to_out.parameters(),
|
|
1348
|
+
]
|
|
1349
|
+
|
|
1350
|
+
def forward(
|
|
1351
|
+
self,
|
|
1352
|
+
tokens, # (b n d)
|
|
1353
|
+
kv_cache = None,
|
|
1354
|
+
return_intermediates = False,
|
|
1355
|
+
rotary_pos_emb = None,
|
|
1356
|
+
residual_values = None, # (b n h d)
|
|
1357
|
+
attend_fn: Callable | None = None
|
|
1358
|
+
):
|
|
1359
|
+
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
|
|
1360
|
+
|
|
1361
|
+
tokens = self.norm(tokens)
|
|
1362
|
+
|
|
1363
|
+
q, k, v = (self.to_q(tokens), self.to_k(tokens), self.to_v(tokens))
|
|
1364
|
+
|
|
1365
|
+
# split heads
|
|
1366
|
+
|
|
1367
|
+
q, k, v = map(self.split_heads, (q, k, v))
|
|
1368
|
+
|
|
1369
|
+
# handle maybe value residual
|
|
1370
|
+
|
|
1371
|
+
if exists(residual_values):
|
|
1372
|
+
residual_values = rearrange(residual_values, '... n h d -> (...) h n d')
|
|
1373
|
+
|
|
1374
|
+
assert exists(self.to_learned_value_residual_mix)
|
|
1375
|
+
|
|
1376
|
+
learned_mix = self.to_learned_value_residual_mix(tokens)
|
|
1377
|
+
|
|
1378
|
+
v = v.lerp(residual_values, learned_mix)
|
|
1379
|
+
|
|
1380
|
+
# qk rmsnorm
|
|
1381
|
+
|
|
1382
|
+
q = self.q_heads_rmsnorm(q)
|
|
1383
|
+
k = self.k_heads_rmsnorm(k)
|
|
1384
|
+
|
|
1385
|
+
# rotary
|
|
1386
|
+
|
|
1387
|
+
if exists(rotary_pos_emb):
|
|
1388
|
+
q = apply_rotations(rotary_pos_emb, q)
|
|
1389
|
+
k = apply_rotations(rotary_pos_emb, k)
|
|
1390
|
+
|
|
1391
|
+
# caching
|
|
1392
|
+
|
|
1393
|
+
if exists(kv_cache):
|
|
1394
|
+
ck, cv = kv_cache
|
|
1395
|
+
k = cat((ck, k), dim = -2)
|
|
1396
|
+
v = cat((cv, v), dim = -2)
|
|
1397
|
+
|
|
1398
|
+
# attention
|
|
1399
|
+
|
|
1400
|
+
attend_fn = default(attend_fn, naive_attend)
|
|
1401
|
+
|
|
1402
|
+
out = attend_fn(q, k, v)
|
|
1403
|
+
|
|
1404
|
+
# gate values
|
|
1405
|
+
|
|
1406
|
+
if exists(self.to_gates):
|
|
1407
|
+
gates = self.to_gates(tokens)
|
|
1408
|
+
out = out * gates
|
|
1409
|
+
|
|
1410
|
+
# merge heads
|
|
1411
|
+
|
|
1412
|
+
out = self.merge_heads(out)
|
|
1413
|
+
|
|
1414
|
+
# combine heads
|
|
1415
|
+
|
|
1416
|
+
out = self.to_out(out)
|
|
1417
|
+
|
|
1418
|
+
out = inverse_packed_batch(out)
|
|
1419
|
+
|
|
1420
|
+
if not return_intermediates:
|
|
1421
|
+
return out
|
|
1422
|
+
|
|
1423
|
+
return out, AttentionIntermediates(stack((k, v)), tokens)
|
|
1424
|
+
|
|
1425
|
+
# feedforward
|
|
1426
|
+
|
|
1427
|
+
class SwiGLUFeedforward(Module):
|
|
1428
|
+
def __init__(
|
|
1429
|
+
self,
|
|
1430
|
+
dim,
|
|
1431
|
+
expansion_factor = 4,
|
|
1432
|
+
pre_rmsnorm = True
|
|
1433
|
+
):
|
|
1434
|
+
super().__init__()
|
|
1435
|
+
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
|
1436
|
+
|
|
1437
|
+
dim_inner = int(dim * expansion_factor * 2 / 3)
|
|
1438
|
+
|
|
1439
|
+
self.proj_in = Linear(dim, dim_inner * 2)
|
|
1440
|
+
self.proj_out = Linear(dim_inner, dim)
|
|
1441
|
+
|
|
1442
|
+
def muon_parameters(self):
|
|
1443
|
+
return [
|
|
1444
|
+
self.proj_in.weight,
|
|
1445
|
+
self.proj_out.weight,
|
|
1446
|
+
]
|
|
1447
|
+
|
|
1448
|
+
def forward(self, x):
|
|
1449
|
+
x = self.norm(x)
|
|
1450
|
+
|
|
1451
|
+
x, gates = self.proj_in(x).chunk(2, dim = -1)
|
|
1452
|
+
x = x * F.gelu(gates)
|
|
1453
|
+
|
|
1454
|
+
return self.proj_out(x)
|
|
1455
|
+
|
|
1456
|
+
# axial space time transformer
|
|
1457
|
+
|
|
1458
|
+
class AxialSpaceTimeTransformer(Module):
|
|
1459
|
+
def __init__(
|
|
1460
|
+
self,
|
|
1461
|
+
dim,
|
|
1462
|
+
depth,
|
|
1463
|
+
attn_heads = 8,
|
|
1464
|
+
attn_dim_head = 64,
|
|
1465
|
+
attn_softclamp_value = 50.,
|
|
1466
|
+
time_block_every = 4,
|
|
1467
|
+
attn_kwargs: dict = dict(),
|
|
1468
|
+
ff_kwargs: dict = dict(),
|
|
1469
|
+
num_residual_streams = 1,
|
|
1470
|
+
num_special_spatial_tokens = 1,
|
|
1471
|
+
special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
|
|
1472
|
+
final_norm = True,
|
|
1473
|
+
value_residual = True, # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS
|
|
1474
|
+
rnn_time = False
|
|
1475
|
+
):
|
|
1476
|
+
super().__init__()
|
|
1477
|
+
assert depth >= time_block_every, f'depth must be at least {time_block_every}'
|
|
1478
|
+
|
|
1479
|
+
# hyper connections
|
|
1480
|
+
|
|
1481
|
+
hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
|
|
1482
|
+
|
|
1483
|
+
# attention
|
|
1484
|
+
|
|
1485
|
+
self.attn_softclamp_value = attn_softclamp_value
|
|
1486
|
+
|
|
1487
|
+
# attention masking
|
|
1488
|
+
|
|
1489
|
+
self.special_attend_only_itself = special_attend_only_itself
|
|
1490
|
+
|
|
1491
|
+
# time rotary embedding
|
|
1492
|
+
|
|
1493
|
+
self.time_rotary = Rotary1D(attn_dim_head)
|
|
1494
|
+
|
|
1495
|
+
# project initial for value residuals
|
|
1496
|
+
|
|
1497
|
+
self.value_residual = value_residual
|
|
1498
|
+
|
|
1499
|
+
if value_residual:
|
|
1500
|
+
dim_inner = attn_dim_head * attn_heads
|
|
1501
|
+
|
|
1502
|
+
self.to_value_residual = nn.Sequential(
|
|
1503
|
+
nn.RMSNorm(dim),
|
|
1504
|
+
nn.Linear(dim, dim_inner, bias = False),
|
|
1505
|
+
Rearrange('... (h d) -> ... h d', h = attn_heads)
|
|
1506
|
+
)
|
|
1507
|
+
|
|
1508
|
+
# a gru layer across time
|
|
1509
|
+
|
|
1510
|
+
self.rnn_time = rnn_time
|
|
1511
|
+
rnn_layers = []
|
|
1512
|
+
|
|
1513
|
+
# transformer
|
|
1514
|
+
|
|
1515
|
+
layers = []
|
|
1516
|
+
is_time = []
|
|
1517
|
+
|
|
1518
|
+
for i in range(depth):
|
|
1519
|
+
layer_index = i + 1
|
|
1520
|
+
|
|
1521
|
+
is_time_block = divisible_by(layer_index, time_block_every)
|
|
1522
|
+
is_time.append(is_time_block)
|
|
1523
|
+
|
|
1524
|
+
rearrange_to_attend = Rearrange('b t s ... -> b s t ...') if is_time_block else Identity()
|
|
1525
|
+
rearrange_from_attend = Rearrange('b s t ... -> b t s ...') if is_time_block else Identity()
|
|
1526
|
+
|
|
1527
|
+
layers.append(ModuleList([
|
|
1528
|
+
rearrange_to_attend,
|
|
1529
|
+
rearrange_from_attend,
|
|
1530
|
+
hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, value_residual = value_residual, **attn_kwargs)),
|
|
1531
|
+
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
|
1532
|
+
]))
|
|
1533
|
+
|
|
1534
|
+
rnn_layers.append(ModuleList([
|
|
1535
|
+
nn.RMSNorm(dim),
|
|
1536
|
+
nn.GRU(dim, dim, batch_first = True)
|
|
1537
|
+
]) if is_time_block and rnn_time else None)
|
|
1538
|
+
|
|
1539
|
+
self.layers = ModuleList(layers)
|
|
1540
|
+
self.rnn_layers = ModuleList(rnn_layers)
|
|
1541
|
+
|
|
1542
|
+
self.is_time = is_time
|
|
1543
|
+
|
|
1544
|
+
# final norm
|
|
1545
|
+
|
|
1546
|
+
self.final_norm = nn.RMSNorm(dim) if final_norm else nn.Identity()
|
|
1547
|
+
|
|
1548
|
+
# special tokens
|
|
1549
|
+
|
|
1550
|
+
self.num_special_spatial_tokens = num_special_spatial_tokens
|
|
1551
|
+
|
|
1552
|
+
def muon_parameters(self):
|
|
1553
|
+
muon_params = []
|
|
1554
|
+
|
|
1555
|
+
for m in self.modules():
|
|
1556
|
+
if isinstance(m, (Attention, SwiGLUFeedforward)):
|
|
1557
|
+
muon_params.extend(m.muon_parameters())
|
|
1558
|
+
|
|
1559
|
+
return muon_params
|
|
1560
|
+
|
|
1561
|
+
def forward(
|
|
1562
|
+
self,
|
|
1563
|
+
tokens, # (b t s d)
|
|
1564
|
+
kv_cache: Tensor | None = None, # (y 2 b h t d)
|
|
1565
|
+
return_intermediates = False
|
|
1566
|
+
|
|
1567
|
+
): # (b t s d) | (y 2 b h t d)
|
|
1568
|
+
|
|
1569
|
+
batch, time, space_seq_len, _, device = *tokens.shape, tokens.device
|
|
1570
|
+
|
|
1571
|
+
assert tokens.ndim == 4
|
|
1572
|
+
|
|
1573
|
+
# attend functions for space and time
|
|
1574
|
+
|
|
1575
|
+
has_kv_cache = exists(kv_cache)
|
|
1576
|
+
use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix
|
|
1577
|
+
|
|
1578
|
+
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
|
|
1579
|
+
|
|
1580
|
+
space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = self.num_special_spatial_tokens, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
|
1581
|
+
|
|
1582
|
+
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
|
1583
|
+
|
|
1584
|
+
# prepare cache
|
|
1585
|
+
|
|
1586
|
+
time_attn_kv_caches = []
|
|
1587
|
+
|
|
1588
|
+
if has_kv_cache:
|
|
1589
|
+
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
|
1590
|
+
|
|
1591
|
+
rotary_seq_len = 1
|
|
1592
|
+
rotary_pos_offset = past_tokens.shape[1]
|
|
1593
|
+
else:
|
|
1594
|
+
rotary_seq_len = time
|
|
1595
|
+
rotary_pos_offset = 0
|
|
1596
|
+
|
|
1597
|
+
kv_cache = default(kv_cache, (None,))
|
|
1598
|
+
|
|
1599
|
+
iter_kv_cache = iter(kv_cache)
|
|
1600
|
+
|
|
1601
|
+
# rotary
|
|
1602
|
+
|
|
1603
|
+
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
|
1604
|
+
|
|
1605
|
+
# value residual
|
|
1606
|
+
|
|
1607
|
+
residual_values = None
|
|
1608
|
+
|
|
1609
|
+
if self.value_residual:
|
|
1610
|
+
residual_values = self.to_value_residual(tokens)
|
|
1611
|
+
|
|
1612
|
+
# normed attention inputs
|
|
1613
|
+
|
|
1614
|
+
normed_time_attn_inputs = []
|
|
1615
|
+
normed_space_attn_inputs = []
|
|
1616
|
+
|
|
1617
|
+
# attention
|
|
1618
|
+
|
|
1619
|
+
tokens = self.expand_streams(tokens)
|
|
1620
|
+
|
|
1621
|
+
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), maybe_rnn_modules, layer_is_time in zip(self.layers, self.rnn_layers, self.is_time):
|
|
1622
|
+
|
|
1623
|
+
tokens = pre_attn_rearrange(tokens)
|
|
1624
|
+
|
|
1625
|
+
# maybe rnn for time
|
|
1626
|
+
|
|
1627
|
+
if layer_is_time and exists(maybe_rnn_modules):
|
|
1628
|
+
rnn_prenorm, rnn = maybe_rnn_modules
|
|
1629
|
+
|
|
1630
|
+
rnn_input, inverse_pack_time = pack_one(tokens, '* t d')
|
|
1631
|
+
|
|
1632
|
+
rnn_out, rnn_hiddens = rnn(rnn_prenorm(rnn_input)) # todo, handle rnn cache
|
|
1633
|
+
|
|
1634
|
+
rnn_out = inverse_pack_time(rnn_out)
|
|
1635
|
+
|
|
1636
|
+
tokens = rnn_out + tokens
|
|
1637
|
+
|
|
1638
|
+
# when is a axial time attention block, should be causal
|
|
1639
|
+
|
|
1640
|
+
attend_fn = time_attend if layer_is_time else space_attend
|
|
1641
|
+
|
|
1642
|
+
layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
|
|
1643
|
+
|
|
1644
|
+
# maybe past kv cache
|
|
1645
|
+
|
|
1646
|
+
maybe_kv_cache = next(iter_kv_cache, None) if layer_is_time else None
|
|
1647
|
+
|
|
1648
|
+
# residual values
|
|
1649
|
+
|
|
1650
|
+
layer_residual_values = maybe(pre_attn_rearrange)(residual_values)
|
|
1651
|
+
|
|
1652
|
+
# attention layer
|
|
1653
|
+
|
|
1654
|
+
tokens, attn_intermediates = attn(
|
|
1655
|
+
tokens,
|
|
1656
|
+
rotary_pos_emb = layer_rotary_pos_emb,
|
|
1657
|
+
attend_fn = attend_fn,
|
|
1658
|
+
kv_cache = maybe_kv_cache,
|
|
1659
|
+
residual_values = layer_residual_values,
|
|
1660
|
+
return_intermediates = True
|
|
1661
|
+
)
|
|
1662
|
+
|
|
1663
|
+
tokens = post_attn_rearrange(tokens)
|
|
1664
|
+
|
|
1665
|
+
# feedforward layer
|
|
1666
|
+
|
|
1667
|
+
tokens = ff(tokens)
|
|
1668
|
+
|
|
1669
|
+
# save kv cache if is time layer
|
|
1670
|
+
|
|
1671
|
+
if layer_is_time:
|
|
1672
|
+
time_attn_kv_caches.append(attn_intermediates.next_kv_cache)
|
|
1673
|
+
|
|
1674
|
+
# save time attention inputs for decorr
|
|
1675
|
+
|
|
1676
|
+
space_or_time_inputs = normed_time_attn_inputs if layer_is_time else normed_space_attn_inputs
|
|
1677
|
+
|
|
1678
|
+
space_or_time_inputs.append(attn_intermediates.normed_inputs)
|
|
1679
|
+
|
|
1680
|
+
tokens = self.reduce_streams(tokens)
|
|
1681
|
+
|
|
1682
|
+
out = self.final_norm(tokens)
|
|
1683
|
+
|
|
1684
|
+
if has_kv_cache:
|
|
1685
|
+
# just concat the past tokens back on for now, todo - clean up the logic
|
|
1686
|
+
out = cat((past_tokens, out), dim = 1)
|
|
1687
|
+
|
|
1688
|
+
if not return_intermediates:
|
|
1689
|
+
return out
|
|
1690
|
+
|
|
1691
|
+
intermediates = TransformerIntermediates(
|
|
1692
|
+
stack(time_attn_kv_caches),
|
|
1693
|
+
safe_stack(normed_time_attn_inputs),
|
|
1694
|
+
safe_stack(normed_space_attn_inputs)
|
|
1695
|
+
)
|
|
1696
|
+
|
|
1697
|
+
return out, intermediates
|
|
1698
|
+
|
|
1699
|
+
# video tokenizer
|
|
1700
|
+
|
|
1701
|
+
class VideoTokenizer(Module):
|
|
1702
|
+
def __init__(
|
|
1703
|
+
self,
|
|
1704
|
+
dim,
|
|
1705
|
+
dim_latent,
|
|
1706
|
+
patch_size,
|
|
1707
|
+
image_height = None,
|
|
1708
|
+
image_width = None,
|
|
1709
|
+
num_latent_tokens = 4,
|
|
1710
|
+
encoder_depth = 4,
|
|
1711
|
+
decoder_depth = 4,
|
|
1712
|
+
time_block_every = 4,
|
|
1713
|
+
attn_kwargs: dict = dict(),
|
|
1714
|
+
attn_dim_head = 64,
|
|
1715
|
+
attn_heads = 8,
|
|
1716
|
+
attn_softclamp_value = 50.,
|
|
1717
|
+
ff_kwargs: dict = dict(),
|
|
1718
|
+
decoder_pos_mlp_depth = 2,
|
|
1719
|
+
channels = 3,
|
|
1720
|
+
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
|
1721
|
+
lpips_loss_network: Module | None = None,
|
|
1722
|
+
lpips_loss_weight = 0.2,
|
|
1723
|
+
encoder_add_decor_aux_loss = False,
|
|
1724
|
+
decor_auxx_loss_weight = 0.1,
|
|
1725
|
+
decorr_sample_frac = 0.25,
|
|
1726
|
+
nd_rotary_kwargs: dict = dict(
|
|
1727
|
+
rope_min_freq = 1.,
|
|
1728
|
+
rope_max_freq = 10000.,
|
|
1729
|
+
rope_p_zero_freqs = 0.
|
|
1730
|
+
),
|
|
1731
|
+
num_residual_streams = 1,
|
|
1732
|
+
):
|
|
1733
|
+
super().__init__()
|
|
1734
|
+
|
|
1735
|
+
self.patch_size = patch_size
|
|
1736
|
+
|
|
1737
|
+
# special tokens
|
|
1738
|
+
|
|
1739
|
+
assert num_latent_tokens >= 1
|
|
1740
|
+
self.num_latent_tokens = num_latent_tokens
|
|
1741
|
+
self.latent_tokens = Parameter(randn(num_latent_tokens, dim) * 1e-2)
|
|
1742
|
+
|
|
1743
|
+
# hyper connections
|
|
1744
|
+
|
|
1745
|
+
hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
|
|
1746
|
+
|
|
1747
|
+
# mae masking - Kaiming He paper from long ago
|
|
1748
|
+
|
|
1749
|
+
self.per_image_patch_mask_prob = per_image_patch_mask_prob
|
|
1750
|
+
self.mask_token = Parameter(randn(dim) * 1e-2)
|
|
1751
|
+
|
|
1752
|
+
# patch and unpatch
|
|
1753
|
+
|
|
1754
|
+
dim_patch = channels * patch_size ** 2
|
|
1755
|
+
|
|
1756
|
+
self.patch_to_tokens = Sequential(
|
|
1757
|
+
Rearrange('b c t (h p1) (w p2) -> b t h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
|
1758
|
+
Linear(dim_patch, dim)
|
|
1759
|
+
)
|
|
1760
|
+
|
|
1761
|
+
self.tokens_to_patch = Sequential(
|
|
1762
|
+
Linear(dim, dim_patch),
|
|
1763
|
+
Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
|
|
1764
|
+
)
|
|
1765
|
+
|
|
1766
|
+
# encoder space / time transformer
|
|
1767
|
+
|
|
1768
|
+
self.encoder_transformer = AxialSpaceTimeTransformer(
|
|
1769
|
+
dim = dim,
|
|
1770
|
+
depth = encoder_depth,
|
|
1771
|
+
attn_dim_head = attn_dim_head,
|
|
1772
|
+
attn_softclamp_value = attn_softclamp_value,
|
|
1773
|
+
time_block_every = time_block_every,
|
|
1774
|
+
num_special_spatial_tokens = num_latent_tokens,
|
|
1775
|
+
num_residual_streams = num_residual_streams,
|
|
1776
|
+
final_norm = True
|
|
1777
|
+
)
|
|
1778
|
+
|
|
1779
|
+
# latents
|
|
1780
|
+
|
|
1781
|
+
self.encoded_to_latents = Sequential(
|
|
1782
|
+
LinearNoBias(dim, dim_latent),
|
|
1783
|
+
nn.Tanh(),
|
|
1784
|
+
)
|
|
1785
|
+
|
|
1786
|
+
self.latents_to_decoder = LinearNoBias(dim_latent, dim)
|
|
1787
|
+
|
|
1788
|
+
# decoder
|
|
1789
|
+
|
|
1790
|
+
self.image_height = image_height
|
|
1791
|
+
self.image_width = image_width
|
|
1792
|
+
|
|
1793
|
+
# parameterize the decoder positional embeddings for MAE style training so it can be resolution agnostic
|
|
1794
|
+
|
|
1795
|
+
self.to_decoder_pos_emb = create_mlp(
|
|
1796
|
+
dim_in = 2,
|
|
1797
|
+
dim = dim * 2,
|
|
1798
|
+
dim_out = dim,
|
|
1799
|
+
depth = decoder_pos_mlp_depth,
|
|
1800
|
+
)
|
|
1801
|
+
|
|
1802
|
+
# decoder transformer
|
|
1803
|
+
|
|
1804
|
+
self.decoder_transformer = AxialSpaceTimeTransformer(
|
|
1805
|
+
dim = dim,
|
|
1806
|
+
depth = decoder_depth,
|
|
1807
|
+
attn_dim_head = attn_dim_head,
|
|
1808
|
+
attn_softclamp_value = attn_softclamp_value,
|
|
1809
|
+
time_block_every = time_block_every,
|
|
1810
|
+
num_special_spatial_tokens = num_latent_tokens,
|
|
1811
|
+
num_residual_streams = num_residual_streams,
|
|
1812
|
+
special_attend_only_itself = True,
|
|
1813
|
+
final_norm = True
|
|
1814
|
+
)
|
|
1815
|
+
|
|
1816
|
+
# loss related
|
|
1817
|
+
|
|
1818
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
1819
|
+
|
|
1820
|
+
self.has_lpips_loss = lpips_loss_weight > 0.
|
|
1821
|
+
self.lpips_loss_weight = lpips_loss_weight
|
|
1822
|
+
|
|
1823
|
+
if self.has_lpips_loss:
|
|
1824
|
+
self.lpips = LPIPSLoss(lpips_loss_network)
|
|
1825
|
+
|
|
1826
|
+
# decorr aux loss
|
|
1827
|
+
# https://arxiv.org/abs/2510.14657
|
|
1828
|
+
|
|
1829
|
+
self.encoder_add_decor_aux_loss = encoder_add_decor_aux_loss
|
|
1830
|
+
self.decorr_aux_loss_weight = decor_auxx_loss_weight
|
|
1831
|
+
|
|
1832
|
+
self.decorr_loss = DecorrelationLoss(decorr_sample_frac, soft_validate_num_sampled = True) if encoder_add_decor_aux_loss else None
|
|
1833
|
+
|
|
1834
|
+
@property
|
|
1835
|
+
def device(self):
|
|
1836
|
+
return self.zero.device
|
|
1837
|
+
|
|
1838
|
+
def muon_parameters(self):
|
|
1839
|
+
return [
|
|
1840
|
+
*self.encoder_transformer.muon_parameters(),
|
|
1841
|
+
*self.decoder_transformer.muon_parameters()
|
|
1842
|
+
]
|
|
1843
|
+
|
|
1844
|
+
@torch.no_grad()
|
|
1845
|
+
def tokenize(
|
|
1846
|
+
self,
|
|
1847
|
+
video
|
|
1848
|
+
):
|
|
1849
|
+
self.eval()
|
|
1850
|
+
return self.forward(video, return_latents = True)
|
|
1851
|
+
|
|
1852
|
+
def decode(
|
|
1853
|
+
self,
|
|
1854
|
+
latents, # (b t n d)
|
|
1855
|
+
height = None,
|
|
1856
|
+
width = None,
|
|
1857
|
+
): # (b c t h w)
|
|
1858
|
+
|
|
1859
|
+
height = default(height, self.image_height)
|
|
1860
|
+
width = default(width, self.image_width)
|
|
1861
|
+
|
|
1862
|
+
assert exists(height) and exists(width), f'image height and width need to be passed in when decoding latents'
|
|
1863
|
+
|
|
1864
|
+
batch, time, device = *latents.shape[:2], latents.device
|
|
1865
|
+
|
|
1866
|
+
use_flex = latents.is_cuda and exists(flex_attention)
|
|
1867
|
+
|
|
1868
|
+
num_patch_height = height // self.patch_size
|
|
1869
|
+
num_patch_width = width // self.patch_size
|
|
1870
|
+
|
|
1871
|
+
# latents to tokens
|
|
1872
|
+
|
|
1873
|
+
latent_tokens = self.latents_to_decoder(latents)
|
|
1874
|
+
|
|
1875
|
+
# generate decoder positional embedding and concat the latent token
|
|
1876
|
+
|
|
1877
|
+
spatial_pos_height = torch.linspace(-1., 1., num_patch_height, device = device)
|
|
1878
|
+
spatial_pos_width = torch.linspace(-1., 1., num_patch_width, device = device)
|
|
1879
|
+
|
|
1880
|
+
space_height_width_coor = stack(torch.meshgrid(spatial_pos_height, spatial_pos_width, indexing = 'ij'), dim = -1)
|
|
1881
|
+
|
|
1882
|
+
decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor)
|
|
1883
|
+
decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time)
|
|
1884
|
+
|
|
1885
|
+
tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d')
|
|
1886
|
+
|
|
1887
|
+
# decoder attention
|
|
1888
|
+
|
|
1889
|
+
tokens = self.decoder_transformer(tokens)
|
|
1890
|
+
|
|
1891
|
+
# unpack latents
|
|
1892
|
+
|
|
1893
|
+
tokens, latent_tokens = unpack(tokens, packed_latent_shape, 'b t * d')
|
|
1894
|
+
|
|
1895
|
+
# project back to patches
|
|
1896
|
+
|
|
1897
|
+
recon_video = self.tokens_to_patch(tokens)
|
|
1898
|
+
|
|
1899
|
+
return recon_video
|
|
1900
|
+
|
|
1901
|
+
def forward(
|
|
1902
|
+
self,
|
|
1903
|
+
video, # (b c t h w)
|
|
1904
|
+
return_latents = False,
|
|
1905
|
+
mask_patches = None,
|
|
1906
|
+
return_all_losses = False
|
|
1907
|
+
):
|
|
1908
|
+
batch, _, time, height, width = video.shape
|
|
1909
|
+
patch_size, device = self.patch_size, video.device
|
|
1910
|
+
|
|
1911
|
+
assert divisible_by(height, patch_size) and divisible_by(width, patch_size)
|
|
1912
|
+
|
|
1913
|
+
# to tokens
|
|
1914
|
+
|
|
1915
|
+
tokens = self.patch_to_tokens(video)
|
|
1916
|
+
|
|
1917
|
+
# get some dimensions
|
|
1918
|
+
|
|
1919
|
+
num_patch_height, num_patch_width, _ = tokens.shape[-3:]
|
|
1920
|
+
|
|
1921
|
+
# masking
|
|
1922
|
+
|
|
1923
|
+
mask_patches = default(mask_patches, self.training)
|
|
1924
|
+
|
|
1925
|
+
if mask_patches:
|
|
1926
|
+
min_mask_prob, max_mask_prob = self.per_image_patch_mask_prob
|
|
1927
|
+
|
|
1928
|
+
mask_prob = torch.empty(tokens.shape[:2], device = tokens.device).uniform_(min_mask_prob, max_mask_prob) # (b t)
|
|
1929
|
+
|
|
1930
|
+
mask_prob = repeat(mask_prob, 'b t -> b t vh vw', vh = tokens.shape[2], vw = tokens.shape[3])
|
|
1931
|
+
mask_patch = torch.bernoulli(mask_prob) == 1.
|
|
1932
|
+
|
|
1933
|
+
tokens = einx.where('..., d, ... d', mask_patch, self.mask_token, tokens)
|
|
1934
|
+
|
|
1935
|
+
# pack space
|
|
1936
|
+
|
|
1937
|
+
tokens, inverse_pack_space = pack_one(tokens, 'b t * d')
|
|
1938
|
+
|
|
1939
|
+
# add the latent
|
|
1940
|
+
|
|
1941
|
+
latents = repeat(self.latent_tokens, 'n d -> b t n d', b = tokens.shape[0], t = tokens.shape[1])
|
|
1942
|
+
|
|
1943
|
+
tokens, packed_latent_shape = pack((tokens, latents), 'b t * d')
|
|
1944
|
+
|
|
1945
|
+
# encoder attention
|
|
1946
|
+
|
|
1947
|
+
tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
|
|
1948
|
+
|
|
1949
|
+
# latent bottleneck
|
|
1950
|
+
|
|
1951
|
+
tokens, latents = unpack(tokens, packed_latent_shape, 'b t * d')
|
|
1952
|
+
|
|
1953
|
+
latents = self.encoded_to_latents(latents)
|
|
1954
|
+
|
|
1955
|
+
if return_latents:
|
|
1956
|
+
return latents
|
|
1957
|
+
|
|
1958
|
+
recon_video = self.decode(latents, height = height, width = width)
|
|
1959
|
+
|
|
1960
|
+
# losses
|
|
1961
|
+
|
|
1962
|
+
recon_loss = F.mse_loss(video, recon_video)
|
|
1963
|
+
|
|
1964
|
+
lpips_loss = self.zero
|
|
1965
|
+
|
|
1966
|
+
if self.has_lpips_loss:
|
|
1967
|
+
lpips_loss = self.lpips(video, recon_video)
|
|
1968
|
+
|
|
1969
|
+
time_decorr_loss = space_decorr_loss = self.zero
|
|
1970
|
+
|
|
1971
|
+
if self.encoder_add_decor_aux_loss:
|
|
1972
|
+
if exists(time_attn_normed_inputs):
|
|
1973
|
+
time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
|
|
1974
|
+
|
|
1975
|
+
if exists(space_attn_normed_inputs):
|
|
1976
|
+
space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
|
|
1977
|
+
|
|
1978
|
+
# losses
|
|
1979
|
+
|
|
1980
|
+
total_loss = (
|
|
1981
|
+
recon_loss +
|
|
1982
|
+
lpips_loss * self.lpips_loss_weight +
|
|
1983
|
+
time_decorr_loss * self.decorr_aux_loss_weight +
|
|
1984
|
+
space_decorr_loss * self.decorr_aux_loss_weight
|
|
1985
|
+
)
|
|
1986
|
+
|
|
1987
|
+
if not return_all_losses:
|
|
1988
|
+
return total_loss
|
|
1989
|
+
|
|
1990
|
+
losses = (recon_loss, lpips_loss, decorr_loss)
|
|
1991
|
+
|
|
1992
|
+
return total_loss, TokenizerLosses(*losses)
|
|
1993
|
+
|
|
1994
|
+
# dynamics model, axial space-time transformer
|
|
1995
|
+
|
|
1996
|
+
class DynamicsWorldModel(Module):
|
|
1997
|
+
def __init__(
|
|
1998
|
+
self,
|
|
1999
|
+
dim,
|
|
2000
|
+
dim_latent,
|
|
2001
|
+
video_tokenizer: VideoTokenizer | None = None,
|
|
2002
|
+
max_steps = 64, # K_max in paper
|
|
2003
|
+
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
|
2004
|
+
num_spatial_tokens = 2, # latents projected to greater number of spatial tokens
|
|
2005
|
+
num_latent_tokens = None,
|
|
2006
|
+
num_agents = 1,
|
|
2007
|
+
num_tasks = 0,
|
|
2008
|
+
num_video_views = 1,
|
|
2009
|
+
dim_proprio = None,
|
|
2010
|
+
reward_encoder_kwargs: dict = dict(),
|
|
2011
|
+
depth = 4,
|
|
2012
|
+
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
|
2013
|
+
time_block_every = 4, # every 4th block is time
|
|
2014
|
+
attn_kwargs: dict = dict(),
|
|
2015
|
+
transformer_kwargs: dict = dict(),
|
|
2016
|
+
attn_heads = 8,
|
|
2017
|
+
attn_dim_head = 64,
|
|
2018
|
+
attn_softclamp_value = 50.,
|
|
2019
|
+
ff_kwargs: dict = dict(),
|
|
2020
|
+
loss_weight_fn: Callable = ramp_weight,
|
|
2021
|
+
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
|
|
2022
|
+
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
|
2023
|
+
add_reward_embed_to_agent_token = False,
|
|
2024
|
+
add_reward_embed_dropout = 0.1,
|
|
2025
|
+
num_discrete_actions: int | tuple[int, ...] = 0,
|
|
2026
|
+
num_continuous_actions = 0,
|
|
2027
|
+
continuous_norm_stats = None,
|
|
2028
|
+
multi_token_pred_len = 8,
|
|
2029
|
+
value_head_mlp_depth = 3,
|
|
2030
|
+
policy_head_mlp_depth = 3,
|
|
2031
|
+
latent_flow_loss_weight = 1.,
|
|
2032
|
+
reward_loss_weight: float | list[float] = 1.,
|
|
2033
|
+
discrete_action_loss_weight: float | list[float] = 1.,
|
|
2034
|
+
continuous_action_loss_weight: float | list[float] = 1.,
|
|
2035
|
+
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
|
2036
|
+
num_residual_streams = 1,
|
|
2037
|
+
keep_reward_ema_stats = False,
|
|
2038
|
+
reward_ema_decay = 0.998,
|
|
2039
|
+
reward_quantile_filter = (0.05, 0.95),
|
|
2040
|
+
gae_discount_factor = 0.997,
|
|
2041
|
+
gae_lambda = 0.95,
|
|
2042
|
+
ppo_eps_clip = 0.2,
|
|
2043
|
+
pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
|
|
2044
|
+
pmpo_reverse_kl = True,
|
|
2045
|
+
pmpo_kl_div_loss_weight = .3,
|
|
2046
|
+
normalize_advantages = None,
|
|
2047
|
+
value_clip = 0.4,
|
|
2048
|
+
policy_entropy_weight = .01,
|
|
2049
|
+
gae_use_accelerated = False
|
|
2050
|
+
):
|
|
2051
|
+
super().__init__()
|
|
2052
|
+
|
|
2053
|
+
# can accept raw video if tokenizer is passed in
|
|
2054
|
+
|
|
2055
|
+
self.video_tokenizer = video_tokenizer
|
|
2056
|
+
|
|
2057
|
+
if exists(video_tokenizer):
|
|
2058
|
+
num_latent_tokens = default(num_latent_tokens, video_tokenizer.num_latent_tokens)
|
|
2059
|
+
assert video_tokenizer.num_latent_tokens == num_latent_tokens, f'`num_latent_tokens` must be the same for the tokenizer and dynamics model'
|
|
2060
|
+
|
|
2061
|
+
assert exists(num_latent_tokens), '`num_latent_tokens` must be set'
|
|
2062
|
+
|
|
2063
|
+
# spatial
|
|
2064
|
+
|
|
2065
|
+
self.num_latent_tokens = num_latent_tokens
|
|
2066
|
+
self.dim_latent = dim_latent
|
|
2067
|
+
self.latent_shape = (num_latent_tokens, dim_latent)
|
|
2068
|
+
|
|
2069
|
+
if num_spatial_tokens >= num_latent_tokens:
|
|
2070
|
+
assert divisible_by(num_spatial_tokens, num_latent_tokens)
|
|
2071
|
+
|
|
2072
|
+
expand_factor = num_spatial_tokens // num_latent_tokens
|
|
2073
|
+
|
|
2074
|
+
self.latents_to_spatial_tokens = Sequential(
|
|
2075
|
+
Linear(dim_latent, dim * expand_factor),
|
|
2076
|
+
Rearrange('... (s d) -> ... s d', s = expand_factor)
|
|
2077
|
+
)
|
|
2078
|
+
|
|
2079
|
+
self.to_latent_pred = Sequential(
|
|
2080
|
+
Reduce('b t v n s d -> b t v n d', 'mean'),
|
|
2081
|
+
RMSNorm(dim),
|
|
2082
|
+
LinearNoBias(dim, dim_latent)
|
|
2083
|
+
)
|
|
2084
|
+
|
|
2085
|
+
else:
|
|
2086
|
+
assert divisible_by(num_latent_tokens, num_spatial_tokens)
|
|
2087
|
+
latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
|
|
2088
|
+
|
|
2089
|
+
self.latents_to_spatial_tokens = Sequential(
|
|
2090
|
+
Rearrange('... n d -> ... (n d)'),
|
|
2091
|
+
Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
|
|
2092
|
+
Rearrange('... (s d) -> ... s d', s = num_spatial_tokens)
|
|
2093
|
+
)
|
|
2094
|
+
|
|
2095
|
+
self.to_latent_pred = Sequential(
|
|
2096
|
+
RMSNorm(dim),
|
|
2097
|
+
LinearNoBias(dim, dim_latent * latent_tokens_to_space),
|
|
2098
|
+
Rearrange('b t v s (n d) -> b t v (s n) d', n = latent_tokens_to_space)
|
|
2099
|
+
)
|
|
2100
|
+
|
|
2101
|
+
# number of video views, for robotics, which could have third person + wrist camera at least
|
|
2102
|
+
|
|
2103
|
+
assert num_video_views >= 1
|
|
2104
|
+
self.video_has_multi_view = num_video_views > 1
|
|
2105
|
+
|
|
2106
|
+
self.num_video_views = num_video_views
|
|
2107
|
+
|
|
2108
|
+
if self.video_has_multi_view:
|
|
2109
|
+
self.view_emb = nn.Parameter(torch.randn(num_video_views, dim) * 1e-2)
|
|
2110
|
+
|
|
2111
|
+
# proprioception
|
|
2112
|
+
|
|
2113
|
+
self.has_proprio = exists(dim_proprio)
|
|
2114
|
+
self.dim_proprio = dim_proprio
|
|
2115
|
+
|
|
2116
|
+
if self.has_proprio:
|
|
2117
|
+
self.to_proprio_token = nn.Linear(dim_proprio, dim)
|
|
2118
|
+
|
|
2119
|
+
self.to_proprio_pred = Sequential(
|
|
2120
|
+
RMSNorm(dim),
|
|
2121
|
+
nn.Linear(dim, dim_proprio)
|
|
2122
|
+
)
|
|
2123
|
+
|
|
2124
|
+
# register tokens
|
|
2125
|
+
|
|
2126
|
+
self.num_register_tokens = num_register_tokens
|
|
2127
|
+
self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
|
2128
|
+
|
|
2129
|
+
# signal and step sizes
|
|
2130
|
+
|
|
2131
|
+
assert divisible_by(dim, 2)
|
|
2132
|
+
dim_half = dim // 2
|
|
2133
|
+
|
|
2134
|
+
assert is_power_two(max_steps), '`max_steps` must be a power of 2'
|
|
2135
|
+
self.max_steps = max_steps
|
|
2136
|
+
self.num_step_sizes_log2 = int(log2(max_steps))
|
|
2137
|
+
|
|
2138
|
+
self.signal_levels_embed = nn.Embedding(max_steps, dim_half)
|
|
2139
|
+
self.step_size_embed = nn.Embedding(self.num_step_sizes_log2, dim_half) # power of 2, so 1/1, 1/2, 1/4, 1/8 ... 1/Kmax
|
|
2140
|
+
|
|
2141
|
+
self.prob_no_shortcut_train = default(prob_no_shortcut_train, self.num_step_sizes_log2 ** -1.)
|
|
2142
|
+
|
|
2143
|
+
# loss related
|
|
2144
|
+
|
|
2145
|
+
self.pred_orig_latent = pred_orig_latent # x-space or v-space
|
|
2146
|
+
self.loss_weight_fn = loss_weight_fn
|
|
2147
|
+
|
|
2148
|
+
# reinforcement related
|
|
2149
|
+
|
|
2150
|
+
# they sum all the actions into a single token
|
|
2151
|
+
|
|
2152
|
+
self.num_agents = num_agents
|
|
2153
|
+
|
|
2154
|
+
self.agent_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
|
2155
|
+
self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
|
2156
|
+
|
|
2157
|
+
self.reward_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
|
2158
|
+
|
|
2159
|
+
self.num_tasks = num_tasks
|
|
2160
|
+
self.task_embed = nn.Embedding(num_tasks, dim)
|
|
2161
|
+
|
|
2162
|
+
# learned set of latent genes
|
|
2163
|
+
|
|
2164
|
+
self.agent_has_genes = num_latent_genes > 0
|
|
2165
|
+
self.num_latent_genes = num_latent_genes
|
|
2166
|
+
self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
|
|
2167
|
+
|
|
2168
|
+
# policy head
|
|
2169
|
+
|
|
2170
|
+
self.policy_head = create_mlp(
|
|
2171
|
+
dim_in = dim,
|
|
2172
|
+
dim = dim * 4,
|
|
2173
|
+
dim_out = dim * 4,
|
|
2174
|
+
depth = policy_head_mlp_depth
|
|
2175
|
+
)
|
|
2176
|
+
|
|
2177
|
+
# action embedder
|
|
2178
|
+
|
|
2179
|
+
self.action_embedder = ActionEmbedder(
|
|
2180
|
+
dim = dim,
|
|
2181
|
+
num_discrete_actions = num_discrete_actions,
|
|
2182
|
+
num_continuous_actions = num_continuous_actions,
|
|
2183
|
+
continuous_norm_stats = continuous_norm_stats,
|
|
2184
|
+
can_unembed = True,
|
|
2185
|
+
unembed_dim = dim * 4,
|
|
2186
|
+
num_unembed_preds = multi_token_pred_len,
|
|
2187
|
+
squeeze_unembed_preds = False
|
|
2188
|
+
)
|
|
2189
|
+
|
|
2190
|
+
# multi token prediction length
|
|
2191
|
+
|
|
2192
|
+
self.multi_token_pred_len = multi_token_pred_len
|
|
2193
|
+
|
|
2194
|
+
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
|
2195
|
+
|
|
2196
|
+
self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
|
2197
|
+
self.add_reward_embed_dropout = add_reward_embed_dropout
|
|
2198
|
+
|
|
2199
|
+
self.reward_encoder = SymExpTwoHot(
|
|
2200
|
+
**reward_encoder_kwargs,
|
|
2201
|
+
dim_embed = dim,
|
|
2202
|
+
learned_embedding = add_reward_embed_to_agent_token
|
|
2203
|
+
)
|
|
2204
|
+
|
|
2205
|
+
to_reward_pred = Sequential(
|
|
2206
|
+
RMSNorm(dim),
|
|
2207
|
+
LinearNoBias(dim, self.reward_encoder.num_bins)
|
|
2208
|
+
)
|
|
2209
|
+
|
|
2210
|
+
self.to_reward_pred = Ensemble(
|
|
2211
|
+
to_reward_pred,
|
|
2212
|
+
multi_token_pred_len
|
|
2213
|
+
)
|
|
2214
|
+
|
|
2215
|
+
# value head
|
|
2216
|
+
|
|
2217
|
+
self.value_head = create_mlp(
|
|
2218
|
+
dim_in = dim,
|
|
2219
|
+
dim = dim * 4,
|
|
2220
|
+
dim_out = self.reward_encoder.num_bins,
|
|
2221
|
+
depth = value_head_mlp_depth,
|
|
735
2222
|
)
|
|
736
|
-
):
|
|
737
|
-
super().__init__()
|
|
738
2223
|
|
|
739
|
-
|
|
2224
|
+
# efficient axial space / time transformer
|
|
2225
|
+
|
|
2226
|
+
self.transformer = AxialSpaceTimeTransformer(
|
|
2227
|
+
dim = dim,
|
|
2228
|
+
depth = depth,
|
|
2229
|
+
attn_heads = attn_heads,
|
|
2230
|
+
attn_dim_head = attn_dim_head,
|
|
2231
|
+
attn_softclamp_value = attn_softclamp_value,
|
|
2232
|
+
attn_kwargs = attn_kwargs,
|
|
2233
|
+
ff_kwargs = ff_kwargs,
|
|
2234
|
+
num_residual_streams = num_residual_streams,
|
|
2235
|
+
num_special_spatial_tokens = num_agents,
|
|
2236
|
+
time_block_every = time_block_every,
|
|
2237
|
+
final_norm = False,
|
|
2238
|
+
**transformer_kwargs
|
|
2239
|
+
)
|
|
740
2240
|
|
|
741
|
-
#
|
|
2241
|
+
# ppo related
|
|
742
2242
|
|
|
743
|
-
|
|
744
|
-
self.
|
|
745
|
-
self.
|
|
2243
|
+
self.gae_use_accelerated = gae_use_accelerated
|
|
2244
|
+
self.gae_discount_factor = gae_discount_factor
|
|
2245
|
+
self.gae_lambda = gae_lambda
|
|
746
2246
|
|
|
747
|
-
|
|
2247
|
+
self.ppo_eps_clip = ppo_eps_clip
|
|
2248
|
+
self.value_clip = value_clip
|
|
2249
|
+
self.policy_entropy_weight = policy_entropy_weight
|
|
748
2250
|
|
|
749
|
-
|
|
750
|
-
self.mask_token = Parameter(randn(dim) * 1e-2)
|
|
2251
|
+
# pmpo related
|
|
751
2252
|
|
|
752
|
-
|
|
2253
|
+
self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
|
|
2254
|
+
self.pmpo_kl_div_loss_weight = pmpo_kl_div_loss_weight
|
|
2255
|
+
self.pmpo_reverse_kl = pmpo_reverse_kl
|
|
753
2256
|
|
|
754
|
-
|
|
2257
|
+
# rewards related
|
|
755
2258
|
|
|
756
|
-
self.
|
|
757
|
-
|
|
758
|
-
Linear(dim_patch, dim)
|
|
759
|
-
)
|
|
2259
|
+
self.keep_reward_ema_stats = keep_reward_ema_stats
|
|
2260
|
+
self.reward_ema_decay = reward_ema_decay
|
|
760
2261
|
|
|
761
|
-
self.
|
|
762
|
-
Linear(dim, dim_patch),
|
|
763
|
-
Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
|
|
764
|
-
)
|
|
2262
|
+
self.register_buffer('reward_quantile_filter', tensor(reward_quantile_filter), persistent = False)
|
|
765
2263
|
|
|
766
|
-
|
|
2264
|
+
self.register_buffer('ema_returns_mean', tensor(0.))
|
|
2265
|
+
self.register_buffer('ema_returns_var', tensor(1.))
|
|
767
2266
|
|
|
768
|
-
|
|
769
|
-
dim_pos = 3,
|
|
770
|
-
heads = attn_heads,
|
|
771
|
-
dim_head = attn_dim_head,
|
|
772
|
-
**nd_rotary_kwargs
|
|
773
|
-
)
|
|
2267
|
+
# loss related
|
|
774
2268
|
|
|
775
|
-
|
|
2269
|
+
self.flow_loss_normalizer = LossNormalizer(1)
|
|
2270
|
+
self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
|
|
2271
|
+
self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
|
2272
|
+
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None
|
|
776
2273
|
|
|
777
|
-
self.
|
|
2274
|
+
self.latent_flow_loss_weight = latent_flow_loss_weight
|
|
778
2275
|
|
|
779
|
-
|
|
2276
|
+
self.register_buffer('reward_loss_weight', tensor(reward_loss_weight))
|
|
2277
|
+
self.register_buffer('discrete_action_loss_weight', tensor(discrete_action_loss_weight))
|
|
2278
|
+
self.register_buffer('continuous_action_loss_weight', tensor(continuous_action_loss_weight))
|
|
780
2279
|
|
|
781
|
-
|
|
2280
|
+
assert self.reward_loss_weight.numel() in {1, multi_token_pred_len}
|
|
2281
|
+
assert self.discrete_action_loss_weight.numel() in {1, multi_token_pred_len}
|
|
2282
|
+
assert self.continuous_action_loss_weight.numel() in {1, multi_token_pred_len}
|
|
782
2283
|
|
|
783
|
-
|
|
784
|
-
encoder_layers.append(ModuleList([
|
|
785
|
-
Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs),
|
|
786
|
-
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
|
787
|
-
]))
|
|
2284
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
788
2285
|
|
|
789
|
-
|
|
790
|
-
|
|
2286
|
+
@property
|
|
2287
|
+
def device(self):
|
|
2288
|
+
return self.zero.device
|
|
791
2289
|
|
|
792
|
-
|
|
2290
|
+
# types of parameters
|
|
793
2291
|
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
nn.Tanh(),
|
|
797
|
-
)
|
|
2292
|
+
def muon_parameters(self):
|
|
2293
|
+
return self.transformer.muon_parameters()
|
|
798
2294
|
|
|
799
|
-
|
|
2295
|
+
def policy_head_parameters(self):
|
|
2296
|
+
return [
|
|
2297
|
+
*self.policy_head.parameters(),
|
|
2298
|
+
*self.action_embedder.unembed_parameters() # includes the unembed from the action-embedder
|
|
2299
|
+
]
|
|
800
2300
|
|
|
801
|
-
|
|
2301
|
+
def value_head_parameters(self):
|
|
2302
|
+
return self.value_head.parameters()
|
|
802
2303
|
|
|
803
|
-
|
|
804
|
-
|
|
2304
|
+
def parameter(self):
|
|
2305
|
+
params = super().parameters()
|
|
805
2306
|
|
|
806
|
-
|
|
2307
|
+
if not exists(self.video_tokenizer):
|
|
2308
|
+
return params
|
|
807
2309
|
|
|
808
|
-
|
|
809
|
-
dim_in = 2,
|
|
810
|
-
dim = dim * 2,
|
|
811
|
-
dim_out = dim,
|
|
812
|
-
depth = decoder_pos_mlp_depth,
|
|
813
|
-
)
|
|
2310
|
+
return list(set(params) - set(self.video_tokenizer.parameters()))
|
|
814
2311
|
|
|
815
|
-
|
|
2312
|
+
# helpers for shortcut flow matching
|
|
816
2313
|
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
2314
|
+
def get_times_from_signal_level(
|
|
2315
|
+
self,
|
|
2316
|
+
signal_levels,
|
|
2317
|
+
align_dims_left_to = None
|
|
2318
|
+
):
|
|
2319
|
+
times = signal_levels.float() / self.max_steps
|
|
822
2320
|
|
|
823
|
-
|
|
824
|
-
|
|
2321
|
+
if not exists(align_dims_left_to):
|
|
2322
|
+
return times
|
|
825
2323
|
|
|
826
|
-
|
|
2324
|
+
return align_dims_left(times, align_dims_left_to)
|
|
827
2325
|
|
|
828
|
-
|
|
2326
|
+
# evolutionary policy optimization - https://web3.arxiv.org/abs/2503.19037
|
|
829
2327
|
|
|
830
|
-
|
|
831
|
-
|
|
2328
|
+
@torch.no_grad()
|
|
2329
|
+
def evolve_(
|
|
2330
|
+
self,
|
|
2331
|
+
fitness,
|
|
2332
|
+
select_frac = 0.5,
|
|
2333
|
+
tournament_frac = 0.5
|
|
2334
|
+
):
|
|
2335
|
+
assert fitness.numel() == self.num_latent_genes
|
|
832
2336
|
|
|
833
|
-
|
|
834
|
-
self.lpips = LPIPSLoss(lpips_loss_network)
|
|
2337
|
+
pop = self.latent_genes
|
|
835
2338
|
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
2339
|
+
pop_size = self.num_latent_genes
|
|
2340
|
+
num_selected = ceil(pop_size * select_frac)
|
|
2341
|
+
num_children = pop_size - num_selected
|
|
2342
|
+
|
|
2343
|
+
dim_gene = pop.shape[-1]
|
|
2344
|
+
|
|
2345
|
+
# natural selection just a sort and slice
|
|
2346
|
+
|
|
2347
|
+
selected_fitness, selected_indices = fitness.topk(num_selected, dim = -1)
|
|
2348
|
+
selected = pop[selected_indices]
|
|
2349
|
+
|
|
2350
|
+
# use tournament - one tournament per child
|
|
2351
|
+
|
|
2352
|
+
tournament_size = max(2, ceil(num_selected * tournament_frac))
|
|
2353
|
+
|
|
2354
|
+
tournaments = torch.randn((num_children, num_selected), device = self.device).argsort(dim = -1)[:, :tournament_size]
|
|
2355
|
+
|
|
2356
|
+
parent_ids = selected_fitness[tournaments].topk(2, dim = -1).indices # get top 2 winners as parents
|
|
2357
|
+
|
|
2358
|
+
parents = selected[parent_ids]
|
|
2359
|
+
|
|
2360
|
+
# crossover by random interpolation from parent1 to parent2
|
|
2361
|
+
|
|
2362
|
+
random_uniform_mix = torch.randn((num_children, dim_gene), device = self.device).sigmoid()
|
|
2363
|
+
|
|
2364
|
+
parent1, parent2 = parents.unbind(dim = 1)
|
|
2365
|
+
children = parent1.lerp(parent2, random_uniform_mix)
|
|
2366
|
+
|
|
2367
|
+
# store next population
|
|
2368
|
+
|
|
2369
|
+
next_pop = cat((selected, children))
|
|
2370
|
+
|
|
2371
|
+
self.latent_genes.copy_(next_pop)
|
|
2372
|
+
|
|
2373
|
+
# interacting with env for experience
|
|
839
2374
|
|
|
840
2375
|
@torch.no_grad()
|
|
841
|
-
def
|
|
2376
|
+
def interact_with_env(
|
|
842
2377
|
self,
|
|
843
|
-
|
|
2378
|
+
env,
|
|
2379
|
+
seed = None,
|
|
2380
|
+
agent_index = 0,
|
|
2381
|
+
step_size = 4,
|
|
2382
|
+
max_timesteps = 16,
|
|
2383
|
+
env_is_vectorized = False,
|
|
2384
|
+
use_time_kv_cache = True,
|
|
2385
|
+
store_agent_embed = True,
|
|
2386
|
+
store_old_action_unembeds = True,
|
|
844
2387
|
):
|
|
845
|
-
self.
|
|
846
|
-
return self.forward(video, return_latents = True)
|
|
2388
|
+
assert exists(self.video_tokenizer)
|
|
847
2389
|
|
|
848
|
-
|
|
849
|
-
self,
|
|
850
|
-
time,
|
|
851
|
-
num_patch_height,
|
|
852
|
-
num_patch_width
|
|
853
|
-
):
|
|
854
|
-
device = self.device
|
|
2390
|
+
init_frame = env.reset()
|
|
855
2391
|
|
|
856
|
-
|
|
857
|
-
arange(time, device = device),
|
|
858
|
-
arange(num_patch_height, device = device),
|
|
859
|
-
arange(num_patch_width, device = device)
|
|
860
|
-
), dim = -1)
|
|
2392
|
+
# frame to video
|
|
861
2393
|
|
|
862
|
-
|
|
2394
|
+
if env_is_vectorized:
|
|
2395
|
+
video = rearrange(init_frame, 'b c vh vw -> b c 1 vh vw')
|
|
2396
|
+
else:
|
|
2397
|
+
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
|
|
863
2398
|
|
|
864
|
-
|
|
2399
|
+
batch, device = video.shape[0], video.device
|
|
865
2400
|
|
|
866
|
-
|
|
2401
|
+
# accumulate
|
|
867
2402
|
|
|
868
|
-
|
|
2403
|
+
rewards = None
|
|
2404
|
+
discrete_actions = None
|
|
2405
|
+
continuous_actions = None
|
|
2406
|
+
discrete_log_probs = None
|
|
2407
|
+
continuous_log_probs = None
|
|
2408
|
+
values = None
|
|
2409
|
+
latents = None
|
|
869
2410
|
|
|
870
|
-
|
|
2411
|
+
acc_agent_embed = None
|
|
2412
|
+
acc_policy_embed = None
|
|
871
2413
|
|
|
872
|
-
|
|
873
|
-
self,
|
|
874
|
-
latents, # (b t n d)
|
|
875
|
-
height = None,
|
|
876
|
-
width = None,
|
|
877
|
-
rotary_pos_emb = None
|
|
878
|
-
): # (b c t h w)
|
|
2414
|
+
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
|
879
2415
|
|
|
880
|
-
|
|
881
|
-
|
|
2416
|
+
is_terminated = full((batch,), False, device = device)
|
|
2417
|
+
is_truncated = full((batch,), False, device = device)
|
|
882
2418
|
|
|
883
|
-
|
|
2419
|
+
episode_lens = full((batch,), 0, device = device)
|
|
884
2420
|
|
|
885
|
-
|
|
2421
|
+
# maybe time kv cache
|
|
886
2422
|
|
|
887
|
-
|
|
2423
|
+
time_kv_cache = None
|
|
888
2424
|
|
|
889
|
-
|
|
890
|
-
num_patch_width = width // self.patch_size
|
|
2425
|
+
step_index = 0
|
|
891
2426
|
|
|
892
|
-
|
|
893
|
-
|
|
2427
|
+
while not is_terminated.all():
|
|
2428
|
+
step_index += 1
|
|
894
2429
|
|
|
895
|
-
|
|
2430
|
+
latents = self.video_tokenizer(video, return_latents = True)
|
|
896
2431
|
|
|
897
|
-
|
|
2432
|
+
_, (agent_embed, next_time_kv_cache) = self.forward(
|
|
2433
|
+
latents = latents,
|
|
2434
|
+
signal_levels = self.max_steps - 1,
|
|
2435
|
+
step_sizes = step_size,
|
|
2436
|
+
rewards = rewards,
|
|
2437
|
+
discrete_actions = discrete_actions,
|
|
2438
|
+
continuous_actions = continuous_actions,
|
|
2439
|
+
time_kv_cache = time_kv_cache,
|
|
2440
|
+
latent_is_noised = True,
|
|
2441
|
+
return_pred_only = True,
|
|
2442
|
+
return_intermediates = True
|
|
2443
|
+
)
|
|
898
2444
|
|
|
899
|
-
|
|
2445
|
+
# time kv cache
|
|
900
2446
|
|
|
901
|
-
|
|
902
|
-
|
|
2447
|
+
if use_time_kv_cache:
|
|
2448
|
+
time_kv_cache = next_time_kv_cache
|
|
903
2449
|
|
|
904
|
-
|
|
2450
|
+
# get one agent
|
|
905
2451
|
|
|
906
|
-
|
|
907
|
-
decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time)
|
|
2452
|
+
one_agent_embed = agent_embed[..., -1:, agent_index, :]
|
|
908
2453
|
|
|
909
|
-
|
|
2454
|
+
# values
|
|
910
2455
|
|
|
911
|
-
|
|
2456
|
+
value_bins = self.value_head(one_agent_embed)
|
|
2457
|
+
value = self.reward_encoder.bins_to_scalar_value(value_bins)
|
|
912
2458
|
|
|
913
|
-
|
|
2459
|
+
values = safe_cat((values, value), dim = 1)
|
|
914
2460
|
|
|
915
|
-
|
|
2461
|
+
# policy embed
|
|
916
2462
|
|
|
917
|
-
|
|
2463
|
+
policy_embed = self.policy_head(one_agent_embed)
|
|
918
2464
|
|
|
919
|
-
|
|
2465
|
+
if store_old_action_unembeds:
|
|
2466
|
+
acc_policy_embed = safe_cat((acc_policy_embed, policy_embed), dim = 1)
|
|
920
2467
|
|
|
921
|
-
|
|
2468
|
+
# sample actions
|
|
922
2469
|
|
|
923
|
-
|
|
924
|
-
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens
|
|
2470
|
+
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
|
925
2471
|
|
|
926
|
-
|
|
2472
|
+
discrete_actions = safe_cat((discrete_actions, sampled_discrete_actions), dim = 1)
|
|
2473
|
+
continuous_actions = safe_cat((continuous_actions, sampled_continuous_actions), dim = 1)
|
|
927
2474
|
|
|
928
|
-
|
|
2475
|
+
# get the log prob and values for policy optimization
|
|
929
2476
|
|
|
930
|
-
|
|
2477
|
+
one_discrete_log_probs, one_continuous_log_probs = self.action_embedder.log_probs(
|
|
2478
|
+
policy_embed,
|
|
2479
|
+
pred_head_index = 0,
|
|
2480
|
+
discrete_targets = sampled_discrete_actions,
|
|
2481
|
+
continuous_targets = sampled_continuous_actions,
|
|
2482
|
+
)
|
|
931
2483
|
|
|
932
|
-
|
|
2484
|
+
discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
|
|
2485
|
+
continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
|
|
933
2486
|
|
|
934
|
-
|
|
2487
|
+
# pass the sampled action to the environment and get back next state and reward
|
|
935
2488
|
|
|
936
|
-
|
|
2489
|
+
env_step_out = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
|
937
2490
|
|
|
938
|
-
|
|
2491
|
+
if len(env_step_out) == 2:
|
|
2492
|
+
next_frame, reward = env_step_out
|
|
2493
|
+
terminated = full((batch,), False)
|
|
2494
|
+
truncated = full((batch,), False)
|
|
939
2495
|
|
|
940
|
-
|
|
2496
|
+
elif len(env_step_out) == 3:
|
|
2497
|
+
next_frame, reward, terminated = env_step_out
|
|
2498
|
+
truncated = full((batch,), False)
|
|
941
2499
|
|
|
942
|
-
|
|
2500
|
+
elif len(env_step_out) == 4:
|
|
2501
|
+
next_frame, reward, terminated, truncated = env_step_out
|
|
943
2502
|
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
video, # (b c t h w)
|
|
947
|
-
return_latents = False,
|
|
948
|
-
mask_patches = None,
|
|
949
|
-
return_all_losses = False
|
|
950
|
-
):
|
|
951
|
-
batch, _, time, height, width = video.shape
|
|
952
|
-
patch_size, device = self.patch_size, video.device
|
|
2503
|
+
elif len(env_step_out) == 5:
|
|
2504
|
+
next_frame, reward, terminated, truncated, info = env_step_out
|
|
953
2505
|
|
|
954
|
-
|
|
2506
|
+
# update episode lens
|
|
955
2507
|
|
|
956
|
-
|
|
2508
|
+
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
|
957
2509
|
|
|
958
|
-
|
|
2510
|
+
# update `is_terminated`
|
|
959
2511
|
|
|
960
|
-
|
|
2512
|
+
# (1) - environment says it is terminated
|
|
2513
|
+
# (2) - previous step is truncated (this step is for bootstrap value)
|
|
961
2514
|
|
|
962
|
-
|
|
2515
|
+
is_terminated |= (terminated | is_truncated)
|
|
963
2516
|
|
|
964
|
-
|
|
2517
|
+
# update `is_truncated`
|
|
965
2518
|
|
|
966
|
-
|
|
2519
|
+
if step_index <= max_timesteps:
|
|
2520
|
+
is_truncated |= truncated
|
|
967
2521
|
|
|
968
|
-
|
|
2522
|
+
if step_index == max_timesteps:
|
|
2523
|
+
# if the step index is at the max time step allowed, set the truncated flag, if not already terminated
|
|
969
2524
|
|
|
970
|
-
|
|
2525
|
+
is_truncated |= ~is_terminated
|
|
971
2526
|
|
|
972
|
-
|
|
973
|
-
min_mask_prob, max_mask_prob = self.per_image_patch_mask_prob
|
|
2527
|
+
# batch and time dimension
|
|
974
2528
|
|
|
975
|
-
|
|
2529
|
+
if env_is_vectorized:
|
|
2530
|
+
next_frame = rearrange(next_frame, 'b c vh vw -> b c 1 vh vw')
|
|
2531
|
+
reward = rearrange(reward, 'b -> b 1')
|
|
2532
|
+
else:
|
|
2533
|
+
next_frame = rearrange(next_frame, 'c vh vw -> 1 c 1 vh vw')
|
|
2534
|
+
reward = rearrange(reward, ' -> 1 1')
|
|
2535
|
+
|
|
2536
|
+
# concat
|
|
2537
|
+
|
|
2538
|
+
video = cat((video, next_frame), dim = 2)
|
|
2539
|
+
rewards = safe_cat((rewards, reward), dim = 1)
|
|
2540
|
+
|
|
2541
|
+
acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
|
|
2542
|
+
|
|
2543
|
+
# package up one experience for learning
|
|
2544
|
+
|
|
2545
|
+
batch, device = latents.shape[0], latents.device
|
|
2546
|
+
|
|
2547
|
+
one_experience = Experience(
|
|
2548
|
+
latents = latents,
|
|
2549
|
+
video = video[:, :, :-1],
|
|
2550
|
+
rewards = rewards,
|
|
2551
|
+
actions = (discrete_actions, continuous_actions),
|
|
2552
|
+
log_probs = (discrete_log_probs, continuous_log_probs),
|
|
2553
|
+
values = values,
|
|
2554
|
+
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
|
|
2555
|
+
agent_embed = acc_agent_embed if store_agent_embed else None,
|
|
2556
|
+
step_size = step_size,
|
|
2557
|
+
agent_index = agent_index,
|
|
2558
|
+
is_truncated = is_truncated,
|
|
2559
|
+
lens = episode_lens,
|
|
2560
|
+
is_from_world_model = False
|
|
2561
|
+
)
|
|
976
2562
|
|
|
977
|
-
|
|
978
|
-
mask_patch = torch.bernoulli(mask_prob) == 1.
|
|
2563
|
+
return one_experience
|
|
979
2564
|
|
|
980
|
-
|
|
2565
|
+
# ppo
|
|
981
2566
|
|
|
982
|
-
|
|
2567
|
+
def learn_from_experience(
|
|
2568
|
+
self,
|
|
2569
|
+
experience: Experience,
|
|
2570
|
+
policy_optim: Optimizer | None = None,
|
|
2571
|
+
value_optim: Optimizer | None = None,
|
|
2572
|
+
only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
|
2573
|
+
use_pmpo = True,
|
|
2574
|
+
normalize_advantages = None,
|
|
2575
|
+
eps = 1e-6
|
|
2576
|
+
):
|
|
2577
|
+
assert isinstance(experience, Experience)
|
|
983
2578
|
|
|
984
|
-
|
|
2579
|
+
experience = experience.to(self.device)
|
|
985
2580
|
|
|
986
|
-
|
|
2581
|
+
latents = experience.latents
|
|
2582
|
+
actions = experience.actions
|
|
2583
|
+
old_log_probs = experience.log_probs
|
|
2584
|
+
old_values = experience.values
|
|
2585
|
+
rewards = experience.rewards
|
|
2586
|
+
agent_embeds = experience.agent_embed
|
|
2587
|
+
old_action_unembeds = experience.old_action_unembeds
|
|
987
2588
|
|
|
988
|
-
|
|
2589
|
+
step_size = experience.step_size
|
|
2590
|
+
agent_index = experience.agent_index
|
|
989
2591
|
|
|
990
|
-
|
|
2592
|
+
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization - world_model.generate(..., return_log_probs_and_values = True)'
|
|
991
2593
|
|
|
992
|
-
|
|
2594
|
+
batch, time = latents.shape[0], latents.shape[1]
|
|
993
2595
|
|
|
994
|
-
#
|
|
2596
|
+
# calculate returns
|
|
995
2597
|
|
|
996
|
-
|
|
2598
|
+
# mask out anything after the `lens`, which may include a bootstrapped node at the very end if `is_truncated = True`
|
|
997
2599
|
|
|
998
|
-
|
|
2600
|
+
if not exists(experience.is_truncated):
|
|
2601
|
+
experience.is_truncated = full((batch,), True, device = latents.device)
|
|
999
2602
|
|
|
1000
|
-
|
|
2603
|
+
if exists(experience.lens):
|
|
2604
|
+
mask_for_gae = lens_to_mask(experience.lens, time)
|
|
1001
2605
|
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
causal_block_size = space_seq_len,
|
|
1005
|
-
softclamp_value = self.attn_softclamp_value,
|
|
1006
|
-
block_size_per_special = space_seq_len,
|
|
1007
|
-
num_special_tokens = 1
|
|
1008
|
-
)
|
|
2606
|
+
rewards = rewards.masked_fill(~mask_for_gae, 0.)
|
|
2607
|
+
old_values = old_values.masked_fill(~mask_for_gae, 0.)
|
|
1009
2608
|
|
|
1010
|
-
|
|
2609
|
+
# calculate returns
|
|
1011
2610
|
|
|
1012
|
-
|
|
2611
|
+
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
|
1013
2612
|
|
|
1014
|
-
#
|
|
1015
|
-
# similar to agent token in dynamics model
|
|
2613
|
+
# handle variable lengths
|
|
1016
2614
|
|
|
1017
|
-
|
|
2615
|
+
max_time = latents.shape[1]
|
|
2616
|
+
is_var_len = exists(experience.lens)
|
|
1018
2617
|
|
|
1019
|
-
|
|
2618
|
+
mask = None
|
|
1020
2619
|
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
2620
|
+
if is_var_len:
|
|
2621
|
+
learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value
|
|
2622
|
+
mask = lens_to_mask(learnable_lens, max_time)
|
|
1024
2623
|
|
|
1025
|
-
|
|
2624
|
+
# determine whether to finetune entire transformer or just learn the heads
|
|
1026
2625
|
|
|
1027
|
-
|
|
2626
|
+
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
|
|
1028
2627
|
|
|
1029
|
-
|
|
2628
|
+
# maybe keep track returns statistics and normalize returns and values before calculating advantage, as done in dreamer v3
|
|
1030
2629
|
|
|
1031
|
-
|
|
2630
|
+
if self.keep_reward_ema_stats:
|
|
2631
|
+
ema_returns_mean, ema_returns_var = self.ema_returns_mean, self.ema_returns_var
|
|
1032
2632
|
|
|
1033
|
-
|
|
2633
|
+
decay = 1. - self.reward_ema_decay
|
|
1034
2634
|
|
|
1035
|
-
|
|
1036
|
-
return latents
|
|
2635
|
+
# quantile filter
|
|
1037
2636
|
|
|
1038
|
-
|
|
2637
|
+
lo, hi = torch.quantile(returns, self.reward_quantile_filter).tolist()
|
|
2638
|
+
returns_for_stats = returns.clamp(lo, hi)
|
|
1039
2639
|
|
|
1040
|
-
|
|
2640
|
+
# mean, var - todo - handle distributed
|
|
1041
2641
|
|
|
1042
|
-
|
|
2642
|
+
returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var()
|
|
1043
2643
|
|
|
1044
|
-
|
|
2644
|
+
# ema
|
|
1045
2645
|
|
|
1046
|
-
|
|
1047
|
-
|
|
2646
|
+
ema_returns_mean.lerp_(returns_mean, decay)
|
|
2647
|
+
ema_returns_var.lerp_(returns_var, decay)
|
|
1048
2648
|
|
|
1049
|
-
|
|
2649
|
+
# normalize
|
|
1050
2650
|
|
|
1051
|
-
|
|
1052
|
-
recon_loss +
|
|
1053
|
-
lpips_loss * self.lpips_loss_weight
|
|
1054
|
-
)
|
|
2651
|
+
ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
|
|
1055
2652
|
|
|
1056
|
-
|
|
1057
|
-
|
|
2653
|
+
normed_returns = (returns - ema_returns_mean) / ema_returns_std
|
|
2654
|
+
normed_old_values = (old_values - ema_returns_mean) / ema_returns_std
|
|
1058
2655
|
|
|
1059
|
-
|
|
2656
|
+
advantage = normed_returns - normed_old_values
|
|
2657
|
+
else:
|
|
2658
|
+
advantage = returns - old_values
|
|
1060
2659
|
|
|
1061
|
-
|
|
2660
|
+
# if using pmpo, do not normalize advantages, but can be overridden
|
|
1062
2661
|
|
|
1063
|
-
|
|
2662
|
+
normalize_advantages = default(normalize_advantages, not use_pmpo)
|
|
1064
2663
|
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
self,
|
|
1068
|
-
dim,
|
|
1069
|
-
dim_latent,
|
|
1070
|
-
video_tokenizer: VideoTokenizer | None = None,
|
|
1071
|
-
max_steps = 64, # K_max in paper
|
|
1072
|
-
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
|
1073
|
-
num_spatial_tokens = 2, # latents projected to greater number of spatial tokens
|
|
1074
|
-
num_latent_tokens = None,
|
|
1075
|
-
num_agents = 1,
|
|
1076
|
-
num_tasks = 0,
|
|
1077
|
-
reward_encoder_kwargs: dict = dict(),
|
|
1078
|
-
depth = 4,
|
|
1079
|
-
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
|
1080
|
-
time_block_every = 4, # every 4th block is time
|
|
1081
|
-
attn_kwargs: dict = dict(
|
|
1082
|
-
heads = 8,
|
|
1083
|
-
),
|
|
1084
|
-
attn_dim_head = 64,
|
|
1085
|
-
attn_softclamp_value = 50.,
|
|
1086
|
-
ff_kwargs: dict = dict(),
|
|
1087
|
-
loss_weight_fn: Callable = ramp_weight,
|
|
1088
|
-
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
|
|
1089
|
-
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
|
1090
|
-
add_reward_embed_to_agent_token = False,
|
|
1091
|
-
add_reward_embed_dropout = 0.1,
|
|
1092
|
-
reward_loss_weight = 0.1,
|
|
1093
|
-
value_head_mlp_depth = 3,
|
|
1094
|
-
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
|
1095
|
-
):
|
|
1096
|
-
super().__init__()
|
|
2664
|
+
if normalize_advantages:
|
|
2665
|
+
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
|
1097
2666
|
|
|
1098
|
-
#
|
|
2667
|
+
# https://arxiv.org/abs/2410.04166v1
|
|
1099
2668
|
|
|
1100
|
-
|
|
2669
|
+
if use_pmpo:
|
|
2670
|
+
pos_advantage_mask = advantage >= 0.
|
|
2671
|
+
neg_advantage_mask = ~pos_advantage_mask
|
|
1101
2672
|
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
assert video_tokenizer.num_latent_tokens == num_latent_tokens, f'`num_latent_tokens` must be the same for the tokenizer and dynamics model'
|
|
2673
|
+
# replay for the action logits and values
|
|
2674
|
+
# but only do so if fine tuning the entire world model for RL
|
|
1105
2675
|
|
|
1106
|
-
|
|
2676
|
+
discrete_actions, continuous_actions = actions
|
|
1107
2677
|
|
|
1108
|
-
|
|
2678
|
+
if (
|
|
2679
|
+
not only_learn_policy_value_heads or
|
|
2680
|
+
not exists(agent_embeds)
|
|
2681
|
+
):
|
|
1109
2682
|
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
2683
|
+
with world_model_forward_context():
|
|
2684
|
+
_, (agent_embeds, _) = self.forward(
|
|
2685
|
+
latents = latents,
|
|
2686
|
+
signal_levels = self.max_steps - 1,
|
|
2687
|
+
step_sizes = step_size,
|
|
2688
|
+
rewards = rewards,
|
|
2689
|
+
discrete_actions = discrete_actions,
|
|
2690
|
+
continuous_actions = continuous_actions,
|
|
2691
|
+
latent_is_noised = True,
|
|
2692
|
+
return_pred_only = True,
|
|
2693
|
+
return_intermediates = True
|
|
2694
|
+
)
|
|
1113
2695
|
|
|
1114
|
-
|
|
1115
|
-
assert divisible_by(num_spatial_tokens, num_latent_tokens)
|
|
2696
|
+
agent_embeds = agent_embeds[..., agent_index, :]
|
|
1116
2697
|
|
|
1117
|
-
|
|
2698
|
+
# maybe detach agent embed
|
|
1118
2699
|
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
Rearrange('... (s d) -> ... s d', s = expand_factor)
|
|
1122
|
-
)
|
|
2700
|
+
if only_learn_policy_value_heads:
|
|
2701
|
+
agent_embeds = agent_embeds.detach()
|
|
1123
2702
|
|
|
1124
|
-
|
|
1125
|
-
Reduce('b t n s d -> b t n d', 'mean'),
|
|
1126
|
-
RMSNorm(dim),
|
|
1127
|
-
LinearNoBias(dim, dim_latent)
|
|
1128
|
-
)
|
|
2703
|
+
# ppo
|
|
1129
2704
|
|
|
1130
|
-
|
|
1131
|
-
assert divisible_by(num_latent_tokens, num_spatial_tokens)
|
|
1132
|
-
latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
|
|
2705
|
+
policy_embed = self.policy_head(agent_embeds)
|
|
1133
2706
|
|
|
1134
|
-
|
|
1135
|
-
Rearrange('b t n d -> b t (n d)'),
|
|
1136
|
-
Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
|
|
1137
|
-
Rearrange('b t (s d) -> b t s d', s = num_spatial_tokens)
|
|
1138
|
-
)
|
|
2707
|
+
log_probs, entropies = self.action_embedder.log_probs(policy_embed, pred_head_index = 0, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True)
|
|
1139
2708
|
|
|
1140
|
-
|
|
1141
|
-
RMSNorm(dim),
|
|
1142
|
-
LinearNoBias(dim, dim_latent * latent_tokens_to_space),
|
|
1143
|
-
Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
|
|
1144
|
-
)
|
|
2709
|
+
# concat discrete and continuous actions into one for optimizing
|
|
1145
2710
|
|
|
1146
|
-
|
|
2711
|
+
old_log_probs = safe_cat(old_log_probs, dim = -1)
|
|
2712
|
+
log_probs = safe_cat(log_probs, dim = -1)
|
|
2713
|
+
entropies = safe_cat(entropies, dim = -1)
|
|
1147
2714
|
|
|
1148
|
-
|
|
1149
|
-
self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
|
2715
|
+
advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions
|
|
1150
2716
|
|
|
1151
|
-
|
|
2717
|
+
if use_pmpo:
|
|
2718
|
+
# pmpo - weighting the positive and negative advantages equally - ignoring magnitude of advantage and taking the sign
|
|
2719
|
+
# seems to be weighted across batch and time, iiuc
|
|
2720
|
+
# eq (10) in https://arxiv.org/html/2410.04166v1
|
|
1152
2721
|
|
|
1153
|
-
|
|
1154
|
-
|
|
2722
|
+
if exists(mask):
|
|
2723
|
+
pos_advantage_mask &= mask
|
|
2724
|
+
neg_advantage_mask &= mask
|
|
1155
2725
|
|
|
1156
|
-
|
|
1157
|
-
self.max_steps = max_steps
|
|
1158
|
-
self.num_step_sizes_log2 = int(log2(max_steps))
|
|
2726
|
+
α = self.pmpo_pos_to_neg_weight
|
|
1159
2727
|
|
|
1160
|
-
|
|
1161
|
-
|
|
2728
|
+
pos = masked_mean(log_probs, pos_advantage_mask)
|
|
2729
|
+
neg = -masked_mean(log_probs, neg_advantage_mask)
|
|
1162
2730
|
|
|
1163
|
-
|
|
2731
|
+
policy_loss = -(α * pos + (1. - α) * neg)
|
|
1164
2732
|
|
|
1165
|
-
|
|
2733
|
+
# take care of kl
|
|
1166
2734
|
|
|
1167
|
-
|
|
1168
|
-
|
|
2735
|
+
if self.pmpo_kl_div_loss_weight > 0.:
|
|
2736
|
+
|
|
2737
|
+
new_unembedded_actions = self.action_embedder.unembed(policy_embed, pred_head_index = 0)
|
|
2738
|
+
|
|
2739
|
+
kl_div_inputs, kl_div_targets = new_unembedded_actions, old_action_unembeds
|
|
1169
2740
|
|
|
1170
|
-
|
|
2741
|
+
# mentioned that the "reverse direction for the prior KL" was used
|
|
2742
|
+
# make optional, as observed instability in toy task
|
|
1171
2743
|
|
|
1172
|
-
|
|
2744
|
+
if self.pmpo_reverse_kl:
|
|
2745
|
+
kl_div_inputs, kl_div_targets = kl_div_targets, kl_div_inputs
|
|
1173
2746
|
|
|
1174
|
-
|
|
1175
|
-
self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
|
|
2747
|
+
discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(kl_div_inputs, kl_div_targets)
|
|
1176
2748
|
|
|
1177
|
-
|
|
1178
|
-
self.task_embed = nn.Embedding(num_tasks, dim)
|
|
2749
|
+
# accumulate discrete and continuous kl div
|
|
1179
2750
|
|
|
1180
|
-
|
|
2751
|
+
kl_div_loss = 0.
|
|
1181
2752
|
|
|
1182
|
-
|
|
1183
|
-
|
|
2753
|
+
if exists(discrete_kl_div):
|
|
2754
|
+
kl_div_loss = kl_div_loss + discrete_kl_div[mask].mean()
|
|
1184
2755
|
|
|
1185
|
-
|
|
2756
|
+
if exists(continuous_kl_div):
|
|
2757
|
+
kl_div_loss = kl_div_loss + continuous_kl_div[mask].mean()
|
|
1186
2758
|
|
|
1187
|
-
|
|
1188
|
-
self.add_reward_embed_dropout = add_reward_embed_dropout
|
|
2759
|
+
policy_loss = policy_loss + kl_div_loss * self.pmpo_kl_div_loss_weight
|
|
1189
2760
|
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
dim_embed = dim,
|
|
1193
|
-
learned_embedding = add_reward_embed_to_agent_token
|
|
1194
|
-
)
|
|
2761
|
+
else:
|
|
2762
|
+
# ppo clipped surrogate loss
|
|
1195
2763
|
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
LinearNoBias(dim, self.reward_encoder.num_bins)
|
|
1199
|
-
)
|
|
2764
|
+
ratio = (log_probs - old_log_probs).exp()
|
|
2765
|
+
clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
|
|
1200
2766
|
|
|
1201
|
-
|
|
2767
|
+
policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
|
|
2768
|
+
policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
|
|
1202
2769
|
|
|
1203
|
-
|
|
2770
|
+
policy_loss = masked_mean(policy_loss, mask)
|
|
1204
2771
|
|
|
1205
|
-
|
|
1206
|
-
dim_in = dim,
|
|
1207
|
-
dim = dim * 4,
|
|
1208
|
-
dim_out = self.reward_encoder.num_bins,
|
|
1209
|
-
depth = value_head_mlp_depth,
|
|
1210
|
-
)
|
|
2772
|
+
# handle entropy loss for naive exploration bonus
|
|
1211
2773
|
|
|
1212
|
-
|
|
2774
|
+
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
|
|
1213
2775
|
|
|
1214
|
-
|
|
2776
|
+
entropy_loss = masked_mean(entropy_loss, mask)
|
|
1215
2777
|
|
|
1216
|
-
#
|
|
2778
|
+
# total policy loss
|
|
1217
2779
|
|
|
1218
|
-
|
|
2780
|
+
total_policy_loss = (
|
|
2781
|
+
policy_loss +
|
|
2782
|
+
entropy_loss * self.policy_entropy_weight
|
|
2783
|
+
)
|
|
1219
2784
|
|
|
1220
|
-
#
|
|
2785
|
+
# maybe take policy optimizer step
|
|
1221
2786
|
|
|
1222
|
-
|
|
1223
|
-
|
|
2787
|
+
if exists(policy_optim):
|
|
2788
|
+
total_policy_loss.backward()
|
|
1224
2789
|
|
|
1225
|
-
|
|
1226
|
-
|
|
2790
|
+
policy_optim.step()
|
|
2791
|
+
policy_optim.zero_grad()
|
|
1227
2792
|
|
|
1228
|
-
|
|
1229
|
-
is_time.append(is_time_block)
|
|
2793
|
+
# value loss
|
|
1230
2794
|
|
|
1231
|
-
|
|
1232
|
-
|
|
2795
|
+
value_bins = self.value_head(agent_embeds)
|
|
2796
|
+
values = self.reward_encoder.bins_to_scalar_value(value_bins)
|
|
1233
2797
|
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
rearrange_from_attend,
|
|
1237
|
-
Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs),
|
|
1238
|
-
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
|
1239
|
-
]))
|
|
2798
|
+
clipped_values = old_values + (values - old_values).clamp(-self.value_clip, self.value_clip)
|
|
2799
|
+
clipped_value_bins = self.reward_encoder(clipped_values)
|
|
1240
2800
|
|
|
1241
|
-
|
|
1242
|
-
self.is_time = is_time
|
|
2801
|
+
return_bins = self.reward_encoder(returns)
|
|
1243
2802
|
|
|
1244
|
-
|
|
2803
|
+
value_bins, return_bins, clipped_value_bins = tuple(rearrange(t, 'b t l -> b l t') for t in (value_bins, return_bins, clipped_value_bins))
|
|
1245
2804
|
|
|
1246
|
-
|
|
2805
|
+
value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none')
|
|
2806
|
+
value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none')
|
|
1247
2807
|
|
|
1248
|
-
|
|
1249
|
-
def device(self):
|
|
1250
|
-
return self.zero.device
|
|
2808
|
+
value_loss = torch.maximum(value_loss_1, value_loss_2)
|
|
1251
2809
|
|
|
1252
|
-
|
|
1253
|
-
self,
|
|
1254
|
-
signal_levels,
|
|
1255
|
-
align_dims_left_to = None
|
|
1256
|
-
):
|
|
1257
|
-
times = signal_levels.float() / self.max_steps
|
|
2810
|
+
# maybe variable length
|
|
1258
2811
|
|
|
1259
|
-
if
|
|
1260
|
-
|
|
2812
|
+
if is_var_len:
|
|
2813
|
+
value_loss = value_loss[mask].mean()
|
|
2814
|
+
else:
|
|
2815
|
+
value_loss = value_loss.mean()
|
|
1261
2816
|
|
|
1262
|
-
|
|
2817
|
+
# maybe take value optimizer step
|
|
1263
2818
|
|
|
1264
|
-
|
|
1265
|
-
|
|
2819
|
+
if exists(policy_optim):
|
|
2820
|
+
value_loss.backward()
|
|
1266
2821
|
|
|
1267
|
-
|
|
1268
|
-
|
|
2822
|
+
value_optim.step()
|
|
2823
|
+
value_optim.zero_grad()
|
|
1269
2824
|
|
|
1270
|
-
return
|
|
2825
|
+
return total_policy_loss, value_loss
|
|
1271
2826
|
|
|
1272
2827
|
@torch.no_grad()
|
|
1273
2828
|
def generate(
|
|
@@ -1276,20 +2831,49 @@ class DynamicsWorldModel(Module):
|
|
|
1276
2831
|
num_steps = 4,
|
|
1277
2832
|
batch_size = 1,
|
|
1278
2833
|
agent_index = 0,
|
|
2834
|
+
tasks: int | Tensor | None = None,
|
|
2835
|
+
latent_gene_ids = None,
|
|
1279
2836
|
image_height = None,
|
|
1280
2837
|
image_width = None,
|
|
1281
2838
|
return_decoded_video = None,
|
|
1282
2839
|
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
|
1283
|
-
|
|
2840
|
+
time_kv_cache: Tensor | None = None,
|
|
2841
|
+
use_time_kv_cache = True,
|
|
2842
|
+
return_rewards_per_frame = False,
|
|
2843
|
+
return_agent_actions = False,
|
|
2844
|
+
return_log_probs_and_values = False,
|
|
2845
|
+
return_for_policy_optimization = False,
|
|
2846
|
+
return_time_kv_cache = False,
|
|
2847
|
+
store_agent_embed = True,
|
|
2848
|
+
store_old_action_unembeds = True
|
|
1284
2849
|
|
|
1285
2850
|
): # (b t n d) | (b c t h w)
|
|
1286
2851
|
|
|
2852
|
+
# handy flag for returning generations for rl
|
|
2853
|
+
|
|
2854
|
+
if return_for_policy_optimization:
|
|
2855
|
+
return_agent_actions |= True
|
|
2856
|
+
return_log_probs_and_values |= True
|
|
2857
|
+
return_rewards_per_frame |= True
|
|
2858
|
+
|
|
2859
|
+
# more variables
|
|
2860
|
+
|
|
2861
|
+
has_proprio = self.has_proprio
|
|
1287
2862
|
was_training = self.training
|
|
1288
2863
|
self.eval()
|
|
1289
2864
|
|
|
2865
|
+
# validation
|
|
2866
|
+
|
|
1290
2867
|
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
|
1291
2868
|
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
|
|
1292
2869
|
|
|
2870
|
+
if isinstance(tasks, int):
|
|
2871
|
+
tasks = full((batch_size,), tasks, device = self.device)
|
|
2872
|
+
|
|
2873
|
+
assert not exists(tasks) or tasks.shape[0] == batch_size
|
|
2874
|
+
|
|
2875
|
+
# get state latent shape
|
|
2876
|
+
|
|
1293
2877
|
latent_shape = self.latent_shape
|
|
1294
2878
|
|
|
1295
2879
|
# derive step size
|
|
@@ -1299,12 +2883,41 @@ class DynamicsWorldModel(Module):
|
|
|
1299
2883
|
# denoising
|
|
1300
2884
|
# teacher forcing to start with
|
|
1301
2885
|
|
|
1302
|
-
latents = empty((batch_size, 0, *latent_shape), device = self.device)
|
|
2886
|
+
latents = empty((batch_size, 0, self.num_video_views, *latent_shape), device = self.device)
|
|
2887
|
+
|
|
2888
|
+
past_latents_context_noise = latents.clone()
|
|
2889
|
+
|
|
2890
|
+
# maybe internal state
|
|
2891
|
+
|
|
2892
|
+
if has_proprio:
|
|
2893
|
+
proprio = empty((batch_size, 0, self.dim_proprio), device = self.device)
|
|
2894
|
+
|
|
2895
|
+
past_proprio_context_noise = proprio.clone()
|
|
2896
|
+
|
|
2897
|
+
# maybe return actions
|
|
2898
|
+
|
|
2899
|
+
return_agent_actions |= return_log_probs_and_values
|
|
2900
|
+
|
|
2901
|
+
decoded_discrete_actions = None
|
|
2902
|
+
decoded_continuous_actions = None
|
|
2903
|
+
|
|
2904
|
+
# policy optimization related
|
|
1303
2905
|
|
|
1304
|
-
|
|
2906
|
+
decoded_discrete_log_probs = None
|
|
2907
|
+
decoded_continuous_log_probs = None
|
|
2908
|
+
decoded_values = None
|
|
2909
|
+
|
|
2910
|
+
# maybe store agent embed
|
|
2911
|
+
|
|
2912
|
+
acc_agent_embed = None
|
|
2913
|
+
|
|
2914
|
+
# maybe store old actions for kl
|
|
2915
|
+
|
|
2916
|
+
acc_policy_embed = None
|
|
1305
2917
|
|
|
1306
2918
|
# maybe return rewards
|
|
1307
2919
|
|
|
2920
|
+
decoded_rewards = None
|
|
1308
2921
|
if return_rewards_per_frame:
|
|
1309
2922
|
decoded_rewards = empty((batch_size, 0), device = self.device, dtype = torch.float32)
|
|
1310
2923
|
|
|
@@ -1313,60 +2926,182 @@ class DynamicsWorldModel(Module):
|
|
|
1313
2926
|
while latents.shape[1] < time_steps:
|
|
1314
2927
|
|
|
1315
2928
|
curr_time_steps = latents.shape[1]
|
|
1316
|
-
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
|
|
1317
2929
|
|
|
1318
|
-
|
|
2930
|
+
# determine whether to take an extra step if
|
|
2931
|
+
# (1) using time kv cache
|
|
2932
|
+
# (2) decoding anything off agent embedding (rewards, actions, etc)
|
|
2933
|
+
|
|
2934
|
+
take_extra_step = (
|
|
2935
|
+
use_time_kv_cache or
|
|
2936
|
+
return_rewards_per_frame or
|
|
2937
|
+
store_agent_embed or
|
|
2938
|
+
return_agent_actions
|
|
2939
|
+
)
|
|
2940
|
+
|
|
2941
|
+
# prepare noised latent / proprio inputs
|
|
2942
|
+
|
|
2943
|
+
noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
|
|
2944
|
+
|
|
2945
|
+
noised_proprio = None
|
|
2946
|
+
|
|
2947
|
+
if has_proprio:
|
|
2948
|
+
noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device)
|
|
2949
|
+
|
|
2950
|
+
# denoising steps
|
|
2951
|
+
|
|
2952
|
+
for step in range(num_steps + int(take_extra_step)):
|
|
2953
|
+
|
|
2954
|
+
is_last_step = (step + 1) == num_steps
|
|
2955
|
+
|
|
1319
2956
|
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
|
2957
|
+
|
|
2958
|
+
# noising past latent context
|
|
2959
|
+
|
|
2960
|
+
noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8)
|
|
2961
|
+
|
|
2962
|
+
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * v n d')
|
|
1320
2963
|
|
|
1321
|
-
|
|
2964
|
+
# handle proprio
|
|
1322
2965
|
|
|
1323
|
-
|
|
2966
|
+
noised_proprio_with_context = None
|
|
2967
|
+
|
|
2968
|
+
if has_proprio:
|
|
2969
|
+
noised_proprio_context = proprio.lerp(past_proprio_context_noise, context_signal_noise)
|
|
2970
|
+
noised_proprio_with_context, _ = pack((noised_proprio_context, noised_proprio), 'b * d')
|
|
2971
|
+
|
|
2972
|
+
# proper signal levels
|
|
1324
2973
|
|
|
1325
2974
|
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
|
1326
2975
|
|
|
1327
|
-
pred, agent_embed = self.forward(
|
|
2976
|
+
pred, (agent_embed, next_time_kv_cache) = self.forward(
|
|
1328
2977
|
latents = noised_latent_with_context,
|
|
1329
2978
|
signal_levels = signal_levels_with_context,
|
|
1330
2979
|
step_sizes = step_size,
|
|
1331
2980
|
rewards = decoded_rewards,
|
|
2981
|
+
tasks = tasks,
|
|
2982
|
+
latent_gene_ids = latent_gene_ids,
|
|
2983
|
+
discrete_actions = decoded_discrete_actions,
|
|
2984
|
+
continuous_actions = decoded_continuous_actions,
|
|
2985
|
+
proprio = noised_proprio_with_context,
|
|
2986
|
+
time_kv_cache = time_kv_cache,
|
|
1332
2987
|
latent_is_noised = True,
|
|
2988
|
+
latent_has_view_dim = True,
|
|
1333
2989
|
return_pred_only = True,
|
|
1334
|
-
|
|
2990
|
+
return_intermediates = True,
|
|
1335
2991
|
)
|
|
1336
2992
|
|
|
1337
|
-
|
|
2993
|
+
if use_time_kv_cache and is_last_step:
|
|
2994
|
+
time_kv_cache = next_time_kv_cache
|
|
2995
|
+
|
|
2996
|
+
# early break if taking an extra step for agent embedding off cleaned latents for decoding
|
|
2997
|
+
|
|
2998
|
+
if take_extra_step and is_last_step:
|
|
2999
|
+
break
|
|
3000
|
+
|
|
3001
|
+
# maybe proprio
|
|
3002
|
+
|
|
3003
|
+
if has_proprio:
|
|
3004
|
+
pred, pred_proprio = pred
|
|
3005
|
+
|
|
3006
|
+
# unpack pred
|
|
3007
|
+
|
|
3008
|
+
_, pred = unpack(pred, pack_context_shape, 'b * v n d')
|
|
3009
|
+
|
|
3010
|
+
if has_proprio:
|
|
3011
|
+
_, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d')
|
|
1338
3012
|
|
|
1339
3013
|
# derive flow, based on whether in x-space or not
|
|
1340
3014
|
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
3015
|
+
def denoise_step(pred, noised, signal_levels):
|
|
3016
|
+
if self.pred_orig_latent:
|
|
3017
|
+
times = self.get_times_from_signal_level(signal_levels)
|
|
3018
|
+
aligned_times = align_dims_left(times, noised)
|
|
3019
|
+
|
|
3020
|
+
flow = (pred - noised) / (1. - aligned_times)
|
|
3021
|
+
else:
|
|
3022
|
+
flow = pred
|
|
3023
|
+
|
|
3024
|
+
return flow * (step_size / self.max_steps)
|
|
1346
3025
|
|
|
1347
3026
|
# denoise
|
|
1348
3027
|
|
|
1349
|
-
noised_latent +=
|
|
3028
|
+
noised_latent += denoise_step(pred, noised_latent, signal_levels)
|
|
3029
|
+
|
|
3030
|
+
if has_proprio:
|
|
3031
|
+
noised_proprio += denoise_step(pred_proprio, noised_proprio, signal_levels)
|
|
1350
3032
|
|
|
1351
3033
|
denoised_latent = noised_latent # it is now denoised
|
|
1352
3034
|
|
|
3035
|
+
if has_proprio:
|
|
3036
|
+
denoised_proprio = noised_proprio
|
|
3037
|
+
|
|
1353
3038
|
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
|
1354
3039
|
|
|
1355
3040
|
if return_rewards_per_frame:
|
|
1356
3041
|
one_agent_embed = agent_embed[:, -1:, agent_index]
|
|
1357
3042
|
|
|
1358
|
-
reward_logits = self.to_reward_pred(one_agent_embed)
|
|
3043
|
+
reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0)
|
|
1359
3044
|
pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
|
|
1360
3045
|
|
|
1361
3046
|
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
|
1362
3047
|
|
|
3048
|
+
# maybe store agent embed
|
|
3049
|
+
|
|
3050
|
+
if store_agent_embed:
|
|
3051
|
+
one_agent_embed = agent_embed[:, -1:, agent_index]
|
|
3052
|
+
acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
|
|
3053
|
+
|
|
3054
|
+
# decode the agent actions if needed
|
|
3055
|
+
|
|
3056
|
+
if return_agent_actions:
|
|
3057
|
+
assert self.action_embedder.has_actions
|
|
3058
|
+
|
|
3059
|
+
one_agent_embed = agent_embed[:, -1:, agent_index]
|
|
3060
|
+
|
|
3061
|
+
policy_embed = self.policy_head(one_agent_embed)
|
|
3062
|
+
|
|
3063
|
+
# maybe store old actions
|
|
3064
|
+
|
|
3065
|
+
if store_old_action_unembeds:
|
|
3066
|
+
acc_policy_embed = safe_cat((acc_policy_embed, policy_embed), dim = 1)
|
|
3067
|
+
|
|
3068
|
+
# sample actions
|
|
3069
|
+
|
|
3070
|
+
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
|
3071
|
+
|
|
3072
|
+
decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
|
|
3073
|
+
decoded_continuous_actions = safe_cat((decoded_continuous_actions, sampled_continuous_actions), dim = 1)
|
|
3074
|
+
|
|
3075
|
+
if return_log_probs_and_values:
|
|
3076
|
+
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
|
|
3077
|
+
policy_embed,
|
|
3078
|
+
pred_head_index = 0,
|
|
3079
|
+
discrete_targets = sampled_discrete_actions,
|
|
3080
|
+
continuous_targets = sampled_continuous_actions,
|
|
3081
|
+
)
|
|
3082
|
+
|
|
3083
|
+
decoded_discrete_log_probs = safe_cat((decoded_discrete_log_probs, discrete_log_probs), dim = 1)
|
|
3084
|
+
decoded_continuous_log_probs = safe_cat((decoded_continuous_log_probs, continuous_log_probs), dim = 1)
|
|
3085
|
+
|
|
3086
|
+
value_bins = self.value_head(one_agent_embed)
|
|
3087
|
+
values = self.reward_encoder.bins_to_scalar_value(value_bins)
|
|
3088
|
+
|
|
3089
|
+
decoded_values = safe_cat((decoded_values, values), dim = 1)
|
|
3090
|
+
|
|
1363
3091
|
# concat the denoised latent
|
|
1364
3092
|
|
|
1365
3093
|
latents = cat((latents, denoised_latent), dim = 1)
|
|
1366
3094
|
|
|
1367
3095
|
# add new fixed context noise for the temporal consistency
|
|
1368
3096
|
|
|
1369
|
-
|
|
3097
|
+
past_latents_context_noise = cat((past_latents_context_noise, randn_like(denoised_latent)), dim = 1)
|
|
3098
|
+
|
|
3099
|
+
# handle proprio
|
|
3100
|
+
|
|
3101
|
+
if has_proprio:
|
|
3102
|
+
proprio = cat((proprio, denoised_proprio), dim = 1)
|
|
3103
|
+
|
|
3104
|
+
past_proprio_context_noise = cat((past_proprio_context_noise, randn_like(denoised_proprio)), dim = 1)
|
|
1370
3105
|
|
|
1371
3106
|
# restore state
|
|
1372
3107
|
|
|
@@ -1377,52 +3112,128 @@ class DynamicsWorldModel(Module):
|
|
|
1377
3112
|
has_tokenizer = exists(self.video_tokenizer)
|
|
1378
3113
|
return_decoded_video = default(return_decoded_video, has_tokenizer)
|
|
1379
3114
|
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
|
|
3115
|
+
video = None
|
|
3116
|
+
|
|
3117
|
+
if return_decoded_video:
|
|
3118
|
+
|
|
3119
|
+
latents_for_video = rearrange(latents, 'b t v n d -> b v t n d')
|
|
3120
|
+
latents_for_video, unpack_view = pack_one(latents_for_video, '* t n d')
|
|
3121
|
+
|
|
3122
|
+
video = self.video_tokenizer.decode(
|
|
3123
|
+
latents_for_video,
|
|
3124
|
+
height = image_height,
|
|
3125
|
+
width = image_width
|
|
3126
|
+
)
|
|
3127
|
+
|
|
3128
|
+
video = unpack_view(video, '* t c vh vw')
|
|
3129
|
+
|
|
3130
|
+
# remove the lone view dimension
|
|
1383
3131
|
|
|
1384
|
-
|
|
3132
|
+
if not self.video_has_multi_view:
|
|
3133
|
+
latents = rearrange(latents, 'b t 1 ... -> b t ...')
|
|
1385
3134
|
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
|
|
3135
|
+
if exists(video):
|
|
3136
|
+
video = rearrange(video, 'b 1 ... -> b ...')
|
|
3137
|
+
|
|
3138
|
+
# only return video or latent if not requesting anything else, for first stage training
|
|
3139
|
+
|
|
3140
|
+
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
|
|
3141
|
+
out = video if return_decoded_video else latents
|
|
3142
|
+
|
|
3143
|
+
if not return_time_kv_cache:
|
|
3144
|
+
return out
|
|
3145
|
+
|
|
3146
|
+
return out, time_kv_cache
|
|
3147
|
+
|
|
3148
|
+
# returning agent actions, rewards, and log probs + values for policy optimization
|
|
3149
|
+
|
|
3150
|
+
batch, device = latents.shape[0], latents.device
|
|
3151
|
+
experience_lens = full((batch,), time_steps, device = device)
|
|
3152
|
+
|
|
3153
|
+
gen = Experience(
|
|
3154
|
+
latents = latents,
|
|
3155
|
+
video = video,
|
|
3156
|
+
proprio = proprio if has_proprio else None,
|
|
3157
|
+
agent_embed = acc_agent_embed if store_agent_embed else None,
|
|
3158
|
+
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
|
|
3159
|
+
step_size = step_size,
|
|
3160
|
+
agent_index = agent_index,
|
|
3161
|
+
lens = experience_lens,
|
|
3162
|
+
is_from_world_model = True
|
|
1390
3163
|
)
|
|
1391
3164
|
|
|
1392
|
-
if
|
|
1393
|
-
|
|
3165
|
+
if return_rewards_per_frame:
|
|
3166
|
+
gen.rewards = decoded_rewards
|
|
3167
|
+
|
|
3168
|
+
if return_agent_actions:
|
|
3169
|
+
gen.actions = (decoded_discrete_actions, decoded_continuous_actions)
|
|
3170
|
+
|
|
3171
|
+
if return_log_probs_and_values:
|
|
3172
|
+
gen.log_probs = (decoded_discrete_log_probs, decoded_continuous_log_probs)
|
|
3173
|
+
|
|
3174
|
+
gen.values = decoded_values
|
|
3175
|
+
|
|
3176
|
+
if not return_time_kv_cache:
|
|
3177
|
+
return gen
|
|
1394
3178
|
|
|
1395
|
-
return
|
|
3179
|
+
return gen, time_kv_cache
|
|
1396
3180
|
|
|
1397
3181
|
def forward(
|
|
1398
3182
|
self,
|
|
1399
3183
|
*,
|
|
1400
|
-
video = None,
|
|
1401
|
-
latents = None,
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
3184
|
+
video = None, # (b v? c t vh vw)
|
|
3185
|
+
latents = None, # (b t v? n d) | (b t v? d)
|
|
3186
|
+
lens = None, # (b)
|
|
3187
|
+
signal_levels = None, # () | (b) | (b t)
|
|
3188
|
+
step_sizes = None, # () | (b)
|
|
3189
|
+
step_sizes_log2 = None, # () | (b)
|
|
3190
|
+
latent_gene_ids = None, # (b)
|
|
3191
|
+
tasks = None, # (b)
|
|
3192
|
+
rewards = None, # (b t)
|
|
3193
|
+
discrete_actions = None, # (b t na) | (b t-1 na)
|
|
3194
|
+
continuous_actions = None, # (b t na) | (b t-1 na)
|
|
3195
|
+
discrete_action_types = None, # (na)
|
|
3196
|
+
continuous_action_types = None, # (na)
|
|
3197
|
+
proprio = None, # (b t dp)
|
|
3198
|
+
time_kv_cache = None,
|
|
1408
3199
|
return_pred_only = False,
|
|
1409
3200
|
latent_is_noised = False,
|
|
1410
3201
|
return_all_losses = False,
|
|
1411
|
-
|
|
3202
|
+
return_intermediates = False,
|
|
3203
|
+
add_autoregressive_action_loss = True,
|
|
3204
|
+
update_loss_ema = None,
|
|
3205
|
+
latent_has_view_dim = False
|
|
1412
3206
|
):
|
|
1413
3207
|
# handle video or latents
|
|
1414
3208
|
|
|
1415
3209
|
assert exists(video) ^ exists(latents)
|
|
1416
3210
|
|
|
3211
|
+
# standardize view dimension
|
|
3212
|
+
|
|
3213
|
+
if not self.video_has_multi_view:
|
|
3214
|
+
if exists(video):
|
|
3215
|
+
video = rearrange(video, 'b ... -> b 1 ...')
|
|
3216
|
+
|
|
3217
|
+
if exists(latents) and not latent_has_view_dim:
|
|
3218
|
+
latents = rearrange(latents, 'b t ... -> b t 1 ...')
|
|
3219
|
+
|
|
3220
|
+
# if raw video passed in, tokenize
|
|
3221
|
+
|
|
1417
3222
|
if exists(video):
|
|
3223
|
+
assert video.ndim == 6
|
|
3224
|
+
|
|
3225
|
+
video, unpack_views = pack_one(video, '* c t vh vw')
|
|
1418
3226
|
assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
|
|
1419
3227
|
|
|
1420
3228
|
latents = self.video_tokenizer.tokenize(video)
|
|
3229
|
+
latents = unpack_views(latents, '* t n d')
|
|
3230
|
+
latents = rearrange(latents, 'b v t n d -> b t v n d')
|
|
1421
3231
|
|
|
1422
|
-
if latents.ndim ==
|
|
1423
|
-
latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
|
|
3232
|
+
if latents.ndim == 4:
|
|
3233
|
+
latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
|
|
1424
3234
|
|
|
1425
|
-
assert latents.shape[-2:] == self.latent_shape
|
|
3235
|
+
assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}'
|
|
3236
|
+
assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}'
|
|
1426
3237
|
|
|
1427
3238
|
# variables
|
|
1428
3239
|
|
|
@@ -1497,29 +3308,30 @@ class DynamicsWorldModel(Module):
|
|
|
1497
3308
|
|
|
1498
3309
|
# times is from 0 to 1
|
|
1499
3310
|
|
|
1500
|
-
times = self.get_times_from_signal_level(signal_levels
|
|
3311
|
+
times = self.get_times_from_signal_level(signal_levels)
|
|
1501
3312
|
|
|
1502
3313
|
if not latent_is_noised:
|
|
1503
3314
|
# get the noise
|
|
1504
3315
|
|
|
1505
3316
|
noise = randn_like(latents)
|
|
3317
|
+
aligned_times = align_dims_left(times, latents)
|
|
1506
3318
|
|
|
1507
3319
|
# noise from 0 as noise to 1 as data
|
|
1508
3320
|
|
|
1509
|
-
noised_latents = noise.lerp(latents,
|
|
3321
|
+
noised_latents = noise.lerp(latents, aligned_times)
|
|
1510
3322
|
|
|
1511
3323
|
else:
|
|
1512
3324
|
noised_latents = latents
|
|
1513
3325
|
|
|
1514
3326
|
# reinforcement learning related
|
|
1515
3327
|
|
|
1516
|
-
agent_tokens = repeat(self.
|
|
3328
|
+
agent_tokens = repeat(self.agent_learned_embed, '... d -> b ... d', b = batch)
|
|
1517
3329
|
|
|
1518
3330
|
if exists(tasks):
|
|
1519
3331
|
assert self.num_tasks > 0
|
|
1520
3332
|
|
|
1521
3333
|
task_embeds = self.task_embed(tasks)
|
|
1522
|
-
agent_tokens = agent_tokens
|
|
3334
|
+
agent_tokens = add('b ... d, b d', agent_tokens, task_embeds)
|
|
1523
3335
|
|
|
1524
3336
|
# maybe evolution
|
|
1525
3337
|
|
|
@@ -1527,13 +3339,15 @@ class DynamicsWorldModel(Module):
|
|
|
1527
3339
|
assert exists(self.latent_genes)
|
|
1528
3340
|
latent_genes = self.latent_genes[latent_gene_ids]
|
|
1529
3341
|
|
|
1530
|
-
agent_tokens =
|
|
3342
|
+
agent_tokens = add('b ... d, b d', agent_tokens, latent_genes)
|
|
1531
3343
|
|
|
1532
3344
|
# handle agent tokens w/ actions and task embeds
|
|
1533
3345
|
|
|
1534
3346
|
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
|
1535
3347
|
|
|
1536
|
-
# maybe
|
|
3348
|
+
# maybe reward tokens
|
|
3349
|
+
|
|
3350
|
+
reward_tokens = agent_tokens[:, :, 0:0]
|
|
1537
3351
|
|
|
1538
3352
|
if exists(rewards):
|
|
1539
3353
|
two_hot_encoding = self.reward_encoder(rewards)
|
|
@@ -1542,30 +3356,102 @@ class DynamicsWorldModel(Module):
|
|
|
1542
3356
|
self.add_reward_embed_to_agent_token and
|
|
1543
3357
|
(not self.training or not sample_prob(self.add_reward_embed_dropout)) # a bit of noise goes a long way
|
|
1544
3358
|
):
|
|
1545
|
-
|
|
3359
|
+
assert self.num_agents == 1
|
|
3360
|
+
|
|
3361
|
+
reward_tokens = self.reward_encoder.embed(two_hot_encoding)
|
|
3362
|
+
|
|
3363
|
+
pop_last_reward = int(reward_tokens.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case
|
|
3364
|
+
|
|
3365
|
+
reward_tokens = pad_at_dim(reward_tokens, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
|
|
3366
|
+
|
|
3367
|
+
reward_tokens = add('1 d, b t d', self.reward_learned_embed, reward_tokens)
|
|
3368
|
+
|
|
3369
|
+
# maybe proprioception
|
|
3370
|
+
|
|
3371
|
+
assert xnor(self.has_proprio, exists(proprio)), 'proprio must be passed in if `dim_proprio` is set and vice versa'
|
|
1546
3372
|
|
|
1547
|
-
|
|
3373
|
+
noised_proprio = None
|
|
1548
3374
|
|
|
1549
|
-
|
|
3375
|
+
if self.has_proprio:
|
|
1550
3376
|
|
|
1551
|
-
|
|
3377
|
+
if not latent_is_noised:
|
|
3378
|
+
# get the noise
|
|
3379
|
+
|
|
3380
|
+
proprio_noise = randn_like(proprio)
|
|
3381
|
+
aligned_times = align_dims_left(times, proprio)
|
|
3382
|
+
|
|
3383
|
+
# noise from 0 as noise to 1 as data
|
|
3384
|
+
|
|
3385
|
+
noised_proprio = proprio_noise.lerp(proprio, aligned_times)
|
|
3386
|
+
|
|
3387
|
+
else:
|
|
3388
|
+
noised_proprio = proprio
|
|
3389
|
+
|
|
3390
|
+
# maybe create the action tokens
|
|
3391
|
+
|
|
3392
|
+
if exists(discrete_actions) or exists(continuous_actions):
|
|
3393
|
+
assert self.action_embedder.has_actions
|
|
3394
|
+
assert self.num_agents == 1, 'only one agent allowed for now'
|
|
3395
|
+
|
|
3396
|
+
action_tokens = self.action_embedder(
|
|
3397
|
+
discrete_actions = discrete_actions,
|
|
3398
|
+
discrete_action_types = discrete_action_types,
|
|
3399
|
+
continuous_actions = continuous_actions,
|
|
3400
|
+
continuous_action_types = continuous_action_types
|
|
3401
|
+
)
|
|
3402
|
+
|
|
3403
|
+
# handle first timestep not having an associated past action
|
|
3404
|
+
|
|
3405
|
+
if action_tokens.shape[1] == (time - 1):
|
|
3406
|
+
action_tokens = pad_at_dim(action_tokens, (1, 0), value = 0. , dim = 1)
|
|
3407
|
+
|
|
3408
|
+
action_tokens = add('1 d, b t d', self.action_learned_embed, action_tokens)
|
|
3409
|
+
|
|
3410
|
+
elif self.action_embedder.has_actions:
|
|
3411
|
+
action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1])
|
|
3412
|
+
|
|
3413
|
+
else:
|
|
3414
|
+
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
|
|
1552
3415
|
|
|
1553
3416
|
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
|
1554
3417
|
|
|
1555
|
-
def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = False):
|
|
3418
|
+
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
|
|
3419
|
+
|
|
1556
3420
|
# latents to spatial tokens
|
|
1557
3421
|
|
|
1558
3422
|
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
|
1559
3423
|
|
|
3424
|
+
# maybe add view embedding
|
|
3425
|
+
|
|
3426
|
+
if self.video_has_multi_view:
|
|
3427
|
+
space_tokens = add('b t v ... d, v d', space_tokens, self.view_emb)
|
|
3428
|
+
|
|
3429
|
+
# merge spatial tokens
|
|
3430
|
+
|
|
1560
3431
|
space_tokens, inverse_pack_space_per_latent = pack_one(space_tokens, 'b t * d')
|
|
1561
3432
|
|
|
1562
3433
|
num_spatial_tokens = space_tokens.shape[-2]
|
|
1563
3434
|
|
|
3435
|
+
# action tokens
|
|
3436
|
+
|
|
3437
|
+
num_action_tokens = 1 if not is_empty(action_tokens) else 0
|
|
3438
|
+
|
|
3439
|
+
# reward tokens
|
|
3440
|
+
|
|
3441
|
+
num_reward_tokens = 1 if not is_empty(reward_tokens) else 0
|
|
3442
|
+
|
|
1564
3443
|
# pack to tokens
|
|
1565
3444
|
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
|
|
1566
3445
|
|
|
1567
3446
|
registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time)
|
|
1568
3447
|
|
|
3448
|
+
# maybe proprio
|
|
3449
|
+
|
|
3450
|
+
if exists(noised_proprio):
|
|
3451
|
+
proprio_token = self.to_proprio_token(noised_proprio)
|
|
3452
|
+
else:
|
|
3453
|
+
proprio_token = registers[:, :, 0:0]
|
|
3454
|
+
|
|
1569
3455
|
# determine signal + step size embed for their diffusion forcing + shortcut
|
|
1570
3456
|
|
|
1571
3457
|
signal_embed = self.signal_levels_embed(signal_levels)
|
|
@@ -1578,77 +3464,86 @@ class DynamicsWorldModel(Module):
|
|
|
1578
3464
|
|
|
1579
3465
|
# pack to tokens for attending
|
|
1580
3466
|
|
|
1581
|
-
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_tokens], 'b t * d')
|
|
3467
|
+
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
|
1582
3468
|
|
|
1583
|
-
#
|
|
3469
|
+
# attention
|
|
1584
3470
|
|
|
1585
|
-
|
|
3471
|
+
tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
|
|
1586
3472
|
|
|
1587
|
-
|
|
3473
|
+
# unpack
|
|
1588
3474
|
|
|
1589
|
-
|
|
3475
|
+
flow_token, space_tokens, proprio_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
|
1590
3476
|
|
|
1591
|
-
|
|
1592
|
-
+ 1 # signal + step
|
|
1593
|
-
+ self.num_agents # action / agent tokens
|
|
1594
|
-
+ self.num_register_tokens
|
|
1595
|
-
+ num_spatial_tokens
|
|
1596
|
-
)
|
|
3477
|
+
# pooling
|
|
1597
3478
|
|
|
1598
|
-
|
|
3479
|
+
space_tokens = inverse_pack_space_per_latent(space_tokens)
|
|
1599
3480
|
|
|
1600
|
-
|
|
3481
|
+
pred = self.to_latent_pred(space_tokens)
|
|
1601
3482
|
|
|
1602
|
-
#
|
|
3483
|
+
# maybe proprio
|
|
1603
3484
|
|
|
1604
|
-
|
|
3485
|
+
if self.has_proprio:
|
|
3486
|
+
pred_proprio = self.to_proprio_pred(proprio_token)
|
|
1605
3487
|
|
|
1606
|
-
|
|
3488
|
+
pred = (pred, pred_proprio)
|
|
1607
3489
|
|
|
1608
|
-
|
|
3490
|
+
# returning
|
|
1609
3491
|
|
|
1610
|
-
|
|
3492
|
+
if not return_agent_tokens:
|
|
3493
|
+
return pred
|
|
1611
3494
|
|
|
1612
|
-
|
|
3495
|
+
if not return_time_kv_cache:
|
|
3496
|
+
return pred, agent_tokens
|
|
1613
3497
|
|
|
1614
|
-
|
|
3498
|
+
return pred, (agent_tokens, next_time_kv_cache)
|
|
1615
3499
|
|
|
1616
|
-
|
|
3500
|
+
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
|
1617
3501
|
|
|
1618
|
-
|
|
3502
|
+
_get_prediction = partial(get_prediction, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
|
|
1619
3503
|
|
|
1620
|
-
|
|
3504
|
+
# forward the network
|
|
1621
3505
|
|
|
1622
|
-
|
|
3506
|
+
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
|
|
1623
3507
|
|
|
1624
|
-
|
|
3508
|
+
if return_pred_only:
|
|
3509
|
+
if not return_intermediates:
|
|
3510
|
+
return pred
|
|
1625
3511
|
|
|
1626
|
-
|
|
3512
|
+
return pred, (encoded_agent_tokens, next_time_kv_cache)
|
|
1627
3513
|
|
|
1628
|
-
|
|
3514
|
+
# pack the predictions to calculate flow for different modalities all at once
|
|
1629
3515
|
|
|
1630
|
-
|
|
3516
|
+
if self.has_proprio:
|
|
3517
|
+
pred, for_flow_loss_packed_shape = pack(pred, 'b t *')
|
|
1631
3518
|
|
|
1632
|
-
|
|
3519
|
+
noised, _ = pack((noised_latents, noised_proprio), 'b t *')
|
|
3520
|
+
data, _ = pack((latents, proprio), 'b t *')
|
|
3521
|
+
noise, _ = pack((noise, proprio_noise), 'b t *')
|
|
3522
|
+
else:
|
|
3523
|
+
noised = noised_latents
|
|
3524
|
+
data = latents
|
|
1633
3525
|
|
|
1634
|
-
|
|
3526
|
+
# wrapper function for maybe unpacking and packing modalities for doing flow math in unison
|
|
1635
3527
|
|
|
1636
|
-
|
|
3528
|
+
def maybe_pack_unpack(fn):
|
|
3529
|
+
@wraps(fn)
|
|
3530
|
+
@torch.no_grad()
|
|
3531
|
+
def inner(noised, *args, **kwargs):
|
|
1637
3532
|
|
|
1638
|
-
|
|
1639
|
-
return pred
|
|
3533
|
+
noised_proprio = None
|
|
1640
3534
|
|
|
1641
|
-
|
|
3535
|
+
if self.has_proprio:
|
|
3536
|
+
noised, noised_proprio = unpack(noised, for_flow_loss_packed_shape, 'b t *')
|
|
1642
3537
|
|
|
1643
|
-
|
|
3538
|
+
pred = fn(noised, noised_proprio, *args, **kwargs)
|
|
1644
3539
|
|
|
1645
|
-
|
|
3540
|
+
if self.has_proprio:
|
|
3541
|
+
pred, _ = pack(pred, 'b t *')
|
|
1646
3542
|
|
|
1647
|
-
if return_pred_only:
|
|
1648
|
-
if not return_agent_tokens:
|
|
1649
3543
|
return pred
|
|
3544
|
+
return inner
|
|
1650
3545
|
|
|
1651
|
-
|
|
3546
|
+
wrapped_get_prediction = maybe_pack_unpack(_get_prediction)
|
|
1652
3547
|
|
|
1653
3548
|
# determine the target for the loss
|
|
1654
3549
|
|
|
@@ -1665,46 +3560,45 @@ class DynamicsWorldModel(Module):
|
|
|
1665
3560
|
# x-space as in paper is in else clause
|
|
1666
3561
|
|
|
1667
3562
|
if is_v_space_pred:
|
|
1668
|
-
pred_target = flow =
|
|
3563
|
+
pred_target = flow = data - noise
|
|
1669
3564
|
else:
|
|
1670
|
-
pred_target =
|
|
3565
|
+
pred_target = data
|
|
1671
3566
|
else:
|
|
1672
3567
|
# shortcut training - Frans et al. https://arxiv.org/abs/2410.12557
|
|
1673
3568
|
|
|
1674
3569
|
# basically a consistency loss where you ensure quantity of two half steps equals one step
|
|
1675
3570
|
# dreamer then makes it works for x-space with some math
|
|
1676
3571
|
|
|
1677
|
-
get_prediction_no_grad = torch.no_grad()(get_prediction)
|
|
1678
|
-
|
|
1679
3572
|
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
|
1680
3573
|
half_step_size = 2 ** step_sizes_log2_minus_one
|
|
1681
3574
|
|
|
1682
|
-
first_step_pred =
|
|
3575
|
+
first_step_pred = wrapped_get_prediction(noised, signal_levels, step_sizes_log2_minus_one)
|
|
1683
3576
|
|
|
1684
3577
|
# first derive b'
|
|
1685
3578
|
|
|
1686
3579
|
if is_v_space_pred:
|
|
1687
3580
|
first_step_pred_flow = first_step_pred
|
|
1688
3581
|
else:
|
|
1689
|
-
first_times = self.get_times_from_signal_level(signal_levels,
|
|
1690
|
-
|
|
3582
|
+
first_times = self.get_times_from_signal_level(signal_levels, noised)
|
|
3583
|
+
|
|
3584
|
+
first_step_pred_flow = (first_step_pred - noised) / (1. - first_times)
|
|
1691
3585
|
|
|
1692
3586
|
# take a half step
|
|
1693
3587
|
|
|
1694
|
-
half_step_size_align_left = align_dims_left(half_step_size,
|
|
3588
|
+
half_step_size_align_left = align_dims_left(half_step_size, noised)
|
|
1695
3589
|
|
|
1696
|
-
|
|
3590
|
+
denoised = noised + first_step_pred_flow * (half_step_size_align_left / self.max_steps)
|
|
1697
3591
|
|
|
1698
3592
|
# get second prediction for b''
|
|
1699
3593
|
|
|
1700
3594
|
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
|
|
1701
|
-
second_step_pred =
|
|
3595
|
+
second_step_pred = wrapped_get_prediction(denoised, signal_levels_plus_half_step, step_sizes_log2_minus_one)
|
|
1702
3596
|
|
|
1703
3597
|
if is_v_space_pred:
|
|
1704
3598
|
second_step_pred_flow = second_step_pred
|
|
1705
3599
|
else:
|
|
1706
|
-
second_times = self.get_times_from_signal_level(signal_levels_plus_half_step,
|
|
1707
|
-
second_step_pred_flow = (second_step_pred -
|
|
3600
|
+
second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised)
|
|
3601
|
+
second_step_pred_flow = (second_step_pred - denoised) / (1. - second_times)
|
|
1708
3602
|
|
|
1709
3603
|
# pred target is sg(b' + b'') / 2
|
|
1710
3604
|
|
|
@@ -1713,7 +3607,7 @@ class DynamicsWorldModel(Module):
|
|
|
1713
3607
|
# need to convert x-space to v-space
|
|
1714
3608
|
|
|
1715
3609
|
if is_x_space:
|
|
1716
|
-
pred = (pred -
|
|
3610
|
+
pred = (pred - noised) / (1. - first_times)
|
|
1717
3611
|
maybe_shortcut_loss_weight = (1. - first_times) ** 2
|
|
1718
3612
|
|
|
1719
3613
|
# mse loss
|
|
@@ -1726,9 +3620,23 @@ class DynamicsWorldModel(Module):
|
|
|
1726
3620
|
|
|
1727
3621
|
if exists(self.loss_weight_fn):
|
|
1728
3622
|
loss_weight = self.loss_weight_fn(times)
|
|
3623
|
+
loss_weight = align_dims_left(loss_weight, flow_losses)
|
|
3624
|
+
|
|
1729
3625
|
flow_losses = flow_losses * loss_weight
|
|
1730
3626
|
|
|
1731
|
-
|
|
3627
|
+
# handle variable lengths if needed
|
|
3628
|
+
|
|
3629
|
+
is_var_len = exists(lens)
|
|
3630
|
+
|
|
3631
|
+
if is_var_len:
|
|
3632
|
+
|
|
3633
|
+
loss_mask = lens_to_mask(lens, time)
|
|
3634
|
+
loss_mask_without_last = loss_mask[:, :-1]
|
|
3635
|
+
|
|
3636
|
+
flow_loss = flow_losses[loss_mask].mean()
|
|
3637
|
+
|
|
3638
|
+
else:
|
|
3639
|
+
flow_loss = flow_losses.mean()
|
|
1732
3640
|
|
|
1733
3641
|
# now take care of the agent token losses
|
|
1734
3642
|
|
|
@@ -1739,28 +3647,114 @@ class DynamicsWorldModel(Module):
|
|
|
1739
3647
|
if rewards.ndim == 2: # (b t)
|
|
1740
3648
|
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
|
|
1741
3649
|
|
|
1742
|
-
reward_pred = self.to_reward_pred(encoded_agent_tokens)
|
|
1743
|
-
|
|
3650
|
+
reward_pred = self.to_reward_pred(encoded_agent_tokens[:, :-1])
|
|
3651
|
+
|
|
3652
|
+
reward_pred = rearrange(reward_pred, 'mtp b t l -> b l t mtp')
|
|
3653
|
+
|
|
3654
|
+
reward_targets, reward_loss_mask = create_multi_token_prediction_targets(two_hot_encoding[:, :-1], self.multi_token_pred_len)
|
|
3655
|
+
|
|
3656
|
+
reward_targets = rearrange(reward_targets, 'b t mtp l -> b l t mtp')
|
|
3657
|
+
|
|
3658
|
+
reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
|
|
3659
|
+
|
|
3660
|
+
reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.)
|
|
3661
|
+
|
|
3662
|
+
if is_var_len:
|
|
3663
|
+
reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
|
|
3664
|
+
else:
|
|
3665
|
+
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
|
|
3666
|
+
|
|
3667
|
+
# maybe autoregressive action loss
|
|
3668
|
+
|
|
3669
|
+
discrete_action_loss = self.zero
|
|
3670
|
+
continuous_action_loss = self.zero
|
|
3671
|
+
|
|
3672
|
+
if (
|
|
3673
|
+
self.num_agents == 1 and
|
|
3674
|
+
add_autoregressive_action_loss and
|
|
3675
|
+
time > 1,
|
|
3676
|
+
(exists(discrete_actions) or exists(continuous_actions))
|
|
3677
|
+
):
|
|
3678
|
+
assert self.action_embedder.has_actions
|
|
3679
|
+
|
|
3680
|
+
# handle actions having time vs time - 1 length
|
|
3681
|
+
# remove the first action if it is equal to time (as it would come from some agent token in the past)
|
|
3682
|
+
|
|
3683
|
+
if exists(discrete_actions) and discrete_actions.shape[1] == time:
|
|
3684
|
+
discrete_actions = discrete_actions[:, 1:]
|
|
3685
|
+
|
|
3686
|
+
if exists(continuous_actions) and continuous_actions.shape[1] == time:
|
|
3687
|
+
continuous_actions = continuous_actions[:, 1:]
|
|
3688
|
+
|
|
3689
|
+
# only for 1 agent
|
|
3690
|
+
|
|
3691
|
+
agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d')
|
|
3692
|
+
policy_embed = self.policy_head(agent_tokens[:, :-1])
|
|
3693
|
+
|
|
3694
|
+
# constitute multi token prediction targets
|
|
3695
|
+
|
|
3696
|
+
discrete_action_targets = continuous_action_targets = None
|
|
3697
|
+
|
|
3698
|
+
if exists(discrete_actions):
|
|
3699
|
+
discrete_action_targets, discrete_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
|
|
3700
|
+
discrete_action_targets = rearrange(discrete_action_targets, 'b t mtp ... -> mtp b t ...')
|
|
3701
|
+
discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
|
|
3702
|
+
|
|
3703
|
+
if exists(continuous_actions):
|
|
3704
|
+
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(continuous_actions, self.multi_token_pred_len)
|
|
3705
|
+
continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
|
|
3706
|
+
continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')
|
|
3707
|
+
|
|
3708
|
+
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
|
|
3709
|
+
policy_embed,
|
|
3710
|
+
discrete_targets = discrete_action_targets if exists(discrete_actions) else None,
|
|
3711
|
+
continuous_targets = continuous_action_targets if exists(continuous_actions) else None
|
|
3712
|
+
)
|
|
3713
|
+
|
|
3714
|
+
if exists(discrete_log_probs):
|
|
3715
|
+
discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
|
|
3716
|
+
|
|
3717
|
+
if is_var_len:
|
|
3718
|
+
discrete_action_losses = rearrange(-discrete_log_probs, 'mtp b t na -> b t na mtp')
|
|
3719
|
+
discrete_action_loss = reduce(discrete_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
|
|
3720
|
+
else:
|
|
3721
|
+
discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
|
|
3722
|
+
|
|
3723
|
+
if exists(continuous_log_probs):
|
|
3724
|
+
continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
|
|
3725
|
+
|
|
3726
|
+
if is_var_len:
|
|
3727
|
+
continuous_action_losses = rearrange(-continuous_log_probs, 'mtp b t na -> b t na mtp')
|
|
3728
|
+
continuous_action_loss = reduce(continuous_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
|
|
3729
|
+
else:
|
|
3730
|
+
continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
|
|
3731
|
+
|
|
3732
|
+
# handle loss normalization
|
|
3733
|
+
|
|
3734
|
+
losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss)
|
|
3735
|
+
|
|
3736
|
+
if exists(self.flow_loss_normalizer):
|
|
3737
|
+
flow_loss = self.flow_loss_normalizer(flow_loss, update_ema = update_loss_ema)
|
|
3738
|
+
|
|
3739
|
+
if exists(rewards) and exists(self.reward_loss_normalizer):
|
|
3740
|
+
reward_loss = self.reward_loss_normalizer(reward_loss, update_ema = update_loss_ema)
|
|
1744
3741
|
|
|
1745
|
-
|
|
3742
|
+
if exists(discrete_actions) and exists(self.discrete_actions_loss_normalizer):
|
|
3743
|
+
discrete_action_loss = self.discrete_actions_loss_normalizer(discrete_action_loss, update_ema = update_loss_ema)
|
|
3744
|
+
|
|
3745
|
+
if exists(continuous_actions) and exists(self.continuous_actions_loss_normalizer):
|
|
3746
|
+
continuous_action_loss = self.continuous_actions_loss_normalizer(continuous_action_loss, update_ema = update_loss_ema)
|
|
3747
|
+
|
|
3748
|
+
# gather losses - they sum across the multi token prediction steps for rewards and actions - eq (9)
|
|
1746
3749
|
|
|
1747
3750
|
total_loss = (
|
|
1748
|
-
flow_loss +
|
|
1749
|
-
reward_loss * self.reward_loss_weight
|
|
3751
|
+
flow_loss * self.latent_flow_loss_weight +
|
|
3752
|
+
(reward_loss * self.reward_loss_weight).sum() +
|
|
3753
|
+
(discrete_action_loss * self.discrete_action_loss_weight).sum() +
|
|
3754
|
+
(continuous_action_loss * self.continuous_action_loss_weight).sum()
|
|
1750
3755
|
)
|
|
1751
3756
|
|
|
1752
3757
|
if not return_all_losses:
|
|
1753
3758
|
return total_loss
|
|
1754
3759
|
|
|
1755
|
-
return total_loss,
|
|
1756
|
-
|
|
1757
|
-
# dreamer
|
|
1758
|
-
|
|
1759
|
-
class Dreamer(Module):
|
|
1760
|
-
def __init__(
|
|
1761
|
-
self,
|
|
1762
|
-
video_tokenizer: VideoTokenizer,
|
|
1763
|
-
dynamics_model: DynamicsModel,
|
|
1764
|
-
discount_factor = 0.997
|
|
1765
|
-
):
|
|
1766
|
-
super().__init__()
|
|
3760
|
+
return total_loss, losses
|