dreamer4 0.0.31__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- dreamer4/__init__.py +8 -2
- dreamer4/dreamer4.py +1552 -197
- dreamer4/mocks.py +97 -0
- dreamer4/trainers.py +525 -3
- {dreamer4-0.0.31.dist-info → dreamer4-0.1.16.dist-info}/METADATA +96 -3
- dreamer4-0.1.16.dist-info/RECORD +8 -0
- dreamer4-0.0.31.dist-info/RECORD +0 -7
- {dreamer4-0.0.31.dist-info → dreamer4-0.1.16.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.31.dist-info → dreamer4-0.1.16.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -3,16 +3,18 @@ from __future__ import annotations
|
|
|
3
3
|
import math
|
|
4
4
|
from math import ceil, log2
|
|
5
5
|
from random import random
|
|
6
|
+
from contextlib import nullcontext
|
|
6
7
|
from collections import namedtuple
|
|
7
|
-
from functools import partial
|
|
8
|
-
from dataclasses import dataclass
|
|
8
|
+
from functools import partial, wraps
|
|
9
|
+
from dataclasses import dataclass, asdict
|
|
9
10
|
|
|
10
11
|
import torch
|
|
11
12
|
import torch.nn.functional as F
|
|
12
13
|
from torch.nested import nested_tensor
|
|
13
|
-
from torch.distributions import Normal
|
|
14
|
+
from torch.distributions import Normal, kl
|
|
14
15
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
|
15
|
-
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
16
|
+
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
17
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
16
18
|
|
|
17
19
|
import torchvision
|
|
18
20
|
from torchvision.models import VGG16_Weights
|
|
@@ -20,11 +22,13 @@ from torchvision.models import VGG16_Weights
|
|
|
20
22
|
from torch.optim import Optimizer
|
|
21
23
|
from adam_atan2_pytorch import MuonAdamAtan2
|
|
22
24
|
|
|
23
|
-
from x_mlps_pytorch.normed_mlp import create_mlp
|
|
24
25
|
from x_mlps_pytorch.ensemble import Ensemble
|
|
26
|
+
from x_mlps_pytorch.normed_mlp import create_mlp
|
|
25
27
|
|
|
26
28
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
27
29
|
|
|
30
|
+
from vit_pytorch.vit_with_decorr import DecorrelationLoss
|
|
31
|
+
|
|
28
32
|
from assoc_scan import AssocScan
|
|
29
33
|
|
|
30
34
|
# ein related
|
|
@@ -35,12 +39,15 @@ from assoc_scan import AssocScan
|
|
|
35
39
|
# d - feature dimension
|
|
36
40
|
# f - frequencies (rotary)
|
|
37
41
|
# l - logit / predicted bins
|
|
42
|
+
# y - layers of transformer
|
|
38
43
|
# p - positions (3 for spacetime in this work)
|
|
39
44
|
# t - time
|
|
40
45
|
# na - action dimension (number of discrete and continuous actions)
|
|
41
46
|
# g - groups of query heads to key heads (gqa)
|
|
42
47
|
# vc - video channels
|
|
43
48
|
# vh, vw - video height and width
|
|
49
|
+
# mtp - multi token prediction length
|
|
50
|
+
# v - video viewpoints
|
|
44
51
|
|
|
45
52
|
import einx
|
|
46
53
|
from einx import add, multiply
|
|
@@ -63,22 +70,96 @@ except ImportError:
|
|
|
63
70
|
|
|
64
71
|
LinearNoBias = partial(Linear, bias = False)
|
|
65
72
|
|
|
66
|
-
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
|
73
|
+
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
|
|
74
|
+
|
|
75
|
+
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
|
67
76
|
|
|
68
|
-
|
|
77
|
+
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
|
|
78
|
+
|
|
79
|
+
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs'))
|
|
80
|
+
|
|
81
|
+
MaybeTensor = Tensor | None
|
|
69
82
|
|
|
70
83
|
@dataclass
|
|
71
84
|
class Experience:
|
|
72
85
|
latents: Tensor
|
|
73
|
-
video:
|
|
86
|
+
video: MaybeTensor = None
|
|
87
|
+
proprio: MaybeTensor = None
|
|
88
|
+
agent_embed: MaybeTensor = None
|
|
74
89
|
rewards: Tensor | None = None
|
|
75
|
-
actions: tuple[
|
|
76
|
-
log_probs: tuple[
|
|
77
|
-
|
|
90
|
+
actions: tuple[MaybeTensor, MaybeTensor] | None = None
|
|
91
|
+
log_probs: tuple[MaybeTensor, MaybeTensor] | None = None
|
|
92
|
+
old_action_unembeds: tuple[MaybeTensor, MaybeTensor] | None = None
|
|
93
|
+
values: MaybeTensor = None
|
|
78
94
|
step_size: int | None = None
|
|
95
|
+
lens: MaybeTensor = None
|
|
96
|
+
is_truncated: MaybeTensor = None
|
|
79
97
|
agent_index: int = 0
|
|
80
98
|
is_from_world_model: bool = True
|
|
81
99
|
|
|
100
|
+
def cpu(self):
|
|
101
|
+
return self.to(torch.device('cpu'))
|
|
102
|
+
|
|
103
|
+
def to(self, device):
|
|
104
|
+
experience_dict = asdict(self)
|
|
105
|
+
experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
|
|
106
|
+
return Experience(**experience_dict)
|
|
107
|
+
|
|
108
|
+
def combine_experiences(
|
|
109
|
+
exps: list[Experiences]
|
|
110
|
+
) -> Experience:
|
|
111
|
+
|
|
112
|
+
assert len(exps) > 0
|
|
113
|
+
|
|
114
|
+
# set lens if not there
|
|
115
|
+
|
|
116
|
+
for exp in exps:
|
|
117
|
+
latents = exp.latents
|
|
118
|
+
batch, time, device = *latents.shape[:2], latents.device
|
|
119
|
+
|
|
120
|
+
if not exists(exp.lens):
|
|
121
|
+
exp.lens = full((batch,), time, device = device)
|
|
122
|
+
|
|
123
|
+
if not exists(exp.is_truncated):
|
|
124
|
+
exp.is_truncated = full((batch,), True, device = device)
|
|
125
|
+
|
|
126
|
+
# convert to dictionary
|
|
127
|
+
|
|
128
|
+
exps_dict = [asdict(exp) for exp in exps]
|
|
129
|
+
|
|
130
|
+
values, tree_specs = zip(*[tree_flatten(exp_dict) for exp_dict in exps_dict])
|
|
131
|
+
|
|
132
|
+
tree_spec = first(tree_specs)
|
|
133
|
+
|
|
134
|
+
all_field_values = list(zip(*values))
|
|
135
|
+
|
|
136
|
+
# an assert to make sure all fields are either all tensors, or a single matching value (for step size, agent index etc) - can change this later
|
|
137
|
+
|
|
138
|
+
assert all([
|
|
139
|
+
all([is_tensor(v) for v in field_values]) or len(set(field_values)) == 1
|
|
140
|
+
for field_values in all_field_values
|
|
141
|
+
])
|
|
142
|
+
|
|
143
|
+
concatted = []
|
|
144
|
+
|
|
145
|
+
for field_values in all_field_values:
|
|
146
|
+
|
|
147
|
+
if is_tensor(first(field_values)):
|
|
148
|
+
|
|
149
|
+
field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2))
|
|
150
|
+
|
|
151
|
+
new_field_value = cat(field_values)
|
|
152
|
+
else:
|
|
153
|
+
new_field_value = first(list(set(field_values)))
|
|
154
|
+
|
|
155
|
+
concatted.append(new_field_value)
|
|
156
|
+
|
|
157
|
+
# return experience
|
|
158
|
+
|
|
159
|
+
concat_exp_dict = tree_unflatten(concatted, tree_spec)
|
|
160
|
+
|
|
161
|
+
return Experience(**concat_exp_dict)
|
|
162
|
+
|
|
82
163
|
# helpers
|
|
83
164
|
|
|
84
165
|
def exists(v):
|
|
@@ -90,6 +171,9 @@ def default(v, d):
|
|
|
90
171
|
def first(arr):
|
|
91
172
|
return arr[0]
|
|
92
173
|
|
|
174
|
+
def xnor(x, y):
|
|
175
|
+
return not (x ^ y)
|
|
176
|
+
|
|
93
177
|
def has_at_least_one(*bools):
|
|
94
178
|
return sum([*map(int, bools)]) > 0
|
|
95
179
|
|
|
@@ -105,14 +189,55 @@ def sample_prob(prob):
|
|
|
105
189
|
def is_power_two(num):
|
|
106
190
|
return log2(num).is_integer()
|
|
107
191
|
|
|
192
|
+
def maybe(fn):
|
|
193
|
+
def inner(t, *args, **kwargs):
|
|
194
|
+
if not exists(t) or not exists(fn):
|
|
195
|
+
return None
|
|
196
|
+
return fn(t)
|
|
197
|
+
return inner
|
|
198
|
+
|
|
108
199
|
# tensor helpers
|
|
109
200
|
|
|
110
201
|
def is_empty(t):
|
|
111
202
|
return t.numel() == 0
|
|
112
203
|
|
|
204
|
+
def lens_to_mask(t, max_len = None):
|
|
205
|
+
if not exists(max_len):
|
|
206
|
+
max_len = t.amax().item()
|
|
207
|
+
|
|
208
|
+
device = t.device
|
|
209
|
+
seq = torch.arange(max_len, device = device)
|
|
210
|
+
|
|
211
|
+
return einx.less('j, i -> i j', seq, t)
|
|
212
|
+
|
|
213
|
+
def masked_mean(t, mask = None):
|
|
214
|
+
if not exists(mask):
|
|
215
|
+
return t.mean()
|
|
216
|
+
|
|
217
|
+
if not mask.any():
|
|
218
|
+
return t[mask].sum()
|
|
219
|
+
|
|
220
|
+
return t[mask].mean()
|
|
221
|
+
|
|
113
222
|
def log(t, eps = 1e-20):
|
|
114
223
|
return t.clamp(min = eps).log()
|
|
115
224
|
|
|
225
|
+
def mean_log_var_to_distr(
|
|
226
|
+
mean_log_var: Tensor
|
|
227
|
+
) -> Normal:
|
|
228
|
+
|
|
229
|
+
mean, log_var = mean_log_var.unbind(dim = -1)
|
|
230
|
+
std = (0.5 * log_var).exp()
|
|
231
|
+
return Normal(mean, std)
|
|
232
|
+
|
|
233
|
+
def safe_stack(tensors, dim = 0):
|
|
234
|
+
tensors = [*filter(exists, tensors)]
|
|
235
|
+
|
|
236
|
+
if len(tensors) == 0:
|
|
237
|
+
return None
|
|
238
|
+
|
|
239
|
+
return stack(tensors, dim = dim)
|
|
240
|
+
|
|
116
241
|
def safe_cat(tensors, dim):
|
|
117
242
|
tensors = [*filter(exists, tensors)]
|
|
118
243
|
|
|
@@ -123,6 +248,15 @@ def safe_cat(tensors, dim):
|
|
|
123
248
|
|
|
124
249
|
return cat(tensors, dim = dim)
|
|
125
250
|
|
|
251
|
+
def safe_squeeze_first(t):
|
|
252
|
+
if not exists(t):
|
|
253
|
+
return None
|
|
254
|
+
|
|
255
|
+
if t.shape[0] != 1:
|
|
256
|
+
return t
|
|
257
|
+
|
|
258
|
+
return rearrange(t, '1 ... -> ...')
|
|
259
|
+
|
|
126
260
|
def gumbel_noise(t):
|
|
127
261
|
noise = torch.rand_like(t)
|
|
128
262
|
return -log(-log(noise))
|
|
@@ -159,6 +293,27 @@ def pad_at_dim(
|
|
|
159
293
|
zeros = ((0, 0) * dims_from_right)
|
|
160
294
|
return F.pad(t, (*zeros, *pad), value = value)
|
|
161
295
|
|
|
296
|
+
def pad_to_len(t, target_len, *, dim):
|
|
297
|
+
curr_len = t.shape[dim]
|
|
298
|
+
|
|
299
|
+
if curr_len >= target_len:
|
|
300
|
+
return t
|
|
301
|
+
|
|
302
|
+
return pad_at_dim(t, (0, target_len - curr_len), dim = dim)
|
|
303
|
+
|
|
304
|
+
def pad_tensors_at_dim_to_max_len(
|
|
305
|
+
tensors: list[Tensor],
|
|
306
|
+
dims: tuple[int, ...]
|
|
307
|
+
):
|
|
308
|
+
for dim in dims:
|
|
309
|
+
if dim >= first(tensors).ndim:
|
|
310
|
+
continue
|
|
311
|
+
|
|
312
|
+
max_time = max([t.shape[dim] for t in tensors])
|
|
313
|
+
tensors = [pad_to_len(t, max_time, dim = dim) for t in tensors]
|
|
314
|
+
|
|
315
|
+
return tensors
|
|
316
|
+
|
|
162
317
|
def align_dims_left(t, aligned_to):
|
|
163
318
|
shape = t.shape
|
|
164
319
|
num_right_dims = aligned_to.ndim - t.ndim
|
|
@@ -174,8 +329,74 @@ def l2norm(t):
|
|
|
174
329
|
def softclamp(t, value = 50.):
|
|
175
330
|
return (t / value).tanh() * value
|
|
176
331
|
|
|
332
|
+
def create_multi_token_prediction_targets(
|
|
333
|
+
t, # (b t ...)
|
|
334
|
+
steps_future,
|
|
335
|
+
|
|
336
|
+
): # (b t-1 steps ...), (b t-1 steps) - targets and the mask, where mask is False for padding
|
|
337
|
+
|
|
338
|
+
batch, seq_len, device = *t.shape[:2], t.device
|
|
339
|
+
|
|
340
|
+
batch_arange = arange(batch, device = device)
|
|
341
|
+
seq_arange = arange(seq_len, device = device)
|
|
342
|
+
steps_arange = arange(steps_future, device = device)
|
|
343
|
+
|
|
344
|
+
indices = add('t, steps -> t steps', seq_arange, steps_arange)
|
|
345
|
+
mask = indices < seq_len
|
|
346
|
+
|
|
347
|
+
batch_arange = rearrange(batch_arange, 'b -> b 1 1')
|
|
348
|
+
|
|
349
|
+
indices[~mask] = 0
|
|
350
|
+
mask = repeat(mask, 't steps -> b t steps', b = batch)
|
|
351
|
+
|
|
352
|
+
out = t[batch_arange, indices]
|
|
353
|
+
|
|
354
|
+
return out, mask
|
|
355
|
+
|
|
177
356
|
# loss related
|
|
178
357
|
|
|
358
|
+
class LossNormalizer(Module):
|
|
359
|
+
|
|
360
|
+
# the authors mentioned the need for loss normalization in the dynamics transformer
|
|
361
|
+
|
|
362
|
+
def __init__(
|
|
363
|
+
self,
|
|
364
|
+
num_losses: int,
|
|
365
|
+
beta = 0.95,
|
|
366
|
+
eps = 1e-6
|
|
367
|
+
):
|
|
368
|
+
super().__init__()
|
|
369
|
+
self.register_buffer('exp_avg_sq', torch.ones(num_losses))
|
|
370
|
+
self.beta = beta
|
|
371
|
+
self.eps = eps
|
|
372
|
+
|
|
373
|
+
def forward(
|
|
374
|
+
self,
|
|
375
|
+
losses: Tensor | list[Tensor] | dict[str, Tensor],
|
|
376
|
+
update_ema = None
|
|
377
|
+
):
|
|
378
|
+
exp_avg_sq, beta = self.exp_avg_sq, self.beta
|
|
379
|
+
update_ema = default(update_ema, self.training)
|
|
380
|
+
|
|
381
|
+
# get the rms value - as mentioned at the end of section 3 in the paper
|
|
382
|
+
|
|
383
|
+
rms = exp_avg_sq.sqrt()
|
|
384
|
+
|
|
385
|
+
if update_ema:
|
|
386
|
+
decay = 1. - beta
|
|
387
|
+
|
|
388
|
+
# update the ema
|
|
389
|
+
|
|
390
|
+
exp_avg_sq.lerp_(losses.detach().square(), decay)
|
|
391
|
+
|
|
392
|
+
# then normalize
|
|
393
|
+
|
|
394
|
+
assert losses.numel() == rms.numel()
|
|
395
|
+
|
|
396
|
+
normed_losses = losses / rms.clamp(min = self.eps)
|
|
397
|
+
|
|
398
|
+
return normed_losses
|
|
399
|
+
|
|
179
400
|
class LPIPSLoss(Module):
|
|
180
401
|
def __init__(
|
|
181
402
|
self,
|
|
@@ -340,7 +561,9 @@ class ActionEmbedder(Module):
|
|
|
340
561
|
num_continuous_actions = 0,
|
|
341
562
|
continuous_norm_stats: tuple[tuple[float, float], ...] | None = None,
|
|
342
563
|
can_unembed = False,
|
|
343
|
-
unembed_dim = None
|
|
564
|
+
unembed_dim = None,
|
|
565
|
+
num_unembed_preds = 1,
|
|
566
|
+
squeeze_unembed_preds = True # will auto-squeeze if prediction is just 1
|
|
344
567
|
):
|
|
345
568
|
super().__init__()
|
|
346
569
|
|
|
@@ -378,11 +601,14 @@ class ActionEmbedder(Module):
|
|
|
378
601
|
|
|
379
602
|
self.can_unembed = can_unembed
|
|
380
603
|
|
|
604
|
+
self.num_unembed_preds = num_unembed_preds
|
|
605
|
+
self.squeeze_unembed_preds = squeeze_unembed_preds
|
|
606
|
+
|
|
381
607
|
if not can_unembed:
|
|
382
608
|
return
|
|
383
609
|
|
|
384
610
|
unembed_dim = default(unembed_dim, dim)
|
|
385
|
-
self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, unembed_dim) * 1e-2)
|
|
611
|
+
self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, num_unembed_preds, unembed_dim) * 1e-2)
|
|
386
612
|
|
|
387
613
|
discrete_action_index = arange(total_discrete_actions)
|
|
388
614
|
|
|
@@ -396,13 +622,13 @@ class ActionEmbedder(Module):
|
|
|
396
622
|
|
|
397
623
|
self.register_buffer('discrete_action_mask', discrete_action_mask, persistent = False)
|
|
398
624
|
|
|
399
|
-
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2)
|
|
625
|
+
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, num_unembed_preds, unembed_dim, 2) * 1e-2)
|
|
400
626
|
|
|
401
627
|
def embed_parameters(self):
|
|
402
628
|
return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()])
|
|
403
629
|
|
|
404
630
|
def unembed_parameters(self):
|
|
405
|
-
return set([
|
|
631
|
+
return set([self.discrete_action_unembed, self.continuous_action_unembed])
|
|
406
632
|
|
|
407
633
|
@property
|
|
408
634
|
def device(self):
|
|
@@ -429,12 +655,26 @@ class ActionEmbedder(Module):
|
|
|
429
655
|
embeds, # (... d)
|
|
430
656
|
discrete_action_types = None, # (na)
|
|
431
657
|
continuous_action_types = None, # (na)
|
|
432
|
-
return_split_discrete = False
|
|
658
|
+
return_split_discrete = False,
|
|
659
|
+
pred_head_index: int | Tensor | None = None
|
|
433
660
|
|
|
434
661
|
): # (... discrete_na), (... continuous_na 2)
|
|
435
662
|
|
|
663
|
+
device = embeds.device
|
|
664
|
+
|
|
436
665
|
assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init'
|
|
437
666
|
|
|
667
|
+
# handle only one prediction head during inference
|
|
668
|
+
|
|
669
|
+
if exists(pred_head_index) and isinstance(pred_head_index, int):
|
|
670
|
+
pred_head_index = tensor(pred_head_index, device = device)
|
|
671
|
+
|
|
672
|
+
# if pred_head_index given as a solo int, just assume we want to squeeze out the prediction head dimension
|
|
673
|
+
|
|
674
|
+
squeeze_one_pred_head = exists(pred_head_index) and pred_head_index.ndim == 0
|
|
675
|
+
|
|
676
|
+
# get action types
|
|
677
|
+
|
|
438
678
|
discrete_action_types, continuous_action_types = tuple(self.cast_action_types(t) for t in (discrete_action_types, continuous_action_types))
|
|
439
679
|
|
|
440
680
|
# discrete actions
|
|
@@ -450,7 +690,13 @@ class ActionEmbedder(Module):
|
|
|
450
690
|
|
|
451
691
|
discrete_action_unembed = discrete_action_unembed[discrete_action_mask]
|
|
452
692
|
|
|
453
|
-
|
|
693
|
+
if exists(pred_head_index):
|
|
694
|
+
discrete_action_unembed = discrete_action_unembed.index_select(1, pred_head_index)
|
|
695
|
+
|
|
696
|
+
discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na mtp d -> mtp ... na')
|
|
697
|
+
|
|
698
|
+
if self.squeeze_unembed_preds or squeeze_one_pred_head:
|
|
699
|
+
discrete_action_logits = safe_squeeze_first(discrete_action_logits)
|
|
454
700
|
|
|
455
701
|
# whether to split the discrete action logits by the number of actions per action type
|
|
456
702
|
|
|
@@ -471,7 +717,13 @@ class ActionEmbedder(Module):
|
|
|
471
717
|
if exists(continuous_action_types):
|
|
472
718
|
continuous_action_unembed = continuous_action_unembed[continuous_action_types]
|
|
473
719
|
|
|
474
|
-
|
|
720
|
+
if exists(pred_head_index):
|
|
721
|
+
continuous_action_unembed = continuous_action_unembed.index_select(1, pred_head_index)
|
|
722
|
+
|
|
723
|
+
continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na mtp d two -> mtp ... na two')
|
|
724
|
+
|
|
725
|
+
if self.squeeze_unembed_preds or squeeze_one_pred_head:
|
|
726
|
+
continuous_action_mean_log_var = safe_squeeze_first(continuous_action_mean_log_var)
|
|
475
727
|
|
|
476
728
|
return discrete_action_logits, continuous_action_mean_log_var
|
|
477
729
|
|
|
@@ -480,9 +732,14 @@ class ActionEmbedder(Module):
|
|
|
480
732
|
embed,
|
|
481
733
|
discrete_temperature = 1.,
|
|
482
734
|
continuous_temperature = 1.,
|
|
735
|
+
inverse_norm_continuous = None,
|
|
736
|
+
pred_head_index: int | Tensor | None = None,
|
|
737
|
+
squeeze = True,
|
|
483
738
|
**kwargs
|
|
484
739
|
):
|
|
485
|
-
|
|
740
|
+
inverse_norm_continuous = default(inverse_norm_continuous, self.continuous_need_norm)
|
|
741
|
+
|
|
742
|
+
discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, pred_head_index = pred_head_index, **kwargs)
|
|
486
743
|
|
|
487
744
|
sampled_discrete = sampled_continuous = None
|
|
488
745
|
|
|
@@ -500,6 +757,12 @@ class ActionEmbedder(Module):
|
|
|
500
757
|
|
|
501
758
|
sampled_continuous = mean + std * torch.randn_like(mean) * continuous_temperature
|
|
502
759
|
|
|
760
|
+
# maybe inverse norm
|
|
761
|
+
|
|
762
|
+
if inverse_norm_continuous:
|
|
763
|
+
norm_mean, norm_std = self.continuous_norm_stats.unbind(dim = -1)
|
|
764
|
+
sampled_continuous = (sampled_continuous * norm_std) + norm_mean
|
|
765
|
+
|
|
503
766
|
return sampled_discrete, sampled_continuous
|
|
504
767
|
|
|
505
768
|
def log_probs(
|
|
@@ -509,12 +772,13 @@ class ActionEmbedder(Module):
|
|
|
509
772
|
continuous_targets = None, # (... na)
|
|
510
773
|
discrete_action_types = None, # (na)
|
|
511
774
|
continuous_action_types = None, # (na)
|
|
775
|
+
pred_head_index: int | Tensor | None = None,
|
|
512
776
|
parallel_discrete_calc = None,
|
|
513
777
|
return_entropies = False
|
|
514
778
|
):
|
|
515
779
|
parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1)
|
|
516
780
|
|
|
517
|
-
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
|
|
781
|
+
discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, pred_head_index = pred_head_index, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
|
|
518
782
|
|
|
519
783
|
# discrete
|
|
520
784
|
|
|
@@ -600,10 +864,7 @@ class ActionEmbedder(Module):
|
|
|
600
864
|
continuous_entropies = None
|
|
601
865
|
|
|
602
866
|
if exists(continuous_targets):
|
|
603
|
-
|
|
604
|
-
std = (0.5 * log_var).exp()
|
|
605
|
-
|
|
606
|
-
distr = Normal(mean, std)
|
|
867
|
+
distr = mean_log_var_to_distr(continuous_action_mean_log_var)
|
|
607
868
|
continuous_log_probs = distr.log_prob(continuous_targets)
|
|
608
869
|
|
|
609
870
|
if return_entropies:
|
|
@@ -618,6 +879,64 @@ class ActionEmbedder(Module):
|
|
|
618
879
|
|
|
619
880
|
return log_probs, entropies
|
|
620
881
|
|
|
882
|
+
def kl_div(
|
|
883
|
+
self,
|
|
884
|
+
src: tuple[MaybeTensor, MaybeTensor],
|
|
885
|
+
tgt: tuple[MaybeTensor, MaybeTensor],
|
|
886
|
+
reduce_across_num_actions = True
|
|
887
|
+
) -> tuple[MaybeTensor, MaybeTensor]:
|
|
888
|
+
|
|
889
|
+
src_discrete, src_continuous = src
|
|
890
|
+
tgt_discrete, tgt_continuous = tgt
|
|
891
|
+
|
|
892
|
+
discrete_kl_div = None
|
|
893
|
+
|
|
894
|
+
# split discrete if it is not already (multiple discrete actions)
|
|
895
|
+
|
|
896
|
+
if exists(src_discrete):
|
|
897
|
+
|
|
898
|
+
discrete_split = self.num_discrete_actions.tolist()
|
|
899
|
+
|
|
900
|
+
if is_tensor(src_discrete):
|
|
901
|
+
src_discrete = src_discrete.split(discrete_split, dim = -1)
|
|
902
|
+
|
|
903
|
+
if is_tensor(tgt_discrete):
|
|
904
|
+
tgt_discrete = tgt_discrete.split(discrete_split, dim = -1)
|
|
905
|
+
|
|
906
|
+
discrete_kl_divs = []
|
|
907
|
+
|
|
908
|
+
for src_logit, tgt_logit in zip(src_discrete, tgt_discrete):
|
|
909
|
+
|
|
910
|
+
src_log_probs = src_logit.log_softmax(dim = -1)
|
|
911
|
+
tgt_prob = tgt_logit.softmax(dim = -1)
|
|
912
|
+
|
|
913
|
+
one_discrete_kl_div = F.kl_div(src_log_probs, tgt_prob, reduction = 'none')
|
|
914
|
+
|
|
915
|
+
discrete_kl_divs.append(one_discrete_kl_div.sum(dim = -1))
|
|
916
|
+
|
|
917
|
+
discrete_kl_div = stack(discrete_kl_divs, dim = -1)
|
|
918
|
+
|
|
919
|
+
# calculate kl divergence for continuous
|
|
920
|
+
|
|
921
|
+
continuous_kl_div = None
|
|
922
|
+
|
|
923
|
+
if exists(src_continuous):
|
|
924
|
+
src_normal = mean_log_var_to_distr(src_continuous)
|
|
925
|
+
tgt_normal = mean_log_var_to_distr(tgt_continuous)
|
|
926
|
+
|
|
927
|
+
continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
|
|
928
|
+
|
|
929
|
+
# maybe reduce
|
|
930
|
+
|
|
931
|
+
if reduce_across_num_actions:
|
|
932
|
+
if exists(discrete_kl_div):
|
|
933
|
+
discrete_kl_div = discrete_kl_div.sum(dim = -1)
|
|
934
|
+
|
|
935
|
+
if exists(continuous_kl_div):
|
|
936
|
+
continuous_kl_div = continuous_kl_div.sum(dim = -1)
|
|
937
|
+
|
|
938
|
+
return discrete_kl_div, continuous_kl_div
|
|
939
|
+
|
|
621
940
|
def forward(
|
|
622
941
|
self,
|
|
623
942
|
*,
|
|
@@ -726,11 +1045,12 @@ class Rotary1D(Module):
|
|
|
726
1045
|
|
|
727
1046
|
def forward(
|
|
728
1047
|
self,
|
|
729
|
-
seq_len
|
|
1048
|
+
seq_len,
|
|
1049
|
+
offset = 0
|
|
730
1050
|
):
|
|
731
1051
|
device, dtype = self.inv_freq.device, self.inv_freq.dtype
|
|
732
1052
|
|
|
733
|
-
t = torch.arange(seq_len, device = device).type(dtype)
|
|
1053
|
+
t = torch.arange(seq_len, device = device).type(dtype) + offset
|
|
734
1054
|
freqs = einsum(t, self.inv_freq, 'i, j -> i j')
|
|
735
1055
|
|
|
736
1056
|
return cat((freqs, freqs), dim = -1)
|
|
@@ -740,7 +1060,18 @@ def apply_rotations(
|
|
|
740
1060
|
rotations, # (h n d) | (n d)
|
|
741
1061
|
t # (b h n d)
|
|
742
1062
|
):
|
|
743
|
-
|
|
1063
|
+
|
|
1064
|
+
heads, seq_len, dtype = *t.shape[1:3], t.dtype
|
|
1065
|
+
|
|
1066
|
+
rotations_seq_len = rotations.shape[-2]
|
|
1067
|
+
|
|
1068
|
+
# handle kv caching with rotations
|
|
1069
|
+
|
|
1070
|
+
if rotations_seq_len > seq_len:
|
|
1071
|
+
rotations = rotations[-seq_len:]
|
|
1072
|
+
|
|
1073
|
+
# precision
|
|
1074
|
+
|
|
744
1075
|
t = t.float()
|
|
745
1076
|
|
|
746
1077
|
# handle gqa for rotary
|
|
@@ -877,10 +1208,11 @@ def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = F
|
|
|
877
1208
|
|
|
878
1209
|
def block_mask_special_tokens_right(
|
|
879
1210
|
seq_len,
|
|
880
|
-
num_tokens
|
|
1211
|
+
num_tokens,
|
|
1212
|
+
special_attend_only_itself = False
|
|
881
1213
|
):
|
|
882
1214
|
def inner(b, h, q, k):
|
|
883
|
-
return special_token_mask(q, k, seq_len, num_tokens)
|
|
1215
|
+
return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself)
|
|
884
1216
|
return inner
|
|
885
1217
|
|
|
886
1218
|
def compose_mask(mask1, mask2):
|
|
@@ -957,6 +1289,10 @@ class Attention(Module):
|
|
|
957
1289
|
query_heads = None,
|
|
958
1290
|
heads = 8,
|
|
959
1291
|
pre_rmsnorm = True,
|
|
1292
|
+
gate_values = True,
|
|
1293
|
+
rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2
|
|
1294
|
+
rmsnorm_key = True,
|
|
1295
|
+
value_residual = True
|
|
960
1296
|
):
|
|
961
1297
|
super().__init__()
|
|
962
1298
|
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
|
@@ -979,12 +1315,33 @@ class Attention(Module):
|
|
|
979
1315
|
self.to_v = LinearNoBias(dim, dim_kv_inner)
|
|
980
1316
|
self.to_out = LinearNoBias(dim_q_inner, dim)
|
|
981
1317
|
|
|
1318
|
+
# alphafold gating per head, for attending to nothing
|
|
1319
|
+
|
|
1320
|
+
self.to_gates = None
|
|
1321
|
+
|
|
1322
|
+
if gate_values:
|
|
1323
|
+
self.to_gates = Sequential(
|
|
1324
|
+
LinearNoBias(dim, query_heads),
|
|
1325
|
+
Rearrange('b n h -> b h n 1'),
|
|
1326
|
+
nn.Sigmoid()
|
|
1327
|
+
)
|
|
1328
|
+
|
|
982
1329
|
# stability related
|
|
983
1330
|
|
|
984
|
-
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
|
|
985
|
-
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
|
|
1331
|
+
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity()
|
|
1332
|
+
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
|
|
1333
|
+
|
|
1334
|
+
# value residual
|
|
1335
|
+
|
|
1336
|
+
self.to_learned_value_residual_mix = nn.Sequential(
|
|
1337
|
+
nn.Linear(dim, heads),
|
|
1338
|
+
Rearrange('b n h -> b h n 1'),
|
|
1339
|
+
nn.Sigmoid()
|
|
1340
|
+
) if value_residual else None
|
|
986
1341
|
|
|
987
1342
|
def muon_parameters(self):
|
|
1343
|
+
# omit the queries and keys for now given what we learned from kimi 2 paper
|
|
1344
|
+
|
|
988
1345
|
return [
|
|
989
1346
|
*self.to_v.parameters(),
|
|
990
1347
|
*self.to_out.parameters(),
|
|
@@ -994,8 +1351,9 @@ class Attention(Module):
|
|
|
994
1351
|
self,
|
|
995
1352
|
tokens, # (b n d)
|
|
996
1353
|
kv_cache = None,
|
|
997
|
-
|
|
1354
|
+
return_intermediates = False,
|
|
998
1355
|
rotary_pos_emb = None,
|
|
1356
|
+
residual_values = None, # (b n h d)
|
|
999
1357
|
attend_fn: Callable | None = None
|
|
1000
1358
|
):
|
|
1001
1359
|
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
|
|
@@ -1008,11 +1366,28 @@ class Attention(Module):
|
|
|
1008
1366
|
|
|
1009
1367
|
q, k, v = map(self.split_heads, (q, k, v))
|
|
1010
1368
|
|
|
1369
|
+
# handle maybe value residual
|
|
1370
|
+
|
|
1371
|
+
if exists(residual_values):
|
|
1372
|
+
residual_values = rearrange(residual_values, '... n h d -> (...) h n d')
|
|
1373
|
+
|
|
1374
|
+
assert exists(self.to_learned_value_residual_mix)
|
|
1375
|
+
|
|
1376
|
+
learned_mix = self.to_learned_value_residual_mix(tokens)
|
|
1377
|
+
|
|
1378
|
+
v = v.lerp(residual_values, learned_mix)
|
|
1379
|
+
|
|
1011
1380
|
# qk rmsnorm
|
|
1012
1381
|
|
|
1013
1382
|
q = self.q_heads_rmsnorm(q)
|
|
1014
1383
|
k = self.k_heads_rmsnorm(k)
|
|
1015
1384
|
|
|
1385
|
+
# rotary
|
|
1386
|
+
|
|
1387
|
+
if exists(rotary_pos_emb):
|
|
1388
|
+
q = apply_rotations(rotary_pos_emb, q)
|
|
1389
|
+
k = apply_rotations(rotary_pos_emb, k)
|
|
1390
|
+
|
|
1016
1391
|
# caching
|
|
1017
1392
|
|
|
1018
1393
|
if exists(kv_cache):
|
|
@@ -1020,18 +1395,18 @@ class Attention(Module):
|
|
|
1020
1395
|
k = cat((ck, k), dim = -2)
|
|
1021
1396
|
v = cat((cv, v), dim = -2)
|
|
1022
1397
|
|
|
1023
|
-
# rotary
|
|
1024
|
-
|
|
1025
|
-
if exists(rotary_pos_emb):
|
|
1026
|
-
q = apply_rotations(rotary_pos_emb, q)
|
|
1027
|
-
k = apply_rotations(rotary_pos_emb, k)
|
|
1028
|
-
|
|
1029
1398
|
# attention
|
|
1030
1399
|
|
|
1031
1400
|
attend_fn = default(attend_fn, naive_attend)
|
|
1032
1401
|
|
|
1033
1402
|
out = attend_fn(q, k, v)
|
|
1034
1403
|
|
|
1404
|
+
# gate values
|
|
1405
|
+
|
|
1406
|
+
if exists(self.to_gates):
|
|
1407
|
+
gates = self.to_gates(tokens)
|
|
1408
|
+
out = out * gates
|
|
1409
|
+
|
|
1035
1410
|
# merge heads
|
|
1036
1411
|
|
|
1037
1412
|
out = self.merge_heads(out)
|
|
@@ -1042,10 +1417,10 @@ class Attention(Module):
|
|
|
1042
1417
|
|
|
1043
1418
|
out = inverse_packed_batch(out)
|
|
1044
1419
|
|
|
1045
|
-
if not
|
|
1420
|
+
if not return_intermediates:
|
|
1046
1421
|
return out
|
|
1047
1422
|
|
|
1048
|
-
return out, stack((k, v))
|
|
1423
|
+
return out, AttentionIntermediates(stack((k, v)), tokens)
|
|
1049
1424
|
|
|
1050
1425
|
# feedforward
|
|
1051
1426
|
|
|
@@ -1085,6 +1460,7 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1085
1460
|
self,
|
|
1086
1461
|
dim,
|
|
1087
1462
|
depth,
|
|
1463
|
+
attn_heads = 8,
|
|
1088
1464
|
attn_dim_head = 64,
|
|
1089
1465
|
attn_softclamp_value = 50.,
|
|
1090
1466
|
time_block_every = 4,
|
|
@@ -1093,9 +1469,12 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1093
1469
|
num_residual_streams = 1,
|
|
1094
1470
|
num_special_spatial_tokens = 1,
|
|
1095
1471
|
special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
|
|
1096
|
-
final_norm = True
|
|
1472
|
+
final_norm = True,
|
|
1473
|
+
value_residual = True, # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS
|
|
1474
|
+
rnn_time = False
|
|
1097
1475
|
):
|
|
1098
1476
|
super().__init__()
|
|
1477
|
+
assert depth >= time_block_every, f'depth must be at least {time_block_every}'
|
|
1099
1478
|
|
|
1100
1479
|
# hyper connections
|
|
1101
1480
|
|
|
@@ -1113,6 +1492,24 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1113
1492
|
|
|
1114
1493
|
self.time_rotary = Rotary1D(attn_dim_head)
|
|
1115
1494
|
|
|
1495
|
+
# project initial for value residuals
|
|
1496
|
+
|
|
1497
|
+
self.value_residual = value_residual
|
|
1498
|
+
|
|
1499
|
+
if value_residual:
|
|
1500
|
+
dim_inner = attn_dim_head * attn_heads
|
|
1501
|
+
|
|
1502
|
+
self.to_value_residual = nn.Sequential(
|
|
1503
|
+
nn.RMSNorm(dim),
|
|
1504
|
+
nn.Linear(dim, dim_inner, bias = False),
|
|
1505
|
+
Rearrange('... (h d) -> ... h d', h = attn_heads)
|
|
1506
|
+
)
|
|
1507
|
+
|
|
1508
|
+
# a gru layer across time
|
|
1509
|
+
|
|
1510
|
+
self.rnn_time = rnn_time
|
|
1511
|
+
rnn_layers = []
|
|
1512
|
+
|
|
1116
1513
|
# transformer
|
|
1117
1514
|
|
|
1118
1515
|
layers = []
|
|
@@ -1124,17 +1521,24 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1124
1521
|
is_time_block = divisible_by(layer_index, time_block_every)
|
|
1125
1522
|
is_time.append(is_time_block)
|
|
1126
1523
|
|
|
1127
|
-
rearrange_to_attend = Rearrange('b t s
|
|
1128
|
-
rearrange_from_attend = Rearrange('b s t
|
|
1524
|
+
rearrange_to_attend = Rearrange('b t s ... -> b s t ...') if is_time_block else Identity()
|
|
1525
|
+
rearrange_from_attend = Rearrange('b s t ... -> b t s ...') if is_time_block else Identity()
|
|
1129
1526
|
|
|
1130
1527
|
layers.append(ModuleList([
|
|
1131
1528
|
rearrange_to_attend,
|
|
1132
1529
|
rearrange_from_attend,
|
|
1133
|
-
hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
|
|
1530
|
+
hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, value_residual = value_residual, **attn_kwargs)),
|
|
1134
1531
|
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
|
1135
1532
|
]))
|
|
1136
1533
|
|
|
1534
|
+
rnn_layers.append(ModuleList([
|
|
1535
|
+
nn.RMSNorm(dim),
|
|
1536
|
+
nn.GRU(dim, dim, batch_first = True)
|
|
1537
|
+
]) if is_time_block and rnn_time else None)
|
|
1538
|
+
|
|
1137
1539
|
self.layers = ModuleList(layers)
|
|
1540
|
+
self.rnn_layers = ModuleList(rnn_layers)
|
|
1541
|
+
|
|
1138
1542
|
self.is_time = is_time
|
|
1139
1543
|
|
|
1140
1544
|
# final norm
|
|
@@ -1145,17 +1549,31 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1145
1549
|
|
|
1146
1550
|
self.num_special_spatial_tokens = num_special_spatial_tokens
|
|
1147
1551
|
|
|
1552
|
+
def muon_parameters(self):
|
|
1553
|
+
muon_params = []
|
|
1554
|
+
|
|
1555
|
+
for m in self.modules():
|
|
1556
|
+
if isinstance(m, (Attention, SwiGLUFeedforward)):
|
|
1557
|
+
muon_params.extend(m.muon_parameters())
|
|
1558
|
+
|
|
1559
|
+
return muon_params
|
|
1560
|
+
|
|
1148
1561
|
def forward(
|
|
1149
1562
|
self,
|
|
1150
|
-
tokens
|
|
1151
|
-
|
|
1563
|
+
tokens, # (b t s d)
|
|
1564
|
+
kv_cache: Tensor | None = None, # (y 2 b h t d)
|
|
1565
|
+
return_intermediates = False
|
|
1566
|
+
|
|
1567
|
+
): # (b t s d) | (y 2 b h t d)
|
|
1568
|
+
|
|
1152
1569
|
batch, time, space_seq_len, _, device = *tokens.shape, tokens.device
|
|
1153
1570
|
|
|
1154
1571
|
assert tokens.ndim == 4
|
|
1155
1572
|
|
|
1156
1573
|
# attend functions for space and time
|
|
1157
1574
|
|
|
1158
|
-
|
|
1575
|
+
has_kv_cache = exists(kv_cache)
|
|
1576
|
+
use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix
|
|
1159
1577
|
|
|
1160
1578
|
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
|
|
1161
1579
|
|
|
@@ -1163,37 +1581,120 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1163
1581
|
|
|
1164
1582
|
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
|
1165
1583
|
|
|
1584
|
+
# prepare cache
|
|
1585
|
+
|
|
1586
|
+
time_attn_kv_caches = []
|
|
1587
|
+
|
|
1588
|
+
if has_kv_cache:
|
|
1589
|
+
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
|
1590
|
+
|
|
1591
|
+
rotary_seq_len = 1
|
|
1592
|
+
rotary_pos_offset = past_tokens.shape[1]
|
|
1593
|
+
else:
|
|
1594
|
+
rotary_seq_len = time
|
|
1595
|
+
rotary_pos_offset = 0
|
|
1596
|
+
|
|
1597
|
+
kv_cache = default(kv_cache, (None,))
|
|
1598
|
+
|
|
1599
|
+
iter_kv_cache = iter(kv_cache)
|
|
1600
|
+
|
|
1166
1601
|
# rotary
|
|
1167
1602
|
|
|
1168
|
-
rotary_pos_emb = self.time_rotary(
|
|
1603
|
+
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
|
1604
|
+
|
|
1605
|
+
# value residual
|
|
1606
|
+
|
|
1607
|
+
residual_values = None
|
|
1608
|
+
|
|
1609
|
+
if self.value_residual:
|
|
1610
|
+
residual_values = self.to_value_residual(tokens)
|
|
1611
|
+
|
|
1612
|
+
# normed attention inputs
|
|
1613
|
+
|
|
1614
|
+
normed_time_attn_inputs = []
|
|
1615
|
+
normed_space_attn_inputs = []
|
|
1169
1616
|
|
|
1170
1617
|
# attention
|
|
1171
1618
|
|
|
1172
1619
|
tokens = self.expand_streams(tokens)
|
|
1173
1620
|
|
|
1174
|
-
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
|
1621
|
+
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), maybe_rnn_modules, layer_is_time in zip(self.layers, self.rnn_layers, self.is_time):
|
|
1175
1622
|
|
|
1176
1623
|
tokens = pre_attn_rearrange(tokens)
|
|
1177
1624
|
|
|
1625
|
+
# maybe rnn for time
|
|
1626
|
+
|
|
1627
|
+
if layer_is_time and exists(maybe_rnn_modules):
|
|
1628
|
+
rnn_prenorm, rnn = maybe_rnn_modules
|
|
1629
|
+
|
|
1630
|
+
rnn_input, inverse_pack_time = pack_one(tokens, '* t d')
|
|
1631
|
+
|
|
1632
|
+
rnn_out, rnn_hiddens = rnn(rnn_prenorm(rnn_input)) # todo, handle rnn cache
|
|
1633
|
+
|
|
1634
|
+
rnn_out = inverse_pack_time(rnn_out)
|
|
1635
|
+
|
|
1636
|
+
tokens = rnn_out + tokens
|
|
1637
|
+
|
|
1178
1638
|
# when is a axial time attention block, should be causal
|
|
1179
1639
|
|
|
1180
1640
|
attend_fn = time_attend if layer_is_time else space_attend
|
|
1181
1641
|
|
|
1182
1642
|
layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
|
|
1183
1643
|
|
|
1644
|
+
# maybe past kv cache
|
|
1645
|
+
|
|
1646
|
+
maybe_kv_cache = next(iter_kv_cache, None) if layer_is_time else None
|
|
1647
|
+
|
|
1648
|
+
# residual values
|
|
1649
|
+
|
|
1650
|
+
layer_residual_values = maybe(pre_attn_rearrange)(residual_values)
|
|
1651
|
+
|
|
1184
1652
|
# attention layer
|
|
1185
1653
|
|
|
1186
|
-
tokens = attn(
|
|
1654
|
+
tokens, attn_intermediates = attn(
|
|
1655
|
+
tokens,
|
|
1656
|
+
rotary_pos_emb = layer_rotary_pos_emb,
|
|
1657
|
+
attend_fn = attend_fn,
|
|
1658
|
+
kv_cache = maybe_kv_cache,
|
|
1659
|
+
residual_values = layer_residual_values,
|
|
1660
|
+
return_intermediates = True
|
|
1661
|
+
)
|
|
1187
1662
|
|
|
1188
1663
|
tokens = post_attn_rearrange(tokens)
|
|
1189
1664
|
|
|
1190
1665
|
# feedforward layer
|
|
1191
1666
|
|
|
1192
|
-
tokens = ff(tokens)
|
|
1667
|
+
tokens = ff(tokens)
|
|
1668
|
+
|
|
1669
|
+
# save kv cache if is time layer
|
|
1670
|
+
|
|
1671
|
+
if layer_is_time:
|
|
1672
|
+
time_attn_kv_caches.append(attn_intermediates.next_kv_cache)
|
|
1673
|
+
|
|
1674
|
+
# save time attention inputs for decorr
|
|
1675
|
+
|
|
1676
|
+
space_or_time_inputs = normed_time_attn_inputs if layer_is_time else normed_space_attn_inputs
|
|
1677
|
+
|
|
1678
|
+
space_or_time_inputs.append(attn_intermediates.normed_inputs)
|
|
1193
1679
|
|
|
1194
1680
|
tokens = self.reduce_streams(tokens)
|
|
1195
1681
|
|
|
1196
|
-
|
|
1682
|
+
out = self.final_norm(tokens)
|
|
1683
|
+
|
|
1684
|
+
if has_kv_cache:
|
|
1685
|
+
# just concat the past tokens back on for now, todo - clean up the logic
|
|
1686
|
+
out = cat((past_tokens, out), dim = 1)
|
|
1687
|
+
|
|
1688
|
+
if not return_intermediates:
|
|
1689
|
+
return out
|
|
1690
|
+
|
|
1691
|
+
intermediates = TransformerIntermediates(
|
|
1692
|
+
stack(time_attn_kv_caches),
|
|
1693
|
+
safe_stack(normed_time_attn_inputs),
|
|
1694
|
+
safe_stack(normed_space_attn_inputs)
|
|
1695
|
+
)
|
|
1696
|
+
|
|
1697
|
+
return out, intermediates
|
|
1197
1698
|
|
|
1198
1699
|
# video tokenizer
|
|
1199
1700
|
|
|
@@ -1219,12 +1720,15 @@ class VideoTokenizer(Module):
|
|
|
1219
1720
|
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
|
1220
1721
|
lpips_loss_network: Module | None = None,
|
|
1221
1722
|
lpips_loss_weight = 0.2,
|
|
1723
|
+
encoder_add_decor_aux_loss = False,
|
|
1724
|
+
decor_auxx_loss_weight = 0.1,
|
|
1725
|
+
decorr_sample_frac = 0.25,
|
|
1222
1726
|
nd_rotary_kwargs: dict = dict(
|
|
1223
1727
|
rope_min_freq = 1.,
|
|
1224
1728
|
rope_max_freq = 10000.,
|
|
1225
1729
|
rope_p_zero_freqs = 0.
|
|
1226
1730
|
),
|
|
1227
|
-
num_residual_streams = 1
|
|
1731
|
+
num_residual_streams = 1,
|
|
1228
1732
|
):
|
|
1229
1733
|
super().__init__()
|
|
1230
1734
|
|
|
@@ -1305,6 +1809,7 @@ class VideoTokenizer(Module):
|
|
|
1305
1809
|
time_block_every = time_block_every,
|
|
1306
1810
|
num_special_spatial_tokens = num_latent_tokens,
|
|
1307
1811
|
num_residual_streams = num_residual_streams,
|
|
1812
|
+
special_attend_only_itself = True,
|
|
1308
1813
|
final_norm = True
|
|
1309
1814
|
)
|
|
1310
1815
|
|
|
@@ -1318,10 +1823,24 @@ class VideoTokenizer(Module):
|
|
|
1318
1823
|
if self.has_lpips_loss:
|
|
1319
1824
|
self.lpips = LPIPSLoss(lpips_loss_network)
|
|
1320
1825
|
|
|
1826
|
+
# decorr aux loss
|
|
1827
|
+
# https://arxiv.org/abs/2510.14657
|
|
1828
|
+
|
|
1829
|
+
self.encoder_add_decor_aux_loss = encoder_add_decor_aux_loss
|
|
1830
|
+
self.decorr_aux_loss_weight = decor_auxx_loss_weight
|
|
1831
|
+
|
|
1832
|
+
self.decorr_loss = DecorrelationLoss(decorr_sample_frac, soft_validate_num_sampled = True) if encoder_add_decor_aux_loss else None
|
|
1833
|
+
|
|
1321
1834
|
@property
|
|
1322
1835
|
def device(self):
|
|
1323
1836
|
return self.zero.device
|
|
1324
1837
|
|
|
1838
|
+
def muon_parameters(self):
|
|
1839
|
+
return [
|
|
1840
|
+
*self.encoder_transformer.muon_parameters(),
|
|
1841
|
+
*self.decoder_transformer.muon_parameters()
|
|
1842
|
+
]
|
|
1843
|
+
|
|
1325
1844
|
@torch.no_grad()
|
|
1326
1845
|
def tokenize(
|
|
1327
1846
|
self,
|
|
@@ -1425,7 +1944,7 @@ class VideoTokenizer(Module):
|
|
|
1425
1944
|
|
|
1426
1945
|
# encoder attention
|
|
1427
1946
|
|
|
1428
|
-
tokens = self.encoder_transformer(tokens)
|
|
1947
|
+
tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
|
|
1429
1948
|
|
|
1430
1949
|
# latent bottleneck
|
|
1431
1950
|
|
|
@@ -1447,19 +1966,30 @@ class VideoTokenizer(Module):
|
|
|
1447
1966
|
if self.has_lpips_loss:
|
|
1448
1967
|
lpips_loss = self.lpips(video, recon_video)
|
|
1449
1968
|
|
|
1969
|
+
time_decorr_loss = space_decorr_loss = self.zero
|
|
1970
|
+
|
|
1971
|
+
if self.encoder_add_decor_aux_loss:
|
|
1972
|
+
if exists(time_attn_normed_inputs):
|
|
1973
|
+
time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
|
|
1974
|
+
|
|
1975
|
+
if exists(space_attn_normed_inputs):
|
|
1976
|
+
space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
|
|
1977
|
+
|
|
1450
1978
|
# losses
|
|
1451
1979
|
|
|
1452
1980
|
total_loss = (
|
|
1453
1981
|
recon_loss +
|
|
1454
|
-
lpips_loss * self.lpips_loss_weight
|
|
1982
|
+
lpips_loss * self.lpips_loss_weight +
|
|
1983
|
+
time_decorr_loss * self.decorr_aux_loss_weight +
|
|
1984
|
+
space_decorr_loss * self.decorr_aux_loss_weight
|
|
1455
1985
|
)
|
|
1456
1986
|
|
|
1457
1987
|
if not return_all_losses:
|
|
1458
1988
|
return total_loss
|
|
1459
1989
|
|
|
1460
|
-
losses = (recon_loss, lpips_loss)
|
|
1990
|
+
losses = (recon_loss, lpips_loss, decorr_loss)
|
|
1461
1991
|
|
|
1462
|
-
return total_loss, TokenizerLosses(losses)
|
|
1992
|
+
return total_loss, TokenizerLosses(*losses)
|
|
1463
1993
|
|
|
1464
1994
|
# dynamics model, axial space-time transformer
|
|
1465
1995
|
|
|
@@ -1475,13 +2005,15 @@ class DynamicsWorldModel(Module):
|
|
|
1475
2005
|
num_latent_tokens = None,
|
|
1476
2006
|
num_agents = 1,
|
|
1477
2007
|
num_tasks = 0,
|
|
2008
|
+
num_video_views = 1,
|
|
2009
|
+
dim_proprio = None,
|
|
1478
2010
|
reward_encoder_kwargs: dict = dict(),
|
|
1479
2011
|
depth = 4,
|
|
1480
2012
|
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
|
1481
2013
|
time_block_every = 4, # every 4th block is time
|
|
1482
|
-
attn_kwargs: dict = dict(
|
|
1483
|
-
|
|
1484
|
-
|
|
2014
|
+
attn_kwargs: dict = dict(),
|
|
2015
|
+
transformer_kwargs: dict = dict(),
|
|
2016
|
+
attn_heads = 8,
|
|
1485
2017
|
attn_dim_head = 64,
|
|
1486
2018
|
attn_softclamp_value = 50.,
|
|
1487
2019
|
ff_kwargs: dict = dict(),
|
|
@@ -1493,15 +2025,25 @@ class DynamicsWorldModel(Module):
|
|
|
1493
2025
|
num_discrete_actions: int | tuple[int, ...] = 0,
|
|
1494
2026
|
num_continuous_actions = 0,
|
|
1495
2027
|
continuous_norm_stats = None,
|
|
1496
|
-
|
|
2028
|
+
multi_token_pred_len = 8,
|
|
1497
2029
|
value_head_mlp_depth = 3,
|
|
1498
2030
|
policy_head_mlp_depth = 3,
|
|
1499
|
-
|
|
2031
|
+
latent_flow_loss_weight = 1.,
|
|
2032
|
+
reward_loss_weight: float | list[float] = 1.,
|
|
2033
|
+
discrete_action_loss_weight: float | list[float] = 1.,
|
|
2034
|
+
continuous_action_loss_weight: float | list[float] = 1.,
|
|
1500
2035
|
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
|
1501
2036
|
num_residual_streams = 1,
|
|
2037
|
+
keep_reward_ema_stats = False,
|
|
2038
|
+
reward_ema_decay = 0.998,
|
|
2039
|
+
reward_quantile_filter = (0.05, 0.95),
|
|
1502
2040
|
gae_discount_factor = 0.997,
|
|
1503
2041
|
gae_lambda = 0.95,
|
|
1504
2042
|
ppo_eps_clip = 0.2,
|
|
2043
|
+
pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
|
|
2044
|
+
pmpo_reverse_kl = True,
|
|
2045
|
+
pmpo_kl_div_loss_weight = .3,
|
|
2046
|
+
normalize_advantages = None,
|
|
1505
2047
|
value_clip = 0.4,
|
|
1506
2048
|
policy_entropy_weight = .01,
|
|
1507
2049
|
gae_use_accelerated = False
|
|
@@ -1535,7 +2077,7 @@ class DynamicsWorldModel(Module):
|
|
|
1535
2077
|
)
|
|
1536
2078
|
|
|
1537
2079
|
self.to_latent_pred = Sequential(
|
|
1538
|
-
Reduce('b t n s d -> b t n d', 'mean'),
|
|
2080
|
+
Reduce('b t v n s d -> b t v n d', 'mean'),
|
|
1539
2081
|
RMSNorm(dim),
|
|
1540
2082
|
LinearNoBias(dim, dim_latent)
|
|
1541
2083
|
)
|
|
@@ -1545,15 +2087,38 @@ class DynamicsWorldModel(Module):
|
|
|
1545
2087
|
latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
|
|
1546
2088
|
|
|
1547
2089
|
self.latents_to_spatial_tokens = Sequential(
|
|
1548
|
-
Rearrange('
|
|
2090
|
+
Rearrange('... n d -> ... (n d)'),
|
|
1549
2091
|
Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
|
|
1550
|
-
Rearrange('
|
|
2092
|
+
Rearrange('... (s d) -> ... s d', s = num_spatial_tokens)
|
|
1551
2093
|
)
|
|
1552
2094
|
|
|
1553
2095
|
self.to_latent_pred = Sequential(
|
|
1554
2096
|
RMSNorm(dim),
|
|
1555
2097
|
LinearNoBias(dim, dim_latent * latent_tokens_to_space),
|
|
1556
|
-
Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
|
|
2098
|
+
Rearrange('b t v s (n d) -> b t v (s n) d', n = latent_tokens_to_space)
|
|
2099
|
+
)
|
|
2100
|
+
|
|
2101
|
+
# number of video views, for robotics, which could have third person + wrist camera at least
|
|
2102
|
+
|
|
2103
|
+
assert num_video_views >= 1
|
|
2104
|
+
self.video_has_multi_view = num_video_views > 1
|
|
2105
|
+
|
|
2106
|
+
self.num_video_views = num_video_views
|
|
2107
|
+
|
|
2108
|
+
if self.video_has_multi_view:
|
|
2109
|
+
self.view_emb = nn.Parameter(torch.randn(num_video_views, dim) * 1e-2)
|
|
2110
|
+
|
|
2111
|
+
# proprioception
|
|
2112
|
+
|
|
2113
|
+
self.has_proprio = exists(dim_proprio)
|
|
2114
|
+
self.dim_proprio = dim_proprio
|
|
2115
|
+
|
|
2116
|
+
if self.has_proprio:
|
|
2117
|
+
self.to_proprio_token = nn.Linear(dim_proprio, dim)
|
|
2118
|
+
|
|
2119
|
+
self.to_proprio_pred = Sequential(
|
|
2120
|
+
RMSNorm(dim),
|
|
2121
|
+
nn.Linear(dim, dim_proprio)
|
|
1557
2122
|
)
|
|
1558
2123
|
|
|
1559
2124
|
# register tokens
|
|
@@ -1597,6 +2162,7 @@ class DynamicsWorldModel(Module):
|
|
|
1597
2162
|
# learned set of latent genes
|
|
1598
2163
|
|
|
1599
2164
|
self.agent_has_genes = num_latent_genes > 0
|
|
2165
|
+
self.num_latent_genes = num_latent_genes
|
|
1600
2166
|
self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
|
|
1601
2167
|
|
|
1602
2168
|
# policy head
|
|
@@ -1616,10 +2182,14 @@ class DynamicsWorldModel(Module):
|
|
|
1616
2182
|
num_continuous_actions = num_continuous_actions,
|
|
1617
2183
|
continuous_norm_stats = continuous_norm_stats,
|
|
1618
2184
|
can_unembed = True,
|
|
1619
|
-
unembed_dim = dim * 4
|
|
2185
|
+
unembed_dim = dim * 4,
|
|
2186
|
+
num_unembed_preds = multi_token_pred_len,
|
|
2187
|
+
squeeze_unembed_preds = False
|
|
1620
2188
|
)
|
|
1621
2189
|
|
|
1622
|
-
|
|
2190
|
+
# multi token prediction length
|
|
2191
|
+
|
|
2192
|
+
self.multi_token_pred_len = multi_token_pred_len
|
|
1623
2193
|
|
|
1624
2194
|
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
|
1625
2195
|
|
|
@@ -1632,12 +2202,15 @@ class DynamicsWorldModel(Module):
|
|
|
1632
2202
|
learned_embedding = add_reward_embed_to_agent_token
|
|
1633
2203
|
)
|
|
1634
2204
|
|
|
1635
|
-
|
|
2205
|
+
to_reward_pred = Sequential(
|
|
1636
2206
|
RMSNorm(dim),
|
|
1637
2207
|
LinearNoBias(dim, self.reward_encoder.num_bins)
|
|
1638
2208
|
)
|
|
1639
2209
|
|
|
1640
|
-
self.
|
|
2210
|
+
self.to_reward_pred = Ensemble(
|
|
2211
|
+
to_reward_pred,
|
|
2212
|
+
multi_token_pred_len
|
|
2213
|
+
)
|
|
1641
2214
|
|
|
1642
2215
|
# value head
|
|
1643
2216
|
|
|
@@ -1653,13 +2226,16 @@ class DynamicsWorldModel(Module):
|
|
|
1653
2226
|
self.transformer = AxialSpaceTimeTransformer(
|
|
1654
2227
|
dim = dim,
|
|
1655
2228
|
depth = depth,
|
|
2229
|
+
attn_heads = attn_heads,
|
|
1656
2230
|
attn_dim_head = attn_dim_head,
|
|
1657
2231
|
attn_softclamp_value = attn_softclamp_value,
|
|
1658
2232
|
attn_kwargs = attn_kwargs,
|
|
1659
2233
|
ff_kwargs = ff_kwargs,
|
|
1660
2234
|
num_residual_streams = num_residual_streams,
|
|
1661
2235
|
num_special_spatial_tokens = num_agents,
|
|
1662
|
-
|
|
2236
|
+
time_block_every = time_block_every,
|
|
2237
|
+
final_norm = False,
|
|
2238
|
+
**transformer_kwargs
|
|
1663
2239
|
)
|
|
1664
2240
|
|
|
1665
2241
|
# ppo related
|
|
@@ -1670,9 +2246,40 @@ class DynamicsWorldModel(Module):
|
|
|
1670
2246
|
|
|
1671
2247
|
self.ppo_eps_clip = ppo_eps_clip
|
|
1672
2248
|
self.value_clip = value_clip
|
|
1673
|
-
self.policy_entropy_weight =
|
|
2249
|
+
self.policy_entropy_weight = policy_entropy_weight
|
|
2250
|
+
|
|
2251
|
+
# pmpo related
|
|
2252
|
+
|
|
2253
|
+
self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
|
|
2254
|
+
self.pmpo_kl_div_loss_weight = pmpo_kl_div_loss_weight
|
|
2255
|
+
self.pmpo_reverse_kl = pmpo_reverse_kl
|
|
2256
|
+
|
|
2257
|
+
# rewards related
|
|
2258
|
+
|
|
2259
|
+
self.keep_reward_ema_stats = keep_reward_ema_stats
|
|
2260
|
+
self.reward_ema_decay = reward_ema_decay
|
|
2261
|
+
|
|
2262
|
+
self.register_buffer('reward_quantile_filter', tensor(reward_quantile_filter), persistent = False)
|
|
2263
|
+
|
|
2264
|
+
self.register_buffer('ema_returns_mean', tensor(0.))
|
|
2265
|
+
self.register_buffer('ema_returns_var', tensor(1.))
|
|
2266
|
+
|
|
2267
|
+
# loss related
|
|
2268
|
+
|
|
2269
|
+
self.flow_loss_normalizer = LossNormalizer(1)
|
|
2270
|
+
self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
|
|
2271
|
+
self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
|
2272
|
+
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None
|
|
2273
|
+
|
|
2274
|
+
self.latent_flow_loss_weight = latent_flow_loss_weight
|
|
2275
|
+
|
|
2276
|
+
self.register_buffer('reward_loss_weight', tensor(reward_loss_weight))
|
|
2277
|
+
self.register_buffer('discrete_action_loss_weight', tensor(discrete_action_loss_weight))
|
|
2278
|
+
self.register_buffer('continuous_action_loss_weight', tensor(continuous_action_loss_weight))
|
|
1674
2279
|
|
|
1675
|
-
|
|
2280
|
+
assert self.reward_loss_weight.numel() in {1, multi_token_pred_len}
|
|
2281
|
+
assert self.discrete_action_loss_weight.numel() in {1, multi_token_pred_len}
|
|
2282
|
+
assert self.continuous_action_loss_weight.numel() in {1, multi_token_pred_len}
|
|
1676
2283
|
|
|
1677
2284
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
1678
2285
|
|
|
@@ -1680,17 +2287,19 @@ class DynamicsWorldModel(Module):
|
|
|
1680
2287
|
def device(self):
|
|
1681
2288
|
return self.zero.device
|
|
1682
2289
|
|
|
1683
|
-
|
|
1684
|
-
self,
|
|
1685
|
-
signal_levels,
|
|
1686
|
-
align_dims_left_to = None
|
|
1687
|
-
):
|
|
1688
|
-
times = signal_levels.float() / self.max_steps
|
|
2290
|
+
# types of parameters
|
|
1689
2291
|
|
|
1690
|
-
|
|
1691
|
-
|
|
2292
|
+
def muon_parameters(self):
|
|
2293
|
+
return self.transformer.muon_parameters()
|
|
1692
2294
|
|
|
1693
|
-
|
|
2295
|
+
def policy_head_parameters(self):
|
|
2296
|
+
return [
|
|
2297
|
+
*self.policy_head.parameters(),
|
|
2298
|
+
*self.action_embedder.unembed_parameters() # includes the unembed from the action-embedder
|
|
2299
|
+
]
|
|
2300
|
+
|
|
2301
|
+
def value_head_parameters(self):
|
|
2302
|
+
return self.value_head.parameters()
|
|
1694
2303
|
|
|
1695
2304
|
def parameter(self):
|
|
1696
2305
|
params = super().parameters()
|
|
@@ -1700,53 +2309,402 @@ class DynamicsWorldModel(Module):
|
|
|
1700
2309
|
|
|
1701
2310
|
return list(set(params) - set(self.video_tokenizer.parameters()))
|
|
1702
2311
|
|
|
1703
|
-
|
|
2312
|
+
# helpers for shortcut flow matching
|
|
2313
|
+
|
|
2314
|
+
def get_times_from_signal_level(
|
|
2315
|
+
self,
|
|
2316
|
+
signal_levels,
|
|
2317
|
+
align_dims_left_to = None
|
|
2318
|
+
):
|
|
2319
|
+
times = signal_levels.float() / self.max_steps
|
|
2320
|
+
|
|
2321
|
+
if not exists(align_dims_left_to):
|
|
2322
|
+
return times
|
|
2323
|
+
|
|
2324
|
+
return align_dims_left(times, align_dims_left_to)
|
|
2325
|
+
|
|
2326
|
+
# evolutionary policy optimization - https://web3.arxiv.org/abs/2503.19037
|
|
2327
|
+
|
|
2328
|
+
@torch.no_grad()
|
|
2329
|
+
def evolve_(
|
|
2330
|
+
self,
|
|
2331
|
+
fitness,
|
|
2332
|
+
select_frac = 0.5,
|
|
2333
|
+
tournament_frac = 0.5
|
|
2334
|
+
):
|
|
2335
|
+
assert fitness.numel() == self.num_latent_genes
|
|
2336
|
+
|
|
2337
|
+
pop = self.latent_genes
|
|
2338
|
+
|
|
2339
|
+
pop_size = self.num_latent_genes
|
|
2340
|
+
num_selected = ceil(pop_size * select_frac)
|
|
2341
|
+
num_children = pop_size - num_selected
|
|
2342
|
+
|
|
2343
|
+
dim_gene = pop.shape[-1]
|
|
2344
|
+
|
|
2345
|
+
# natural selection just a sort and slice
|
|
2346
|
+
|
|
2347
|
+
selected_fitness, selected_indices = fitness.topk(num_selected, dim = -1)
|
|
2348
|
+
selected = pop[selected_indices]
|
|
2349
|
+
|
|
2350
|
+
# use tournament - one tournament per child
|
|
2351
|
+
|
|
2352
|
+
tournament_size = max(2, ceil(num_selected * tournament_frac))
|
|
2353
|
+
|
|
2354
|
+
tournaments = torch.randn((num_children, num_selected), device = self.device).argsort(dim = -1)[:, :tournament_size]
|
|
2355
|
+
|
|
2356
|
+
parent_ids = selected_fitness[tournaments].topk(2, dim = -1).indices # get top 2 winners as parents
|
|
2357
|
+
|
|
2358
|
+
parents = selected[parent_ids]
|
|
2359
|
+
|
|
2360
|
+
# crossover by random interpolation from parent1 to parent2
|
|
2361
|
+
|
|
2362
|
+
random_uniform_mix = torch.randn((num_children, dim_gene), device = self.device).sigmoid()
|
|
2363
|
+
|
|
2364
|
+
parent1, parent2 = parents.unbind(dim = 1)
|
|
2365
|
+
children = parent1.lerp(parent2, random_uniform_mix)
|
|
2366
|
+
|
|
2367
|
+
# store next population
|
|
2368
|
+
|
|
2369
|
+
next_pop = cat((selected, children))
|
|
2370
|
+
|
|
2371
|
+
self.latent_genes.copy_(next_pop)
|
|
2372
|
+
|
|
2373
|
+
# interacting with env for experience
|
|
2374
|
+
|
|
2375
|
+
@torch.no_grad()
|
|
2376
|
+
def interact_with_env(
|
|
2377
|
+
self,
|
|
2378
|
+
env,
|
|
2379
|
+
seed = None,
|
|
2380
|
+
agent_index = 0,
|
|
2381
|
+
step_size = 4,
|
|
2382
|
+
max_timesteps = 16,
|
|
2383
|
+
env_is_vectorized = False,
|
|
2384
|
+
use_time_kv_cache = True,
|
|
2385
|
+
store_agent_embed = True,
|
|
2386
|
+
store_old_action_unembeds = True,
|
|
2387
|
+
):
|
|
2388
|
+
assert exists(self.video_tokenizer)
|
|
2389
|
+
|
|
2390
|
+
init_frame = env.reset()
|
|
2391
|
+
|
|
2392
|
+
# frame to video
|
|
2393
|
+
|
|
2394
|
+
if env_is_vectorized:
|
|
2395
|
+
video = rearrange(init_frame, 'b c vh vw -> b c 1 vh vw')
|
|
2396
|
+
else:
|
|
2397
|
+
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
|
|
2398
|
+
|
|
2399
|
+
batch, device = video.shape[0], video.device
|
|
2400
|
+
|
|
2401
|
+
# accumulate
|
|
2402
|
+
|
|
2403
|
+
rewards = None
|
|
2404
|
+
discrete_actions = None
|
|
2405
|
+
continuous_actions = None
|
|
2406
|
+
discrete_log_probs = None
|
|
2407
|
+
continuous_log_probs = None
|
|
2408
|
+
values = None
|
|
2409
|
+
latents = None
|
|
2410
|
+
|
|
2411
|
+
acc_agent_embed = None
|
|
2412
|
+
acc_policy_embed = None
|
|
2413
|
+
|
|
2414
|
+
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
|
2415
|
+
|
|
2416
|
+
is_terminated = full((batch,), False, device = device)
|
|
2417
|
+
is_truncated = full((batch,), False, device = device)
|
|
2418
|
+
|
|
2419
|
+
episode_lens = full((batch,), 0, device = device)
|
|
2420
|
+
|
|
2421
|
+
# maybe time kv cache
|
|
2422
|
+
|
|
2423
|
+
time_kv_cache = None
|
|
2424
|
+
|
|
2425
|
+
step_index = 0
|
|
2426
|
+
|
|
2427
|
+
while not is_terminated.all():
|
|
2428
|
+
step_index += 1
|
|
2429
|
+
|
|
2430
|
+
latents = self.video_tokenizer(video, return_latents = True)
|
|
2431
|
+
|
|
2432
|
+
_, (agent_embed, next_time_kv_cache) = self.forward(
|
|
2433
|
+
latents = latents,
|
|
2434
|
+
signal_levels = self.max_steps - 1,
|
|
2435
|
+
step_sizes = step_size,
|
|
2436
|
+
rewards = rewards,
|
|
2437
|
+
discrete_actions = discrete_actions,
|
|
2438
|
+
continuous_actions = continuous_actions,
|
|
2439
|
+
time_kv_cache = time_kv_cache,
|
|
2440
|
+
latent_is_noised = True,
|
|
2441
|
+
return_pred_only = True,
|
|
2442
|
+
return_intermediates = True
|
|
2443
|
+
)
|
|
2444
|
+
|
|
2445
|
+
# time kv cache
|
|
2446
|
+
|
|
2447
|
+
if use_time_kv_cache:
|
|
2448
|
+
time_kv_cache = next_time_kv_cache
|
|
2449
|
+
|
|
2450
|
+
# get one agent
|
|
2451
|
+
|
|
2452
|
+
one_agent_embed = agent_embed[..., -1:, agent_index, :]
|
|
2453
|
+
|
|
2454
|
+
# values
|
|
2455
|
+
|
|
2456
|
+
value_bins = self.value_head(one_agent_embed)
|
|
2457
|
+
value = self.reward_encoder.bins_to_scalar_value(value_bins)
|
|
2458
|
+
|
|
2459
|
+
values = safe_cat((values, value), dim = 1)
|
|
2460
|
+
|
|
2461
|
+
# policy embed
|
|
2462
|
+
|
|
2463
|
+
policy_embed = self.policy_head(one_agent_embed)
|
|
2464
|
+
|
|
2465
|
+
if store_old_action_unembeds:
|
|
2466
|
+
acc_policy_embed = safe_cat((acc_policy_embed, policy_embed), dim = 1)
|
|
2467
|
+
|
|
2468
|
+
# sample actions
|
|
2469
|
+
|
|
2470
|
+
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
|
2471
|
+
|
|
2472
|
+
discrete_actions = safe_cat((discrete_actions, sampled_discrete_actions), dim = 1)
|
|
2473
|
+
continuous_actions = safe_cat((continuous_actions, sampled_continuous_actions), dim = 1)
|
|
2474
|
+
|
|
2475
|
+
# get the log prob and values for policy optimization
|
|
2476
|
+
|
|
2477
|
+
one_discrete_log_probs, one_continuous_log_probs = self.action_embedder.log_probs(
|
|
2478
|
+
policy_embed,
|
|
2479
|
+
pred_head_index = 0,
|
|
2480
|
+
discrete_targets = sampled_discrete_actions,
|
|
2481
|
+
continuous_targets = sampled_continuous_actions,
|
|
2482
|
+
)
|
|
2483
|
+
|
|
2484
|
+
discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
|
|
2485
|
+
continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
|
|
2486
|
+
|
|
2487
|
+
# pass the sampled action to the environment and get back next state and reward
|
|
2488
|
+
|
|
2489
|
+
env_step_out = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
|
2490
|
+
|
|
2491
|
+
if len(env_step_out) == 2:
|
|
2492
|
+
next_frame, reward = env_step_out
|
|
2493
|
+
terminated = full((batch,), False)
|
|
2494
|
+
truncated = full((batch,), False)
|
|
2495
|
+
|
|
2496
|
+
elif len(env_step_out) == 3:
|
|
2497
|
+
next_frame, reward, terminated = env_step_out
|
|
2498
|
+
truncated = full((batch,), False)
|
|
2499
|
+
|
|
2500
|
+
elif len(env_step_out) == 4:
|
|
2501
|
+
next_frame, reward, terminated, truncated = env_step_out
|
|
2502
|
+
|
|
2503
|
+
elif len(env_step_out) == 5:
|
|
2504
|
+
next_frame, reward, terminated, truncated, info = env_step_out
|
|
2505
|
+
|
|
2506
|
+
# update episode lens
|
|
2507
|
+
|
|
2508
|
+
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
|
2509
|
+
|
|
2510
|
+
# update `is_terminated`
|
|
2511
|
+
|
|
2512
|
+
# (1) - environment says it is terminated
|
|
2513
|
+
# (2) - previous step is truncated (this step is for bootstrap value)
|
|
2514
|
+
|
|
2515
|
+
is_terminated |= (terminated | is_truncated)
|
|
2516
|
+
|
|
2517
|
+
# update `is_truncated`
|
|
2518
|
+
|
|
2519
|
+
if step_index <= max_timesteps:
|
|
2520
|
+
is_truncated |= truncated
|
|
2521
|
+
|
|
2522
|
+
if step_index == max_timesteps:
|
|
2523
|
+
# if the step index is at the max time step allowed, set the truncated flag, if not already terminated
|
|
2524
|
+
|
|
2525
|
+
is_truncated |= ~is_terminated
|
|
2526
|
+
|
|
2527
|
+
# batch and time dimension
|
|
2528
|
+
|
|
2529
|
+
if env_is_vectorized:
|
|
2530
|
+
next_frame = rearrange(next_frame, 'b c vh vw -> b c 1 vh vw')
|
|
2531
|
+
reward = rearrange(reward, 'b -> b 1')
|
|
2532
|
+
else:
|
|
2533
|
+
next_frame = rearrange(next_frame, 'c vh vw -> 1 c 1 vh vw')
|
|
2534
|
+
reward = rearrange(reward, ' -> 1 1')
|
|
2535
|
+
|
|
2536
|
+
# concat
|
|
2537
|
+
|
|
2538
|
+
video = cat((video, next_frame), dim = 2)
|
|
2539
|
+
rewards = safe_cat((rewards, reward), dim = 1)
|
|
2540
|
+
|
|
2541
|
+
acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
|
|
2542
|
+
|
|
2543
|
+
# package up one experience for learning
|
|
2544
|
+
|
|
2545
|
+
batch, device = latents.shape[0], latents.device
|
|
2546
|
+
|
|
2547
|
+
one_experience = Experience(
|
|
2548
|
+
latents = latents,
|
|
2549
|
+
video = video[:, :, :-1],
|
|
2550
|
+
rewards = rewards,
|
|
2551
|
+
actions = (discrete_actions, continuous_actions),
|
|
2552
|
+
log_probs = (discrete_log_probs, continuous_log_probs),
|
|
2553
|
+
values = values,
|
|
2554
|
+
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
|
|
2555
|
+
agent_embed = acc_agent_embed if store_agent_embed else None,
|
|
2556
|
+
step_size = step_size,
|
|
2557
|
+
agent_index = agent_index,
|
|
2558
|
+
is_truncated = is_truncated,
|
|
2559
|
+
lens = episode_lens,
|
|
2560
|
+
is_from_world_model = False
|
|
2561
|
+
)
|
|
2562
|
+
|
|
2563
|
+
return one_experience
|
|
2564
|
+
|
|
2565
|
+
# ppo
|
|
2566
|
+
|
|
2567
|
+
def learn_from_experience(
|
|
1704
2568
|
self,
|
|
1705
|
-
|
|
2569
|
+
experience: Experience,
|
|
1706
2570
|
policy_optim: Optimizer | None = None,
|
|
1707
|
-
value_optim: Optimizer | None = None
|
|
2571
|
+
value_optim: Optimizer | None = None,
|
|
2572
|
+
only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
|
2573
|
+
use_pmpo = True,
|
|
2574
|
+
normalize_advantages = None,
|
|
2575
|
+
eps = 1e-6
|
|
1708
2576
|
):
|
|
1709
|
-
|
|
1710
|
-
|
|
1711
|
-
|
|
1712
|
-
|
|
1713
|
-
|
|
2577
|
+
assert isinstance(experience, Experience)
|
|
2578
|
+
|
|
2579
|
+
experience = experience.to(self.device)
|
|
2580
|
+
|
|
2581
|
+
latents = experience.latents
|
|
2582
|
+
actions = experience.actions
|
|
2583
|
+
old_log_probs = experience.log_probs
|
|
2584
|
+
old_values = experience.values
|
|
2585
|
+
rewards = experience.rewards
|
|
2586
|
+
agent_embeds = experience.agent_embed
|
|
2587
|
+
old_action_unembeds = experience.old_action_unembeds
|
|
2588
|
+
|
|
2589
|
+
step_size = experience.step_size
|
|
2590
|
+
agent_index = experience.agent_index
|
|
2591
|
+
|
|
2592
|
+
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization - world_model.generate(..., return_log_probs_and_values = True)'
|
|
2593
|
+
|
|
2594
|
+
batch, time = latents.shape[0], latents.shape[1]
|
|
1714
2595
|
|
|
1715
|
-
|
|
1716
|
-
agent_index = generation.agent_index
|
|
2596
|
+
# calculate returns
|
|
1717
2597
|
|
|
1718
|
-
|
|
2598
|
+
# mask out anything after the `lens`, which may include a bootstrapped node at the very end if `is_truncated = True`
|
|
2599
|
+
|
|
2600
|
+
if not exists(experience.is_truncated):
|
|
2601
|
+
experience.is_truncated = full((batch,), True, device = latents.device)
|
|
2602
|
+
|
|
2603
|
+
if exists(experience.lens):
|
|
2604
|
+
mask_for_gae = lens_to_mask(experience.lens, time)
|
|
2605
|
+
|
|
2606
|
+
rewards = rewards.masked_fill(~mask_for_gae, 0.)
|
|
2607
|
+
old_values = old_values.masked_fill(~mask_for_gae, 0.)
|
|
2608
|
+
|
|
2609
|
+
# calculate returns
|
|
1719
2610
|
|
|
1720
2611
|
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
|
1721
2612
|
|
|
1722
|
-
#
|
|
2613
|
+
# handle variable lengths
|
|
2614
|
+
|
|
2615
|
+
max_time = latents.shape[1]
|
|
2616
|
+
is_var_len = exists(experience.lens)
|
|
2617
|
+
|
|
2618
|
+
mask = None
|
|
2619
|
+
|
|
2620
|
+
if is_var_len:
|
|
2621
|
+
learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value
|
|
2622
|
+
mask = lens_to_mask(learnable_lens, max_time)
|
|
2623
|
+
|
|
2624
|
+
# determine whether to finetune entire transformer or just learn the heads
|
|
2625
|
+
|
|
2626
|
+
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
|
|
2627
|
+
|
|
2628
|
+
# maybe keep track returns statistics and normalize returns and values before calculating advantage, as done in dreamer v3
|
|
2629
|
+
|
|
2630
|
+
if self.keep_reward_ema_stats:
|
|
2631
|
+
ema_returns_mean, ema_returns_var = self.ema_returns_mean, self.ema_returns_var
|
|
2632
|
+
|
|
2633
|
+
decay = 1. - self.reward_ema_decay
|
|
2634
|
+
|
|
2635
|
+
# quantile filter
|
|
2636
|
+
|
|
2637
|
+
lo, hi = torch.quantile(returns, self.reward_quantile_filter).tolist()
|
|
2638
|
+
returns_for_stats = returns.clamp(lo, hi)
|
|
2639
|
+
|
|
2640
|
+
# mean, var - todo - handle distributed
|
|
2641
|
+
|
|
2642
|
+
returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var()
|
|
2643
|
+
|
|
2644
|
+
# ema
|
|
2645
|
+
|
|
2646
|
+
ema_returns_mean.lerp_(returns_mean, decay)
|
|
2647
|
+
ema_returns_var.lerp_(returns_var, decay)
|
|
2648
|
+
|
|
2649
|
+
# normalize
|
|
2650
|
+
|
|
2651
|
+
ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
|
|
2652
|
+
|
|
2653
|
+
normed_returns = (returns - ema_returns_mean) / ema_returns_std
|
|
2654
|
+
normed_old_values = (old_values - ema_returns_mean) / ema_returns_std
|
|
2655
|
+
|
|
2656
|
+
advantage = normed_returns - normed_old_values
|
|
2657
|
+
else:
|
|
2658
|
+
advantage = returns - old_values
|
|
2659
|
+
|
|
2660
|
+
# if using pmpo, do not normalize advantages, but can be overridden
|
|
2661
|
+
|
|
2662
|
+
normalize_advantages = default(normalize_advantages, not use_pmpo)
|
|
2663
|
+
|
|
2664
|
+
if normalize_advantages:
|
|
2665
|
+
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
|
2666
|
+
|
|
1723
2667
|
# https://arxiv.org/abs/2410.04166v1
|
|
1724
2668
|
|
|
1725
|
-
|
|
2669
|
+
if use_pmpo:
|
|
2670
|
+
pos_advantage_mask = advantage >= 0.
|
|
2671
|
+
neg_advantage_mask = ~pos_advantage_mask
|
|
1726
2672
|
|
|
1727
2673
|
# replay for the action logits and values
|
|
2674
|
+
# but only do so if fine tuning the entire world model for RL
|
|
1728
2675
|
|
|
1729
2676
|
discrete_actions, continuous_actions = actions
|
|
1730
2677
|
|
|
1731
|
-
|
|
1732
|
-
|
|
1733
|
-
|
|
1734
|
-
|
|
1735
|
-
|
|
1736
|
-
|
|
1737
|
-
|
|
1738
|
-
|
|
1739
|
-
|
|
1740
|
-
|
|
1741
|
-
|
|
2678
|
+
if (
|
|
2679
|
+
not only_learn_policy_value_heads or
|
|
2680
|
+
not exists(agent_embeds)
|
|
2681
|
+
):
|
|
2682
|
+
|
|
2683
|
+
with world_model_forward_context():
|
|
2684
|
+
_, (agent_embeds, _) = self.forward(
|
|
2685
|
+
latents = latents,
|
|
2686
|
+
signal_levels = self.max_steps - 1,
|
|
2687
|
+
step_sizes = step_size,
|
|
2688
|
+
rewards = rewards,
|
|
2689
|
+
discrete_actions = discrete_actions,
|
|
2690
|
+
continuous_actions = continuous_actions,
|
|
2691
|
+
latent_is_noised = True,
|
|
2692
|
+
return_pred_only = True,
|
|
2693
|
+
return_intermediates = True
|
|
2694
|
+
)
|
|
2695
|
+
|
|
2696
|
+
agent_embeds = agent_embeds[..., agent_index, :]
|
|
1742
2697
|
|
|
1743
|
-
|
|
2698
|
+
# maybe detach agent embed
|
|
2699
|
+
|
|
2700
|
+
if only_learn_policy_value_heads:
|
|
2701
|
+
agent_embeds = agent_embeds.detach()
|
|
1744
2702
|
|
|
1745
2703
|
# ppo
|
|
1746
2704
|
|
|
1747
|
-
policy_embed = self.policy_head(
|
|
2705
|
+
policy_embed = self.policy_head(agent_embeds)
|
|
1748
2706
|
|
|
1749
|
-
log_probs, entropies = self.action_embedder.log_probs(policy_embed, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True)
|
|
2707
|
+
log_probs, entropies = self.action_embedder.log_probs(policy_embed, pred_head_index = 0, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True)
|
|
1750
2708
|
|
|
1751
2709
|
# concat discrete and continuous actions into one for optimizing
|
|
1752
2710
|
|
|
@@ -1754,29 +2712,77 @@ class DynamicsWorldModel(Module):
|
|
|
1754
2712
|
log_probs = safe_cat(log_probs, dim = -1)
|
|
1755
2713
|
entropies = safe_cat(entropies, dim = -1)
|
|
1756
2714
|
|
|
1757
|
-
|
|
2715
|
+
advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions
|
|
1758
2716
|
|
|
1759
|
-
|
|
2717
|
+
if use_pmpo:
|
|
2718
|
+
# pmpo - weighting the positive and negative advantages equally - ignoring magnitude of advantage and taking the sign
|
|
2719
|
+
# seems to be weighted across batch and time, iiuc
|
|
2720
|
+
# eq (10) in https://arxiv.org/html/2410.04166v1
|
|
1760
2721
|
|
|
1761
|
-
|
|
2722
|
+
if exists(mask):
|
|
2723
|
+
pos_advantage_mask &= mask
|
|
2724
|
+
neg_advantage_mask &= mask
|
|
2725
|
+
|
|
2726
|
+
α = self.pmpo_pos_to_neg_weight
|
|
2727
|
+
|
|
2728
|
+
pos = masked_mean(log_probs, pos_advantage_mask)
|
|
2729
|
+
neg = -masked_mean(log_probs, neg_advantage_mask)
|
|
2730
|
+
|
|
2731
|
+
policy_loss = -(α * pos + (1. - α) * neg)
|
|
2732
|
+
|
|
2733
|
+
# take care of kl
|
|
2734
|
+
|
|
2735
|
+
if self.pmpo_kl_div_loss_weight > 0.:
|
|
2736
|
+
|
|
2737
|
+
new_unembedded_actions = self.action_embedder.unembed(policy_embed, pred_head_index = 0)
|
|
2738
|
+
|
|
2739
|
+
kl_div_inputs, kl_div_targets = new_unembedded_actions, old_action_unembeds
|
|
2740
|
+
|
|
2741
|
+
# mentioned that the "reverse direction for the prior KL" was used
|
|
2742
|
+
# make optional, as observed instability in toy task
|
|
2743
|
+
|
|
2744
|
+
if self.pmpo_reverse_kl:
|
|
2745
|
+
kl_div_inputs, kl_div_targets = kl_div_targets, kl_div_inputs
|
|
2746
|
+
|
|
2747
|
+
discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(kl_div_inputs, kl_div_targets)
|
|
1762
2748
|
|
|
1763
|
-
|
|
2749
|
+
# accumulate discrete and continuous kl div
|
|
1764
2750
|
|
|
1765
|
-
|
|
1766
|
-
policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
|
|
2751
|
+
kl_div_loss = 0.
|
|
1767
2752
|
|
|
1768
|
-
|
|
2753
|
+
if exists(discrete_kl_div):
|
|
2754
|
+
kl_div_loss = kl_div_loss + discrete_kl_div[mask].mean()
|
|
2755
|
+
|
|
2756
|
+
if exists(continuous_kl_div):
|
|
2757
|
+
kl_div_loss = kl_div_loss + continuous_kl_div[mask].mean()
|
|
2758
|
+
|
|
2759
|
+
policy_loss = policy_loss + kl_div_loss * self.pmpo_kl_div_loss_weight
|
|
2760
|
+
|
|
2761
|
+
else:
|
|
2762
|
+
# ppo clipped surrogate loss
|
|
2763
|
+
|
|
2764
|
+
ratio = (log_probs - old_log_probs).exp()
|
|
2765
|
+
clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
|
|
2766
|
+
|
|
2767
|
+
policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
|
|
2768
|
+
policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
|
|
2769
|
+
|
|
2770
|
+
policy_loss = masked_mean(policy_loss, mask)
|
|
1769
2771
|
|
|
1770
2772
|
# handle entropy loss for naive exploration bonus
|
|
1771
2773
|
|
|
1772
|
-
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
|
|
2774
|
+
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
|
|
2775
|
+
|
|
2776
|
+
entropy_loss = masked_mean(entropy_loss, mask)
|
|
2777
|
+
|
|
2778
|
+
# total policy loss
|
|
1773
2779
|
|
|
1774
2780
|
total_policy_loss = (
|
|
1775
2781
|
policy_loss +
|
|
1776
2782
|
entropy_loss * self.policy_entropy_weight
|
|
1777
2783
|
)
|
|
1778
2784
|
|
|
1779
|
-
#
|
|
2785
|
+
# maybe take policy optimizer step
|
|
1780
2786
|
|
|
1781
2787
|
if exists(policy_optim):
|
|
1782
2788
|
total_policy_loss.backward()
|
|
@@ -1786,7 +2792,7 @@ class DynamicsWorldModel(Module):
|
|
|
1786
2792
|
|
|
1787
2793
|
# value loss
|
|
1788
2794
|
|
|
1789
|
-
value_bins = self.value_head(
|
|
2795
|
+
value_bins = self.value_head(agent_embeds)
|
|
1790
2796
|
values = self.reward_encoder.bins_to_scalar_value(value_bins)
|
|
1791
2797
|
|
|
1792
2798
|
clipped_values = old_values + (values - old_values).clamp(-self.value_clip, self.value_clip)
|
|
@@ -1794,10 +2800,19 @@ class DynamicsWorldModel(Module):
|
|
|
1794
2800
|
|
|
1795
2801
|
return_bins = self.reward_encoder(returns)
|
|
1796
2802
|
|
|
2803
|
+
value_bins, return_bins, clipped_value_bins = tuple(rearrange(t, 'b t l -> b l t') for t in (value_bins, return_bins, clipped_value_bins))
|
|
2804
|
+
|
|
1797
2805
|
value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none')
|
|
1798
2806
|
value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none')
|
|
1799
2807
|
|
|
1800
|
-
value_loss = torch.maximum(value_loss_1, value_loss_2)
|
|
2808
|
+
value_loss = torch.maximum(value_loss_1, value_loss_2)
|
|
2809
|
+
|
|
2810
|
+
# maybe variable length
|
|
2811
|
+
|
|
2812
|
+
if is_var_len:
|
|
2813
|
+
value_loss = value_loss[mask].mean()
|
|
2814
|
+
else:
|
|
2815
|
+
value_loss = value_loss.mean()
|
|
1801
2816
|
|
|
1802
2817
|
# maybe take value optimizer step
|
|
1803
2818
|
|
|
@@ -1816,22 +2831,49 @@ class DynamicsWorldModel(Module):
|
|
|
1816
2831
|
num_steps = 4,
|
|
1817
2832
|
batch_size = 1,
|
|
1818
2833
|
agent_index = 0,
|
|
2834
|
+
tasks: int | Tensor | None = None,
|
|
2835
|
+
latent_gene_ids = None,
|
|
1819
2836
|
image_height = None,
|
|
1820
2837
|
image_width = None,
|
|
1821
2838
|
return_decoded_video = None,
|
|
1822
2839
|
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
|
2840
|
+
time_kv_cache: Tensor | None = None,
|
|
2841
|
+
use_time_kv_cache = True,
|
|
1823
2842
|
return_rewards_per_frame = False,
|
|
1824
2843
|
return_agent_actions = False,
|
|
1825
|
-
return_log_probs_and_values = False
|
|
2844
|
+
return_log_probs_and_values = False,
|
|
2845
|
+
return_for_policy_optimization = False,
|
|
2846
|
+
return_time_kv_cache = False,
|
|
2847
|
+
store_agent_embed = True,
|
|
2848
|
+
store_old_action_unembeds = True
|
|
1826
2849
|
|
|
1827
2850
|
): # (b t n d) | (b c t h w)
|
|
1828
2851
|
|
|
2852
|
+
# handy flag for returning generations for rl
|
|
2853
|
+
|
|
2854
|
+
if return_for_policy_optimization:
|
|
2855
|
+
return_agent_actions |= True
|
|
2856
|
+
return_log_probs_and_values |= True
|
|
2857
|
+
return_rewards_per_frame |= True
|
|
2858
|
+
|
|
2859
|
+
# more variables
|
|
2860
|
+
|
|
2861
|
+
has_proprio = self.has_proprio
|
|
1829
2862
|
was_training = self.training
|
|
1830
2863
|
self.eval()
|
|
1831
2864
|
|
|
2865
|
+
# validation
|
|
2866
|
+
|
|
1832
2867
|
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
|
1833
2868
|
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
|
|
1834
2869
|
|
|
2870
|
+
if isinstance(tasks, int):
|
|
2871
|
+
tasks = full((batch_size,), tasks, device = self.device)
|
|
2872
|
+
|
|
2873
|
+
assert not exists(tasks) or tasks.shape[0] == batch_size
|
|
2874
|
+
|
|
2875
|
+
# get state latent shape
|
|
2876
|
+
|
|
1835
2877
|
latent_shape = self.latent_shape
|
|
1836
2878
|
|
|
1837
2879
|
# derive step size
|
|
@@ -1841,9 +2883,16 @@ class DynamicsWorldModel(Module):
|
|
|
1841
2883
|
# denoising
|
|
1842
2884
|
# teacher forcing to start with
|
|
1843
2885
|
|
|
1844
|
-
latents = empty((batch_size, 0, *latent_shape), device = self.device)
|
|
2886
|
+
latents = empty((batch_size, 0, self.num_video_views, *latent_shape), device = self.device)
|
|
2887
|
+
|
|
2888
|
+
past_latents_context_noise = latents.clone()
|
|
1845
2889
|
|
|
1846
|
-
|
|
2890
|
+
# maybe internal state
|
|
2891
|
+
|
|
2892
|
+
if has_proprio:
|
|
2893
|
+
proprio = empty((batch_size, 0, self.dim_proprio), device = self.device)
|
|
2894
|
+
|
|
2895
|
+
past_proprio_context_noise = proprio.clone()
|
|
1847
2896
|
|
|
1848
2897
|
# maybe return actions
|
|
1849
2898
|
|
|
@@ -1858,8 +2907,17 @@ class DynamicsWorldModel(Module):
|
|
|
1858
2907
|
decoded_continuous_log_probs = None
|
|
1859
2908
|
decoded_values = None
|
|
1860
2909
|
|
|
2910
|
+
# maybe store agent embed
|
|
2911
|
+
|
|
2912
|
+
acc_agent_embed = None
|
|
2913
|
+
|
|
2914
|
+
# maybe store old actions for kl
|
|
2915
|
+
|
|
2916
|
+
acc_policy_embed = None
|
|
2917
|
+
|
|
1861
2918
|
# maybe return rewards
|
|
1862
2919
|
|
|
2920
|
+
decoded_rewards = None
|
|
1863
2921
|
if return_rewards_per_frame:
|
|
1864
2922
|
decoded_rewards = empty((batch_size, 0), device = self.device, dtype = torch.float32)
|
|
1865
2923
|
|
|
@@ -1868,55 +2926,131 @@ class DynamicsWorldModel(Module):
|
|
|
1868
2926
|
while latents.shape[1] < time_steps:
|
|
1869
2927
|
|
|
1870
2928
|
curr_time_steps = latents.shape[1]
|
|
1871
|
-
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
|
|
1872
2929
|
|
|
1873
|
-
|
|
2930
|
+
# determine whether to take an extra step if
|
|
2931
|
+
# (1) using time kv cache
|
|
2932
|
+
# (2) decoding anything off agent embedding (rewards, actions, etc)
|
|
2933
|
+
|
|
2934
|
+
take_extra_step = (
|
|
2935
|
+
use_time_kv_cache or
|
|
2936
|
+
return_rewards_per_frame or
|
|
2937
|
+
store_agent_embed or
|
|
2938
|
+
return_agent_actions
|
|
2939
|
+
)
|
|
2940
|
+
|
|
2941
|
+
# prepare noised latent / proprio inputs
|
|
2942
|
+
|
|
2943
|
+
noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
|
|
2944
|
+
|
|
2945
|
+
noised_proprio = None
|
|
2946
|
+
|
|
2947
|
+
if has_proprio:
|
|
2948
|
+
noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device)
|
|
2949
|
+
|
|
2950
|
+
# denoising steps
|
|
2951
|
+
|
|
2952
|
+
for step in range(num_steps + int(take_extra_step)):
|
|
2953
|
+
|
|
2954
|
+
is_last_step = (step + 1) == num_steps
|
|
2955
|
+
|
|
1874
2956
|
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
|
2957
|
+
|
|
2958
|
+
# noising past latent context
|
|
1875
2959
|
|
|
1876
|
-
noised_context = latents.lerp(
|
|
2960
|
+
noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8)
|
|
1877
2961
|
|
|
1878
|
-
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
|
|
2962
|
+
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * v n d')
|
|
2963
|
+
|
|
2964
|
+
# handle proprio
|
|
2965
|
+
|
|
2966
|
+
noised_proprio_with_context = None
|
|
2967
|
+
|
|
2968
|
+
if has_proprio:
|
|
2969
|
+
noised_proprio_context = proprio.lerp(past_proprio_context_noise, context_signal_noise)
|
|
2970
|
+
noised_proprio_with_context, _ = pack((noised_proprio_context, noised_proprio), 'b * d')
|
|
2971
|
+
|
|
2972
|
+
# proper signal levels
|
|
1879
2973
|
|
|
1880
2974
|
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
|
1881
2975
|
|
|
1882
|
-
pred, agent_embed = self.forward(
|
|
2976
|
+
pred, (agent_embed, next_time_kv_cache) = self.forward(
|
|
1883
2977
|
latents = noised_latent_with_context,
|
|
1884
2978
|
signal_levels = signal_levels_with_context,
|
|
1885
2979
|
step_sizes = step_size,
|
|
1886
2980
|
rewards = decoded_rewards,
|
|
2981
|
+
tasks = tasks,
|
|
2982
|
+
latent_gene_ids = latent_gene_ids,
|
|
1887
2983
|
discrete_actions = decoded_discrete_actions,
|
|
1888
2984
|
continuous_actions = decoded_continuous_actions,
|
|
2985
|
+
proprio = noised_proprio_with_context,
|
|
2986
|
+
time_kv_cache = time_kv_cache,
|
|
1889
2987
|
latent_is_noised = True,
|
|
2988
|
+
latent_has_view_dim = True,
|
|
1890
2989
|
return_pred_only = True,
|
|
1891
|
-
|
|
2990
|
+
return_intermediates = True,
|
|
1892
2991
|
)
|
|
1893
2992
|
|
|
1894
|
-
|
|
2993
|
+
if use_time_kv_cache and is_last_step:
|
|
2994
|
+
time_kv_cache = next_time_kv_cache
|
|
2995
|
+
|
|
2996
|
+
# early break if taking an extra step for agent embedding off cleaned latents for decoding
|
|
2997
|
+
|
|
2998
|
+
if take_extra_step and is_last_step:
|
|
2999
|
+
break
|
|
3000
|
+
|
|
3001
|
+
# maybe proprio
|
|
3002
|
+
|
|
3003
|
+
if has_proprio:
|
|
3004
|
+
pred, pred_proprio = pred
|
|
3005
|
+
|
|
3006
|
+
# unpack pred
|
|
3007
|
+
|
|
3008
|
+
_, pred = unpack(pred, pack_context_shape, 'b * v n d')
|
|
3009
|
+
|
|
3010
|
+
if has_proprio:
|
|
3011
|
+
_, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d')
|
|
1895
3012
|
|
|
1896
3013
|
# derive flow, based on whether in x-space or not
|
|
1897
3014
|
|
|
1898
|
-
|
|
1899
|
-
|
|
1900
|
-
|
|
1901
|
-
|
|
1902
|
-
|
|
3015
|
+
def denoise_step(pred, noised, signal_levels):
|
|
3016
|
+
if self.pred_orig_latent:
|
|
3017
|
+
times = self.get_times_from_signal_level(signal_levels)
|
|
3018
|
+
aligned_times = align_dims_left(times, noised)
|
|
3019
|
+
|
|
3020
|
+
flow = (pred - noised) / (1. - aligned_times)
|
|
3021
|
+
else:
|
|
3022
|
+
flow = pred
|
|
3023
|
+
|
|
3024
|
+
return flow * (step_size / self.max_steps)
|
|
1903
3025
|
|
|
1904
3026
|
# denoise
|
|
1905
3027
|
|
|
1906
|
-
noised_latent +=
|
|
3028
|
+
noised_latent += denoise_step(pred, noised_latent, signal_levels)
|
|
3029
|
+
|
|
3030
|
+
if has_proprio:
|
|
3031
|
+
noised_proprio += denoise_step(pred_proprio, noised_proprio, signal_levels)
|
|
1907
3032
|
|
|
1908
3033
|
denoised_latent = noised_latent # it is now denoised
|
|
1909
3034
|
|
|
3035
|
+
if has_proprio:
|
|
3036
|
+
denoised_proprio = noised_proprio
|
|
3037
|
+
|
|
1910
3038
|
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
|
1911
3039
|
|
|
1912
3040
|
if return_rewards_per_frame:
|
|
1913
3041
|
one_agent_embed = agent_embed[:, -1:, agent_index]
|
|
1914
3042
|
|
|
1915
|
-
reward_logits = self.to_reward_pred(one_agent_embed)
|
|
3043
|
+
reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0)
|
|
1916
3044
|
pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
|
|
1917
3045
|
|
|
1918
3046
|
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
|
1919
3047
|
|
|
3048
|
+
# maybe store agent embed
|
|
3049
|
+
|
|
3050
|
+
if store_agent_embed:
|
|
3051
|
+
one_agent_embed = agent_embed[:, -1:, agent_index]
|
|
3052
|
+
acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
|
|
3053
|
+
|
|
1920
3054
|
# decode the agent actions if needed
|
|
1921
3055
|
|
|
1922
3056
|
if return_agent_actions:
|
|
@@ -1926,7 +3060,14 @@ class DynamicsWorldModel(Module):
|
|
|
1926
3060
|
|
|
1927
3061
|
policy_embed = self.policy_head(one_agent_embed)
|
|
1928
3062
|
|
|
1929
|
-
|
|
3063
|
+
# maybe store old actions
|
|
3064
|
+
|
|
3065
|
+
if store_old_action_unembeds:
|
|
3066
|
+
acc_policy_embed = safe_cat((acc_policy_embed, policy_embed), dim = 1)
|
|
3067
|
+
|
|
3068
|
+
# sample actions
|
|
3069
|
+
|
|
3070
|
+
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
|
1930
3071
|
|
|
1931
3072
|
decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
|
|
1932
3073
|
decoded_continuous_actions = safe_cat((decoded_continuous_actions, sampled_continuous_actions), dim = 1)
|
|
@@ -1934,6 +3075,7 @@ class DynamicsWorldModel(Module):
|
|
|
1934
3075
|
if return_log_probs_and_values:
|
|
1935
3076
|
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
|
|
1936
3077
|
policy_embed,
|
|
3078
|
+
pred_head_index = 0,
|
|
1937
3079
|
discrete_targets = sampled_discrete_actions,
|
|
1938
3080
|
continuous_targets = sampled_continuous_actions,
|
|
1939
3081
|
)
|
|
@@ -1952,7 +3094,14 @@ class DynamicsWorldModel(Module):
|
|
|
1952
3094
|
|
|
1953
3095
|
# add new fixed context noise for the temporal consistency
|
|
1954
3096
|
|
|
1955
|
-
|
|
3097
|
+
past_latents_context_noise = cat((past_latents_context_noise, randn_like(denoised_latent)), dim = 1)
|
|
3098
|
+
|
|
3099
|
+
# handle proprio
|
|
3100
|
+
|
|
3101
|
+
if has_proprio:
|
|
3102
|
+
proprio = cat((proprio, denoised_proprio), dim = 1)
|
|
3103
|
+
|
|
3104
|
+
past_proprio_context_noise = cat((past_proprio_context_noise, randn_like(denoised_proprio)), dim = 1)
|
|
1956
3105
|
|
|
1957
3106
|
# restore state
|
|
1958
3107
|
|
|
@@ -1966,24 +3115,50 @@ class DynamicsWorldModel(Module):
|
|
|
1966
3115
|
video = None
|
|
1967
3116
|
|
|
1968
3117
|
if return_decoded_video:
|
|
3118
|
+
|
|
3119
|
+
latents_for_video = rearrange(latents, 'b t v n d -> b v t n d')
|
|
3120
|
+
latents_for_video, unpack_view = pack_one(latents_for_video, '* t n d')
|
|
3121
|
+
|
|
1969
3122
|
video = self.video_tokenizer.decode(
|
|
1970
|
-
|
|
3123
|
+
latents_for_video,
|
|
1971
3124
|
height = image_height,
|
|
1972
3125
|
width = image_width
|
|
1973
3126
|
)
|
|
1974
3127
|
|
|
3128
|
+
video = unpack_view(video, '* t c vh vw')
|
|
3129
|
+
|
|
3130
|
+
# remove the lone view dimension
|
|
3131
|
+
|
|
3132
|
+
if not self.video_has_multi_view:
|
|
3133
|
+
latents = rearrange(latents, 'b t 1 ... -> b t ...')
|
|
3134
|
+
|
|
3135
|
+
if exists(video):
|
|
3136
|
+
video = rearrange(video, 'b 1 ... -> b ...')
|
|
3137
|
+
|
|
1975
3138
|
# only return video or latent if not requesting anything else, for first stage training
|
|
1976
3139
|
|
|
1977
|
-
if not has_at_least_one(return_rewards_per_frame, return_agent_actions):
|
|
1978
|
-
|
|
3140
|
+
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
|
|
3141
|
+
out = video if return_decoded_video else latents
|
|
3142
|
+
|
|
3143
|
+
if not return_time_kv_cache:
|
|
3144
|
+
return out
|
|
3145
|
+
|
|
3146
|
+
return out, time_kv_cache
|
|
1979
3147
|
|
|
1980
3148
|
# returning agent actions, rewards, and log probs + values for policy optimization
|
|
1981
3149
|
|
|
3150
|
+
batch, device = latents.shape[0], latents.device
|
|
3151
|
+
experience_lens = full((batch,), time_steps, device = device)
|
|
3152
|
+
|
|
1982
3153
|
gen = Experience(
|
|
1983
3154
|
latents = latents,
|
|
1984
3155
|
video = video,
|
|
3156
|
+
proprio = proprio if has_proprio else None,
|
|
3157
|
+
agent_embed = acc_agent_embed if store_agent_embed else None,
|
|
3158
|
+
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
|
|
1985
3159
|
step_size = step_size,
|
|
1986
3160
|
agent_index = agent_index,
|
|
3161
|
+
lens = experience_lens,
|
|
1987
3162
|
is_from_world_model = True
|
|
1988
3163
|
)
|
|
1989
3164
|
|
|
@@ -1998,13 +3173,17 @@ class DynamicsWorldModel(Module):
|
|
|
1998
3173
|
|
|
1999
3174
|
gen.values = decoded_values
|
|
2000
3175
|
|
|
2001
|
-
|
|
3176
|
+
if not return_time_kv_cache:
|
|
3177
|
+
return gen
|
|
3178
|
+
|
|
3179
|
+
return gen, time_kv_cache
|
|
2002
3180
|
|
|
2003
3181
|
def forward(
|
|
2004
3182
|
self,
|
|
2005
3183
|
*,
|
|
2006
|
-
video = None, # (b c t vh vw)
|
|
2007
|
-
latents = None, # (b t n d) | (b t d)
|
|
3184
|
+
video = None, # (b v? c t vh vw)
|
|
3185
|
+
latents = None, # (b t v? n d) | (b t v? d)
|
|
3186
|
+
lens = None, # (b)
|
|
2008
3187
|
signal_levels = None, # () | (b) | (b t)
|
|
2009
3188
|
step_sizes = None, # () | (b)
|
|
2010
3189
|
step_sizes_log2 = None, # () | (b)
|
|
@@ -2015,25 +3194,46 @@ class DynamicsWorldModel(Module):
|
|
|
2015
3194
|
continuous_actions = None, # (b t na) | (b t-1 na)
|
|
2016
3195
|
discrete_action_types = None, # (na)
|
|
2017
3196
|
continuous_action_types = None, # (na)
|
|
3197
|
+
proprio = None, # (b t dp)
|
|
3198
|
+
time_kv_cache = None,
|
|
2018
3199
|
return_pred_only = False,
|
|
2019
3200
|
latent_is_noised = False,
|
|
2020
3201
|
return_all_losses = False,
|
|
2021
|
-
|
|
2022
|
-
add_autoregressive_action_loss =
|
|
3202
|
+
return_intermediates = False,
|
|
3203
|
+
add_autoregressive_action_loss = True,
|
|
3204
|
+
update_loss_ema = None,
|
|
3205
|
+
latent_has_view_dim = False
|
|
2023
3206
|
):
|
|
2024
3207
|
# handle video or latents
|
|
2025
3208
|
|
|
2026
3209
|
assert exists(video) ^ exists(latents)
|
|
2027
3210
|
|
|
3211
|
+
# standardize view dimension
|
|
3212
|
+
|
|
3213
|
+
if not self.video_has_multi_view:
|
|
3214
|
+
if exists(video):
|
|
3215
|
+
video = rearrange(video, 'b ... -> b 1 ...')
|
|
3216
|
+
|
|
3217
|
+
if exists(latents) and not latent_has_view_dim:
|
|
3218
|
+
latents = rearrange(latents, 'b t ... -> b t 1 ...')
|
|
3219
|
+
|
|
3220
|
+
# if raw video passed in, tokenize
|
|
3221
|
+
|
|
2028
3222
|
if exists(video):
|
|
3223
|
+
assert video.ndim == 6
|
|
3224
|
+
|
|
3225
|
+
video, unpack_views = pack_one(video, '* c t vh vw')
|
|
2029
3226
|
assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
|
|
2030
3227
|
|
|
2031
3228
|
latents = self.video_tokenizer.tokenize(video)
|
|
3229
|
+
latents = unpack_views(latents, '* t n d')
|
|
3230
|
+
latents = rearrange(latents, 'b v t n d -> b t v n d')
|
|
2032
3231
|
|
|
2033
|
-
if latents.ndim ==
|
|
2034
|
-
latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
|
|
3232
|
+
if latents.ndim == 4:
|
|
3233
|
+
latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
|
|
2035
3234
|
|
|
2036
|
-
assert latents.shape[-2:] == self.latent_shape
|
|
3235
|
+
assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}'
|
|
3236
|
+
assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}'
|
|
2037
3237
|
|
|
2038
3238
|
# variables
|
|
2039
3239
|
|
|
@@ -2108,16 +3308,17 @@ class DynamicsWorldModel(Module):
|
|
|
2108
3308
|
|
|
2109
3309
|
# times is from 0 to 1
|
|
2110
3310
|
|
|
2111
|
-
times = self.get_times_from_signal_level(signal_levels
|
|
3311
|
+
times = self.get_times_from_signal_level(signal_levels)
|
|
2112
3312
|
|
|
2113
3313
|
if not latent_is_noised:
|
|
2114
3314
|
# get the noise
|
|
2115
3315
|
|
|
2116
3316
|
noise = randn_like(latents)
|
|
3317
|
+
aligned_times = align_dims_left(times, latents)
|
|
2117
3318
|
|
|
2118
3319
|
# noise from 0 as noise to 1 as data
|
|
2119
3320
|
|
|
2120
|
-
noised_latents = noise.lerp(latents,
|
|
3321
|
+
noised_latents = noise.lerp(latents, aligned_times)
|
|
2121
3322
|
|
|
2122
3323
|
else:
|
|
2123
3324
|
noised_latents = latents
|
|
@@ -2165,6 +3366,27 @@ class DynamicsWorldModel(Module):
|
|
|
2165
3366
|
|
|
2166
3367
|
reward_tokens = add('1 d, b t d', self.reward_learned_embed, reward_tokens)
|
|
2167
3368
|
|
|
3369
|
+
# maybe proprioception
|
|
3370
|
+
|
|
3371
|
+
assert xnor(self.has_proprio, exists(proprio)), 'proprio must be passed in if `dim_proprio` is set and vice versa'
|
|
3372
|
+
|
|
3373
|
+
noised_proprio = None
|
|
3374
|
+
|
|
3375
|
+
if self.has_proprio:
|
|
3376
|
+
|
|
3377
|
+
if not latent_is_noised:
|
|
3378
|
+
# get the noise
|
|
3379
|
+
|
|
3380
|
+
proprio_noise = randn_like(proprio)
|
|
3381
|
+
aligned_times = align_dims_left(times, proprio)
|
|
3382
|
+
|
|
3383
|
+
# noise from 0 as noise to 1 as data
|
|
3384
|
+
|
|
3385
|
+
noised_proprio = proprio_noise.lerp(proprio, aligned_times)
|
|
3386
|
+
|
|
3387
|
+
else:
|
|
3388
|
+
noised_proprio = proprio
|
|
3389
|
+
|
|
2168
3390
|
# maybe create the action tokens
|
|
2169
3391
|
|
|
2170
3392
|
if exists(discrete_actions) or exists(continuous_actions):
|
|
@@ -2185,16 +3407,27 @@ class DynamicsWorldModel(Module):
|
|
|
2185
3407
|
|
|
2186
3408
|
action_tokens = add('1 d, b t d', self.action_learned_embed, action_tokens)
|
|
2187
3409
|
|
|
3410
|
+
elif self.action_embedder.has_actions:
|
|
3411
|
+
action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1])
|
|
3412
|
+
|
|
2188
3413
|
else:
|
|
2189
3414
|
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
|
|
2190
3415
|
|
|
2191
3416
|
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
|
2192
3417
|
|
|
2193
|
-
def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False):
|
|
3418
|
+
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
|
|
3419
|
+
|
|
2194
3420
|
# latents to spatial tokens
|
|
2195
3421
|
|
|
2196
3422
|
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
|
2197
3423
|
|
|
3424
|
+
# maybe add view embedding
|
|
3425
|
+
|
|
3426
|
+
if self.video_has_multi_view:
|
|
3427
|
+
space_tokens = add('b t v ... d, v d', space_tokens, self.view_emb)
|
|
3428
|
+
|
|
3429
|
+
# merge spatial tokens
|
|
3430
|
+
|
|
2198
3431
|
space_tokens, inverse_pack_space_per_latent = pack_one(space_tokens, 'b t * d')
|
|
2199
3432
|
|
|
2200
3433
|
num_spatial_tokens = space_tokens.shape[-2]
|
|
@@ -2212,6 +3445,13 @@ class DynamicsWorldModel(Module):
|
|
|
2212
3445
|
|
|
2213
3446
|
registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time)
|
|
2214
3447
|
|
|
3448
|
+
# maybe proprio
|
|
3449
|
+
|
|
3450
|
+
if exists(noised_proprio):
|
|
3451
|
+
proprio_token = self.to_proprio_token(noised_proprio)
|
|
3452
|
+
else:
|
|
3453
|
+
proprio_token = registers[:, :, 0:0]
|
|
3454
|
+
|
|
2215
3455
|
# determine signal + step size embed for their diffusion forcing + shortcut
|
|
2216
3456
|
|
|
2217
3457
|
signal_embed = self.signal_levels_embed(signal_levels)
|
|
@@ -2224,15 +3464,15 @@ class DynamicsWorldModel(Module):
|
|
|
2224
3464
|
|
|
2225
3465
|
# pack to tokens for attending
|
|
2226
3466
|
|
|
2227
|
-
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
|
3467
|
+
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
|
2228
3468
|
|
|
2229
3469
|
# attention
|
|
2230
3470
|
|
|
2231
|
-
tokens = self.transformer(tokens)
|
|
3471
|
+
tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
|
|
2232
3472
|
|
|
2233
3473
|
# unpack
|
|
2234
3474
|
|
|
2235
|
-
flow_token, space_tokens, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
|
3475
|
+
flow_token, space_tokens, proprio_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
|
2236
3476
|
|
|
2237
3477
|
# pooling
|
|
2238
3478
|
|
|
@@ -2240,10 +3480,22 @@ class DynamicsWorldModel(Module):
|
|
|
2240
3480
|
|
|
2241
3481
|
pred = self.to_latent_pred(space_tokens)
|
|
2242
3482
|
|
|
3483
|
+
# maybe proprio
|
|
3484
|
+
|
|
3485
|
+
if self.has_proprio:
|
|
3486
|
+
pred_proprio = self.to_proprio_pred(proprio_token)
|
|
3487
|
+
|
|
3488
|
+
pred = (pred, pred_proprio)
|
|
3489
|
+
|
|
3490
|
+
# returning
|
|
3491
|
+
|
|
2243
3492
|
if not return_agent_tokens:
|
|
2244
3493
|
return pred
|
|
2245
3494
|
|
|
2246
|
-
|
|
3495
|
+
if not return_time_kv_cache:
|
|
3496
|
+
return pred, agent_tokens
|
|
3497
|
+
|
|
3498
|
+
return pred, (agent_tokens, next_time_kv_cache)
|
|
2247
3499
|
|
|
2248
3500
|
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
|
2249
3501
|
|
|
@@ -2251,13 +3503,47 @@ class DynamicsWorldModel(Module):
|
|
|
2251
3503
|
|
|
2252
3504
|
# forward the network
|
|
2253
3505
|
|
|
2254
|
-
pred, encoded_agent_tokens = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True)
|
|
3506
|
+
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
|
|
2255
3507
|
|
|
2256
3508
|
if return_pred_only:
|
|
2257
|
-
if not
|
|
3509
|
+
if not return_intermediates:
|
|
3510
|
+
return pred
|
|
3511
|
+
|
|
3512
|
+
return pred, (encoded_agent_tokens, next_time_kv_cache)
|
|
3513
|
+
|
|
3514
|
+
# pack the predictions to calculate flow for different modalities all at once
|
|
3515
|
+
|
|
3516
|
+
if self.has_proprio:
|
|
3517
|
+
pred, for_flow_loss_packed_shape = pack(pred, 'b t *')
|
|
3518
|
+
|
|
3519
|
+
noised, _ = pack((noised_latents, noised_proprio), 'b t *')
|
|
3520
|
+
data, _ = pack((latents, proprio), 'b t *')
|
|
3521
|
+
noise, _ = pack((noise, proprio_noise), 'b t *')
|
|
3522
|
+
else:
|
|
3523
|
+
noised = noised_latents
|
|
3524
|
+
data = latents
|
|
3525
|
+
|
|
3526
|
+
# wrapper function for maybe unpacking and packing modalities for doing flow math in unison
|
|
3527
|
+
|
|
3528
|
+
def maybe_pack_unpack(fn):
|
|
3529
|
+
@wraps(fn)
|
|
3530
|
+
@torch.no_grad()
|
|
3531
|
+
def inner(noised, *args, **kwargs):
|
|
3532
|
+
|
|
3533
|
+
noised_proprio = None
|
|
3534
|
+
|
|
3535
|
+
if self.has_proprio:
|
|
3536
|
+
noised, noised_proprio = unpack(noised, for_flow_loss_packed_shape, 'b t *')
|
|
3537
|
+
|
|
3538
|
+
pred = fn(noised, noised_proprio, *args, **kwargs)
|
|
3539
|
+
|
|
3540
|
+
if self.has_proprio:
|
|
3541
|
+
pred, _ = pack(pred, 'b t *')
|
|
3542
|
+
|
|
2258
3543
|
return pred
|
|
3544
|
+
return inner
|
|
2259
3545
|
|
|
2260
|
-
|
|
3546
|
+
wrapped_get_prediction = maybe_pack_unpack(_get_prediction)
|
|
2261
3547
|
|
|
2262
3548
|
# determine the target for the loss
|
|
2263
3549
|
|
|
@@ -2274,46 +3560,45 @@ class DynamicsWorldModel(Module):
|
|
|
2274
3560
|
# x-space as in paper is in else clause
|
|
2275
3561
|
|
|
2276
3562
|
if is_v_space_pred:
|
|
2277
|
-
pred_target = flow =
|
|
3563
|
+
pred_target = flow = data - noise
|
|
2278
3564
|
else:
|
|
2279
|
-
pred_target =
|
|
3565
|
+
pred_target = data
|
|
2280
3566
|
else:
|
|
2281
3567
|
# shortcut training - Frans et al. https://arxiv.org/abs/2410.12557
|
|
2282
3568
|
|
|
2283
3569
|
# basically a consistency loss where you ensure quantity of two half steps equals one step
|
|
2284
3570
|
# dreamer then makes it works for x-space with some math
|
|
2285
3571
|
|
|
2286
|
-
get_prediction_no_grad = torch.no_grad()(_get_prediction)
|
|
2287
|
-
|
|
2288
3572
|
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
|
2289
3573
|
half_step_size = 2 ** step_sizes_log2_minus_one
|
|
2290
3574
|
|
|
2291
|
-
first_step_pred =
|
|
3575
|
+
first_step_pred = wrapped_get_prediction(noised, signal_levels, step_sizes_log2_minus_one)
|
|
2292
3576
|
|
|
2293
3577
|
# first derive b'
|
|
2294
3578
|
|
|
2295
3579
|
if is_v_space_pred:
|
|
2296
3580
|
first_step_pred_flow = first_step_pred
|
|
2297
3581
|
else:
|
|
2298
|
-
first_times = self.get_times_from_signal_level(signal_levels,
|
|
2299
|
-
|
|
3582
|
+
first_times = self.get_times_from_signal_level(signal_levels, noised)
|
|
3583
|
+
|
|
3584
|
+
first_step_pred_flow = (first_step_pred - noised) / (1. - first_times)
|
|
2300
3585
|
|
|
2301
3586
|
# take a half step
|
|
2302
3587
|
|
|
2303
|
-
half_step_size_align_left = align_dims_left(half_step_size,
|
|
3588
|
+
half_step_size_align_left = align_dims_left(half_step_size, noised)
|
|
2304
3589
|
|
|
2305
|
-
|
|
3590
|
+
denoised = noised + first_step_pred_flow * (half_step_size_align_left / self.max_steps)
|
|
2306
3591
|
|
|
2307
3592
|
# get second prediction for b''
|
|
2308
3593
|
|
|
2309
3594
|
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
|
|
2310
|
-
second_step_pred =
|
|
3595
|
+
second_step_pred = wrapped_get_prediction(denoised, signal_levels_plus_half_step, step_sizes_log2_minus_one)
|
|
2311
3596
|
|
|
2312
3597
|
if is_v_space_pred:
|
|
2313
3598
|
second_step_pred_flow = second_step_pred
|
|
2314
3599
|
else:
|
|
2315
|
-
second_times = self.get_times_from_signal_level(signal_levels_plus_half_step,
|
|
2316
|
-
second_step_pred_flow = (second_step_pred -
|
|
3600
|
+
second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised)
|
|
3601
|
+
second_step_pred_flow = (second_step_pred - denoised) / (1. - second_times)
|
|
2317
3602
|
|
|
2318
3603
|
# pred target is sg(b' + b'') / 2
|
|
2319
3604
|
|
|
@@ -2322,7 +3607,7 @@ class DynamicsWorldModel(Module):
|
|
|
2322
3607
|
# need to convert x-space to v-space
|
|
2323
3608
|
|
|
2324
3609
|
if is_x_space:
|
|
2325
|
-
pred = (pred -
|
|
3610
|
+
pred = (pred - noised) / (1. - first_times)
|
|
2326
3611
|
maybe_shortcut_loss_weight = (1. - first_times) ** 2
|
|
2327
3612
|
|
|
2328
3613
|
# mse loss
|
|
@@ -2335,9 +3620,23 @@ class DynamicsWorldModel(Module):
|
|
|
2335
3620
|
|
|
2336
3621
|
if exists(self.loss_weight_fn):
|
|
2337
3622
|
loss_weight = self.loss_weight_fn(times)
|
|
3623
|
+
loss_weight = align_dims_left(loss_weight, flow_losses)
|
|
3624
|
+
|
|
2338
3625
|
flow_losses = flow_losses * loss_weight
|
|
2339
3626
|
|
|
2340
|
-
|
|
3627
|
+
# handle variable lengths if needed
|
|
3628
|
+
|
|
3629
|
+
is_var_len = exists(lens)
|
|
3630
|
+
|
|
3631
|
+
if is_var_len:
|
|
3632
|
+
|
|
3633
|
+
loss_mask = lens_to_mask(lens, time)
|
|
3634
|
+
loss_mask_without_last = loss_mask[:, :-1]
|
|
3635
|
+
|
|
3636
|
+
flow_loss = flow_losses[loss_mask].mean()
|
|
3637
|
+
|
|
3638
|
+
else:
|
|
3639
|
+
flow_loss = flow_losses.mean()
|
|
2341
3640
|
|
|
2342
3641
|
# now take care of the agent token losses
|
|
2343
3642
|
|
|
@@ -2348,58 +3647,114 @@ class DynamicsWorldModel(Module):
|
|
|
2348
3647
|
if rewards.ndim == 2: # (b t)
|
|
2349
3648
|
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
|
|
2350
3649
|
|
|
2351
|
-
reward_pred = self.to_reward_pred(encoded_agent_tokens)
|
|
2352
|
-
|
|
3650
|
+
reward_pred = self.to_reward_pred(encoded_agent_tokens[:, :-1])
|
|
3651
|
+
|
|
3652
|
+
reward_pred = rearrange(reward_pred, 'mtp b t l -> b l t mtp')
|
|
3653
|
+
|
|
3654
|
+
reward_targets, reward_loss_mask = create_multi_token_prediction_targets(two_hot_encoding[:, :-1], self.multi_token_pred_len)
|
|
3655
|
+
|
|
3656
|
+
reward_targets = rearrange(reward_targets, 'b t mtp l -> b l t mtp')
|
|
3657
|
+
|
|
3658
|
+
reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
|
|
3659
|
+
|
|
3660
|
+
reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.)
|
|
3661
|
+
|
|
3662
|
+
if is_var_len:
|
|
3663
|
+
reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
|
|
3664
|
+
else:
|
|
3665
|
+
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
|
|
2353
3666
|
|
|
2354
3667
|
# maybe autoregressive action loss
|
|
2355
3668
|
|
|
2356
|
-
|
|
3669
|
+
discrete_action_loss = self.zero
|
|
3670
|
+
continuous_action_loss = self.zero
|
|
2357
3671
|
|
|
2358
3672
|
if (
|
|
2359
3673
|
self.num_agents == 1 and
|
|
2360
3674
|
add_autoregressive_action_loss and
|
|
3675
|
+
time > 1,
|
|
2361
3676
|
(exists(discrete_actions) or exists(continuous_actions))
|
|
2362
3677
|
):
|
|
2363
3678
|
assert self.action_embedder.has_actions
|
|
2364
3679
|
|
|
3680
|
+
# handle actions having time vs time - 1 length
|
|
3681
|
+
# remove the first action if it is equal to time (as it would come from some agent token in the past)
|
|
3682
|
+
|
|
3683
|
+
if exists(discrete_actions) and discrete_actions.shape[1] == time:
|
|
3684
|
+
discrete_actions = discrete_actions[:, 1:]
|
|
3685
|
+
|
|
3686
|
+
if exists(continuous_actions) and continuous_actions.shape[1] == time:
|
|
3687
|
+
continuous_actions = continuous_actions[:, 1:]
|
|
3688
|
+
|
|
2365
3689
|
# only for 1 agent
|
|
2366
3690
|
|
|
2367
3691
|
agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d')
|
|
2368
3692
|
policy_embed = self.policy_head(agent_tokens[:, :-1])
|
|
2369
3693
|
|
|
3694
|
+
# constitute multi token prediction targets
|
|
3695
|
+
|
|
3696
|
+
discrete_action_targets = continuous_action_targets = None
|
|
3697
|
+
|
|
3698
|
+
if exists(discrete_actions):
|
|
3699
|
+
discrete_action_targets, discrete_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
|
|
3700
|
+
discrete_action_targets = rearrange(discrete_action_targets, 'b t mtp ... -> mtp b t ...')
|
|
3701
|
+
discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
|
|
3702
|
+
|
|
3703
|
+
if exists(continuous_actions):
|
|
3704
|
+
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(continuous_actions, self.multi_token_pred_len)
|
|
3705
|
+
continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
|
|
3706
|
+
continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')
|
|
3707
|
+
|
|
2370
3708
|
discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
|
|
2371
3709
|
policy_embed,
|
|
2372
|
-
discrete_targets =
|
|
2373
|
-
continuous_targets =
|
|
3710
|
+
discrete_targets = discrete_action_targets if exists(discrete_actions) else None,
|
|
3711
|
+
continuous_targets = continuous_action_targets if exists(continuous_actions) else None
|
|
2374
3712
|
)
|
|
2375
3713
|
|
|
2376
3714
|
if exists(discrete_log_probs):
|
|
2377
|
-
|
|
3715
|
+
discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
|
|
3716
|
+
|
|
3717
|
+
if is_var_len:
|
|
3718
|
+
discrete_action_losses = rearrange(-discrete_log_probs, 'mtp b t na -> b t na mtp')
|
|
3719
|
+
discrete_action_loss = reduce(discrete_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
|
|
3720
|
+
else:
|
|
3721
|
+
discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
|
|
2378
3722
|
|
|
2379
3723
|
if exists(continuous_log_probs):
|
|
2380
|
-
|
|
3724
|
+
continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
|
|
3725
|
+
|
|
3726
|
+
if is_var_len:
|
|
3727
|
+
continuous_action_losses = rearrange(-continuous_log_probs, 'mtp b t na -> b t na mtp')
|
|
3728
|
+
continuous_action_loss = reduce(continuous_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
|
|
3729
|
+
else:
|
|
3730
|
+
continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
|
|
3731
|
+
|
|
3732
|
+
# handle loss normalization
|
|
3733
|
+
|
|
3734
|
+
losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss)
|
|
3735
|
+
|
|
3736
|
+
if exists(self.flow_loss_normalizer):
|
|
3737
|
+
flow_loss = self.flow_loss_normalizer(flow_loss, update_ema = update_loss_ema)
|
|
2381
3738
|
|
|
2382
|
-
|
|
3739
|
+
if exists(rewards) and exists(self.reward_loss_normalizer):
|
|
3740
|
+
reward_loss = self.reward_loss_normalizer(reward_loss, update_ema = update_loss_ema)
|
|
3741
|
+
|
|
3742
|
+
if exists(discrete_actions) and exists(self.discrete_actions_loss_normalizer):
|
|
3743
|
+
discrete_action_loss = self.discrete_actions_loss_normalizer(discrete_action_loss, update_ema = update_loss_ema)
|
|
3744
|
+
|
|
3745
|
+
if exists(continuous_actions) and exists(self.continuous_actions_loss_normalizer):
|
|
3746
|
+
continuous_action_loss = self.continuous_actions_loss_normalizer(continuous_action_loss, update_ema = update_loss_ema)
|
|
3747
|
+
|
|
3748
|
+
# gather losses - they sum across the multi token prediction steps for rewards and actions - eq (9)
|
|
2383
3749
|
|
|
2384
3750
|
total_loss = (
|
|
2385
|
-
flow_loss +
|
|
2386
|
-
reward_loss * self.reward_loss_weight +
|
|
2387
|
-
|
|
3751
|
+
flow_loss * self.latent_flow_loss_weight +
|
|
3752
|
+
(reward_loss * self.reward_loss_weight).sum() +
|
|
3753
|
+
(discrete_action_loss * self.discrete_action_loss_weight).sum() +
|
|
3754
|
+
(continuous_action_loss * self.continuous_action_loss_weight).sum()
|
|
2388
3755
|
)
|
|
2389
3756
|
|
|
2390
3757
|
if not return_all_losses:
|
|
2391
3758
|
return total_loss
|
|
2392
3759
|
|
|
2393
|
-
losses = WorldModelLosses(flow_loss, reward_loss, behavior_clone_loss)
|
|
2394
|
-
|
|
2395
3760
|
return total_loss, losses
|
|
2396
|
-
|
|
2397
|
-
# dreamer
|
|
2398
|
-
|
|
2399
|
-
class Dreamer(Module):
|
|
2400
|
-
def __init__(
|
|
2401
|
-
self,
|
|
2402
|
-
video_tokenizer: VideoTokenizer,
|
|
2403
|
-
dynamics_model: DynamicsWorldModel,
|
|
2404
|
-
):
|
|
2405
|
-
super().__init__()
|