dreamer4 0.0.7__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

dreamer4/dreamer4.py CHANGED
@@ -3,19 +3,31 @@ from __future__ import annotations
3
3
  import math
4
4
  from math import ceil, log2
5
5
  from random import random
6
+ from contextlib import nullcontext
6
7
  from collections import namedtuple
7
- from functools import partial
8
+ from functools import partial, wraps
9
+ from dataclasses import dataclass, asdict
8
10
 
9
11
  import torch
10
12
  import torch.nn.functional as F
11
- from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
12
- from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
13
+ from torch.nested import nested_tensor
14
+ from torch.distributions import Normal, kl
15
+ from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
16
+ from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
17
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
13
18
 
14
19
  import torchvision
15
20
  from torchvision.models import VGG16_Weights
16
21
 
17
- from x_mlps_pytorch.normed_mlp import create_mlp
22
+ from torch.optim import Optimizer
23
+ from adam_atan2_pytorch import MuonAdamAtan2
24
+
18
25
  from x_mlps_pytorch.ensemble import Ensemble
26
+ from x_mlps_pytorch.normed_mlp import create_mlp
27
+
28
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
29
+
30
+ from vit_pytorch.vit_with_decorr import DecorrelationLoss
19
31
 
20
32
  from assoc_scan import AssocScan
21
33
 
@@ -27,13 +39,18 @@ from assoc_scan import AssocScan
27
39
  # d - feature dimension
28
40
  # f - frequencies (rotary)
29
41
  # l - logit / predicted bins
42
+ # y - layers of transformer
30
43
  # p - positions (3 for spacetime in this work)
31
44
  # t - time
45
+ # na - action dimension (number of discrete and continuous actions)
32
46
  # g - groups of query heads to key heads (gqa)
33
47
  # vc - video channels
34
48
  # vh, vw - video height and width
49
+ # mtp - multi token prediction length
50
+ # v - video viewpoints
35
51
 
36
52
  import einx
53
+ from einx import add, multiply
37
54
  from einops import einsum, rearrange, repeat, reduce, pack, unpack
38
55
  from einops.layers.torch import Rearrange, Reduce
39
56
 
@@ -53,7 +70,95 @@ except ImportError:
53
70
 
54
71
  LinearNoBias = partial(Linear, bias = False)
55
72
 
56
- TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
73
+ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
74
+
75
+ WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
76
+
77
+ AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
78
+
79
+ TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs'))
80
+
81
+ MaybeTensor = Tensor | None
82
+
83
+ @dataclass
84
+ class Experience:
85
+ latents: Tensor
86
+ video: MaybeTensor = None
87
+ proprio: MaybeTensor = None
88
+ agent_embed: MaybeTensor = None
89
+ rewards: Tensor | None = None
90
+ actions: tuple[MaybeTensor, MaybeTensor] | None = None
91
+ log_probs: tuple[MaybeTensor, MaybeTensor] | None = None
92
+ old_action_unembeds: tuple[MaybeTensor, MaybeTensor] | None = None
93
+ values: MaybeTensor = None
94
+ step_size: int | None = None
95
+ lens: MaybeTensor = None
96
+ is_truncated: MaybeTensor = None
97
+ agent_index: int = 0
98
+ is_from_world_model: bool = True
99
+
100
+ def cpu(self):
101
+ return self.to(torch.device('cpu'))
102
+
103
+ def to(self, device):
104
+ experience_dict = asdict(self)
105
+ experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
106
+ return Experience(**experience_dict)
107
+
108
+ def combine_experiences(
109
+ exps: list[Experiences]
110
+ ) -> Experience:
111
+
112
+ assert len(exps) > 0
113
+
114
+ # set lens if not there
115
+
116
+ for exp in exps:
117
+ latents = exp.latents
118
+ batch, time, device = *latents.shape[:2], latents.device
119
+
120
+ if not exists(exp.lens):
121
+ exp.lens = full((batch,), time, device = device)
122
+
123
+ if not exists(exp.is_truncated):
124
+ exp.is_truncated = full((batch,), True, device = device)
125
+
126
+ # convert to dictionary
127
+
128
+ exps_dict = [asdict(exp) for exp in exps]
129
+
130
+ values, tree_specs = zip(*[tree_flatten(exp_dict) for exp_dict in exps_dict])
131
+
132
+ tree_spec = first(tree_specs)
133
+
134
+ all_field_values = list(zip(*values))
135
+
136
+ # an assert to make sure all fields are either all tensors, or a single matching value (for step size, agent index etc) - can change this later
137
+
138
+ assert all([
139
+ all([is_tensor(v) for v in field_values]) or len(set(field_values)) == 1
140
+ for field_values in all_field_values
141
+ ])
142
+
143
+ concatted = []
144
+
145
+ for field_values in all_field_values:
146
+
147
+ if is_tensor(first(field_values)):
148
+
149
+ field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2))
150
+
151
+ new_field_value = cat(field_values)
152
+ else:
153
+ new_field_value = first(list(set(field_values)))
154
+
155
+ concatted.append(new_field_value)
156
+
157
+ # return experience
158
+
159
+ concat_exp_dict = tree_unflatten(concatted, tree_spec)
160
+
161
+ return Experience(**concat_exp_dict)
57
162
 
58
163
  # helpers
59
164
 
@@ -66,6 +171,15 @@ def default(v, d):
66
171
  def first(arr):
67
172
  return arr[0]
68
173
 
174
+ def xnor(x, y):
175
+ return not (x ^ y)
176
+
177
+ def has_at_least_one(*bools):
178
+ return sum([*map(int, bools)]) > 0
179
+
180
+ def ensure_tuple(t):
181
+ return (t,) if not isinstance(t, tuple) else t
182
+
69
183
  def divisible_by(num, den):
70
184
  return (num % den) == 0
71
185
 
@@ -75,8 +189,88 @@ def sample_prob(prob):
75
189
  def is_power_two(num):
76
190
  return log2(num).is_integer()
77
191
 
192
+ def maybe(fn):
193
+ def inner(t, *args, **kwargs):
194
+ if not exists(t) or not exists(fn):
195
+ return None
196
+ return fn(t)
197
+ return inner
198
+
78
199
  # tensor helpers
79
200
 
201
+ def is_empty(t):
202
+ return t.numel() == 0
203
+
204
+ def lens_to_mask(t, max_len = None):
205
+ if not exists(max_len):
206
+ max_len = t.amax().item()
207
+
208
+ device = t.device
209
+ seq = torch.arange(max_len, device = device)
210
+
211
+ return einx.less('j, i -> i j', seq, t)
212
+
213
+ def masked_mean(t, mask = None):
214
+ if not exists(mask):
215
+ return t.mean()
216
+
217
+ if not mask.any():
218
+ return t[mask].sum()
219
+
220
+ return t[mask].mean()
221
+
222
+ def log(t, eps = 1e-20):
223
+ return t.clamp(min = eps).log()
224
+
225
+ def mean_log_var_to_distr(
226
+ mean_log_var: Tensor
227
+ ) -> Normal:
228
+
229
+ mean, log_var = mean_log_var.unbind(dim = -1)
230
+ std = (0.5 * log_var).exp()
231
+ return Normal(mean, std)
232
+
233
+ def safe_stack(tensors, dim = 0):
234
+ tensors = [*filter(exists, tensors)]
235
+
236
+ if len(tensors) == 0:
237
+ return None
238
+
239
+ return stack(tensors, dim = dim)
240
+
241
+ def safe_cat(tensors, dim):
242
+ tensors = [*filter(exists, tensors)]
243
+
244
+ if len(tensors) == 0:
245
+ return None
246
+ elif len(tensors) == 1:
247
+ return tensors[0]
248
+
249
+ return cat(tensors, dim = dim)
250
+
251
+ def safe_squeeze_first(t):
252
+ if not exists(t):
253
+ return None
254
+
255
+ if t.shape[0] != 1:
256
+ return t
257
+
258
+ return rearrange(t, '1 ... -> ...')
259
+
260
+ def gumbel_noise(t):
261
+ noise = torch.rand_like(t)
262
+ return -log(-log(noise))
263
+
264
+ def gumbel_sample(
265
+ t,
266
+ temperature = 1.,
267
+ dim = -1,
268
+ keepdim = False,
269
+ eps = 1e-10
270
+ ):
271
+ noised = (t / max(temperature, eps)) + gumbel_noise(t)
272
+ return noised.argmax(dim = dim, keepdim = keepdim)
273
+
80
274
  def pack_one(t, pattern):
81
275
  packed, packed_shape = pack([t], pattern)
82
276
 
@@ -99,6 +293,27 @@ def pad_at_dim(
99
293
  zeros = ((0, 0) * dims_from_right)
100
294
  return F.pad(t, (*zeros, *pad), value = value)
101
295
 
296
+ def pad_to_len(t, target_len, *, dim):
297
+ curr_len = t.shape[dim]
298
+
299
+ if curr_len >= target_len:
300
+ return t
301
+
302
+ return pad_at_dim(t, (0, target_len - curr_len), dim = dim)
303
+
304
+ def pad_tensors_at_dim_to_max_len(
305
+ tensors: list[Tensor],
306
+ dims: tuple[int, ...]
307
+ ):
308
+ for dim in dims:
309
+ if dim >= first(tensors).ndim:
310
+ continue
311
+
312
+ max_time = max([t.shape[dim] for t in tensors])
313
+ tensors = [pad_to_len(t, max_time, dim = dim) for t in tensors]
314
+
315
+ return tensors
316
+
102
317
  def align_dims_left(t, aligned_to):
103
318
  shape = t.shape
104
319
  num_right_dims = aligned_to.ndim - t.ndim
@@ -114,8 +329,74 @@ def l2norm(t):
114
329
  def softclamp(t, value = 50.):
115
330
  return (t / value).tanh() * value
116
331
 
332
+ def create_multi_token_prediction_targets(
333
+ t, # (b t ...)
334
+ steps_future,
335
+
336
+ ): # (b t-1 steps ...), (b t-1 steps) - targets and the mask, where mask is False for padding
337
+
338
+ batch, seq_len, device = *t.shape[:2], t.device
339
+
340
+ batch_arange = arange(batch, device = device)
341
+ seq_arange = arange(seq_len, device = device)
342
+ steps_arange = arange(steps_future, device = device)
343
+
344
+ indices = add('t, steps -> t steps', seq_arange, steps_arange)
345
+ mask = indices < seq_len
346
+
347
+ batch_arange = rearrange(batch_arange, 'b -> b 1 1')
348
+
349
+ indices[~mask] = 0
350
+ mask = repeat(mask, 't steps -> b t steps', b = batch)
351
+
352
+ out = t[batch_arange, indices]
353
+
354
+ return out, mask
355
+
117
356
  # loss related
118
357
 
358
+ class LossNormalizer(Module):
359
+
360
+ # the authors mentioned the need for loss normalization in the dynamics transformer
361
+
362
+ def __init__(
363
+ self,
364
+ num_losses: int,
365
+ beta = 0.95,
366
+ eps = 1e-6
367
+ ):
368
+ super().__init__()
369
+ self.register_buffer('exp_avg_sq', torch.ones(num_losses))
370
+ self.beta = beta
371
+ self.eps = eps
372
+
373
+ def forward(
374
+ self,
375
+ losses: Tensor | list[Tensor] | dict[str, Tensor],
376
+ update_ema = None
377
+ ):
378
+ exp_avg_sq, beta = self.exp_avg_sq, self.beta
379
+ update_ema = default(update_ema, self.training)
380
+
381
+ # get the rms value - as mentioned at the end of section 3 in the paper
382
+
383
+ rms = exp_avg_sq.sqrt()
384
+
385
+ if update_ema:
386
+ decay = 1. - beta
387
+
388
+ # update the ema
389
+
390
+ exp_avg_sq.lerp_(losses.detach().square(), decay)
391
+
392
+ # then normalize
393
+
394
+ assert losses.numel() == rms.numel()
395
+
396
+ normed_losses = losses / rms.clamp(min = self.eps)
397
+
398
+ return normed_losses
399
+
119
400
  class LPIPSLoss(Module):
120
401
  def __init__(
121
402
  self,
@@ -267,1007 +548,2281 @@ class SymExpTwoHot(Module):
267
548
 
268
549
  return inverse_pack(encoded, '* l')
269
550
 
270
- # generalized advantage estimate
271
-
272
- @torch.no_grad()
273
- def calc_gae(
274
- rewards,
275
- values,
276
- masks,
277
- gamma = 0.99,
278
- lam = 0.95,
279
- use_accelerated = None
280
- ):
281
- assert values.shape[-1] == rewards.shape[-1]
282
- use_accelerated = default(use_accelerated, rewards.is_cuda)
283
-
284
- values = F.pad(values, (0, 1), value = 0.)
285
- values, values_next = values[..., :-1], values[..., 1:]
286
-
287
- delta = rewards + gamma * values_next * masks - values
288
- gates = gamma * lam * masks
289
-
290
- scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
291
-
292
- gae = scan(gates, delta)
293
-
294
- returns = gae + values
551
+ # action related
295
552
 
296
- return returns
553
+ ActionEmbeds = namedtuple('ActionEmbed', ('discrete', 'continuous'))
297
554
 
298
- # golden gate rotary - Jerry Xiong, PhD student at UIUC
299
- # https://jerryxio.ng/posts/nd-rope/
300
-
301
- def _phi(m):
302
- x = 2.
303
- for _ in range(10):
304
- x = (1. + x) ** (1. / (m + 1.))
305
- return x
306
-
307
- def make_directions(n, d):
308
- g = _phi(d)
309
- alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
310
- i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
311
- z = torch.fmod(i * alpha, 1.0)
312
- directions = torch.erfinv(2.0 * z - 1.0)
313
- directions = l2norm(directions)
314
- return directions.float()
315
-
316
- class GoldenGateRoPENd(Module):
555
+ class ActionEmbedder(Module):
317
556
  def __init__(
318
557
  self,
319
- dim_pos,
320
- heads,
321
- dim_head,
322
- rope_min_freq = 1.,
323
- rope_max_freq = 10000.,
324
- rope_p_zero_freqs = 0., # proportion of frequencies set to 0
558
+ dim,
559
+ *,
560
+ num_discrete_actions: int | tuple[int, ...] = 0,
561
+ num_continuous_actions = 0,
562
+ continuous_norm_stats: tuple[tuple[float, float], ...] | None = None,
563
+ can_unembed = False,
564
+ unembed_dim = None,
565
+ num_unembed_preds = 1,
566
+ squeeze_unembed_preds = True # will auto-squeeze if prediction is just 1
325
567
  ):
326
568
  super().__init__()
327
- assert divisible_by(dim_head, 2)
328
569
 
329
- n_freqs = dim_head // 2
330
- n_zero_freqs = round(rope_p_zero_freqs * n_freqs)
570
+ # handle discrete actions
331
571
 
332
- omega = cat((
333
- torch.zeros(n_zero_freqs),
334
- rope_min_freq * (rope_max_freq / rope_min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
335
- ))
572
+ num_discrete_actions = tensor(ensure_tuple(num_discrete_actions))
573
+ total_discrete_actions = num_discrete_actions.sum().item()
336
574
 
337
- directions = make_directions(heads * n_freqs, dim_pos)
338
- directions = rearrange(directions, '(h f) p -> h f p', h = heads)
575
+ self.num_discrete_action_types = len(num_discrete_actions)
576
+ self.discrete_action_embed = Embedding(total_discrete_actions, dim)
339
577
 
340
- omega_expanded = rearrange(omega, 'f -> f 1')
341
- self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
578
+ self.register_buffer('num_discrete_actions', num_discrete_actions, persistent = False)
342
579
 
343
- def forward(
344
- self,
345
- pos # (b n p)
346
- ):
580
+ # continuous actions
347
581
 
348
- freqs = rearrange(self.freqs, 'h f p -> h 1 f p')
349
- positions = rearrange(pos.float(), 'n p -> 1 n 1 p')
582
+ self.num_continuous_action_types = num_continuous_actions
583
+ self.continuous_action_embed = Embedding(num_continuous_actions, dim)
350
584
 
351
- # thetas for freqs and positions (batch, head, seq, freq)
585
+ self.continuous_need_norm = exists(continuous_norm_stats)
352
586
 
353
- theta = reduce(freqs * positions, 'h n f p -> h n f', 'sum')
587
+ if self.continuous_need_norm:
588
+ self.register_buffer('continuous_norm_stats', tensor(continuous_norm_stats))
354
589
 
355
- return cat((theta, theta), dim = -1)
590
+ # defaults
356
591
 
357
- class Rotary1D(Module):
358
- def __init__(
359
- self,
360
- dim_head,
361
- theta = 10000.
362
- ):
363
- super().__init__()
364
- inv_freq = 1.0 / (theta ** (arange(0, dim_head, 2).float() / dim_head))
365
- self.register_buffer('inv_freq', inv_freq)
592
+ self.register_buffer('default_discrete_action_types', arange(self.num_discrete_action_types), persistent = False)
593
+ self.register_buffer('default_continuous_action_types', arange(self.num_continuous_action_types), persistent = False)
366
594
 
367
- def forward(
368
- self,
369
- seq_len
370
- ):
371
- device, dtype = self.inv_freq.device, self.inv_freq.dtype
595
+ # calculate offsets
372
596
 
373
- t = torch.arange(seq_len, device = device).type(dtype)
374
- freqs = einsum(t, self.inv_freq, 'i, j -> i j')
597
+ offsets = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0)
598
+ self.register_buffer('discrete_action_offsets', offsets, persistent = False)
375
599
 
376
- return cat((freqs, freqs), dim = -1)
600
+ # unembedding
377
601
 
602
+ self.can_unembed = can_unembed
378
603
 
379
- def apply_rotations(
380
- rotations, # (h n d) | (n d)
381
- t # (b h n d)
382
- ):
383
- heads, dtype = t.shape[1], t.dtype
384
- t = t.float()
604
+ self.num_unembed_preds = num_unembed_preds
605
+ self.squeeze_unembed_preds = squeeze_unembed_preds
385
606
 
386
- # handle gqa for rotary
607
+ if not can_unembed:
608
+ return
387
609
 
388
- if rotations.ndim == 3 and rotations.shape[0] < heads:
389
- rotary_heads = rotations.shape[0]
610
+ unembed_dim = default(unembed_dim, dim)
611
+ self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, num_unembed_preds, unembed_dim) * 1e-2)
390
612
 
391
- assert divisible_by(heads, rotary_heads)
392
- groups = heads // rotary_heads
393
- rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups)
613
+ discrete_action_index = arange(total_discrete_actions)
394
614
 
395
- x1, x2 = t.chunk(2, dim = -1)
396
- rotated_half_t = cat((-x2, x1), dim = -1)
615
+ padded_num_discrete_actions = F.pad(num_discrete_actions, (1, 0), value = 0)
616
+ exclusive_cumsum = padded_num_discrete_actions.cumsum(dim = -1)
397
617
 
398
- # rotate in the positions
618
+ discrete_action_mask = (
619
+ einx.greater_equal('j, i -> i j', discrete_action_index, exclusive_cumsum[:-1]) &
620
+ einx.less('j, i -> i j', discrete_action_index, exclusive_cumsum[1:])
621
+ )
399
622
 
400
- rotated = t * rotations.cos() + rotated_half_t * rotations.sin()
401
- return rotated.type(dtype)
623
+ self.register_buffer('discrete_action_mask', discrete_action_mask, persistent = False)
402
624
 
403
- # multi-head rmsnorm
625
+ self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, num_unembed_preds, unembed_dim, 2) * 1e-2)
404
626
 
405
- class MultiHeadRMSNorm(Module):
406
- def __init__(
627
+ def embed_parameters(self):
628
+ return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()])
629
+
630
+ def unembed_parameters(self):
631
+ return set([self.discrete_action_unembed, self.continuous_action_unembed])
632
+
633
+ @property
634
+ def device(self):
635
+ return self.discrete_action_offsets.device
636
+
637
+ @property
638
+ def has_actions(self):
639
+ return self.num_discrete_action_types > 0 or self.num_continuous_action_types > 0
640
+
641
+ def cast_action_types(
407
642
  self,
408
- dim_head,
409
- heads = 8
643
+ action_types = None
410
644
  ):
411
- super().__init__()
412
- self.scale = dim_head ** 0.5
413
- self.gamma = Parameter(torch.zeros(heads, dim_head)) # weight decay friendly
645
+ if exists(action_types) and not is_tensor(action_types):
646
+ if isinstance(action_types, int):
647
+ action_types = (action_types,)
414
648
 
415
- def forward(
649
+ action_types = tensor(action_types, device = self.device)
650
+
651
+ return action_types
652
+
653
+ def unembed(
416
654
  self,
417
- x # (b h n d)
418
- ):
419
- normed = l2norm(x)
420
- scale = (self.gamma + 1.) * self.scale
421
- return einx.multiply('... h n d, h d', normed, scale)
655
+ embeds, # (... d)
656
+ discrete_action_types = None, # (na)
657
+ continuous_action_types = None, # (na)
658
+ return_split_discrete = False,
659
+ pred_head_index: int | Tensor | None = None
422
660
 
423
- # naive attend
661
+ ): # (... discrete_na), (... continuous_na 2)
424
662
 
425
- def naive_attend(
426
- q, k, v,
427
- softclamp_value = None,
428
- scale = None,
429
- causal = False,
430
- causal_block_size = 1,
431
- mask = None
432
- ):
663
+ device = embeds.device
433
664
 
434
- if not exists(scale):
435
- scale = q.shape[-1] ** -0.5
665
+ assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init'
436
666
 
437
- # grouped query attention
667
+ # handle only one prediction head during inference
438
668
 
439
- groups = q.shape[1] // k.shape[1]
669
+ if exists(pred_head_index) and isinstance(pred_head_index, int):
670
+ pred_head_index = tensor(pred_head_index, device = device)
440
671
 
441
- q = rearrange(q, 'b (h g) ... -> b h g ...', g = groups)
672
+ # if pred_head_index given as a solo int, just assume we want to squeeze out the prediction head dimension
442
673
 
443
- # similarity
674
+ squeeze_one_pred_head = exists(pred_head_index) and pred_head_index.ndim == 0
444
675
 
445
- sim = einsum(q, k, 'b h g i d, b h j d -> b h g i j')
676
+ # get action types
446
677
 
447
- # scale and attention
678
+ discrete_action_types, continuous_action_types = tuple(self.cast_action_types(t) for t in (discrete_action_types, continuous_action_types))
448
679
 
449
- sim = sim * scale
680
+ # discrete actions
450
681
 
451
- # softclamping a la gemma 3
682
+ discrete_action_logits = None
452
683
 
453
- if exists(softclamp_value):
454
- sim = softclamp(sim, softclamp_value)
684
+ if self.num_discrete_action_types > 0:
455
685
 
456
- # masking
686
+ discrete_action_unembed = self.discrete_action_unembed
457
687
 
458
- mask_value = -torch.finfo(sim.dtype).max
688
+ if exists(discrete_action_types):
689
+ discrete_action_mask = self.discrete_action_mask[discrete_action_types].any(dim = 0)
459
690
 
460
- if exists(mask):
461
- sim = sim.masked_fill(~mask, mask_value)
691
+ discrete_action_unembed = discrete_action_unembed[discrete_action_mask]
462
692
 
463
- if causal:
464
- is_blocked_causal = causal_block_size > 1
465
- i, j = sim.shape[-2:]
693
+ if exists(pred_head_index):
694
+ discrete_action_unembed = discrete_action_unembed.index_select(1, pred_head_index)
466
695
 
467
- if is_blocked_causal:
468
- i = ceil(i / causal_block_size)
469
- j = ceil(j / causal_block_size)
696
+ discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na mtp d -> mtp ... na')
470
697
 
471
- causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
698
+ if self.squeeze_unembed_preds or squeeze_one_pred_head:
699
+ discrete_action_logits = safe_squeeze_first(discrete_action_logits)
472
700
 
473
- if causal_block_size > 1:
474
- causal_mask = repeat(causal_mask, 'i j -> (i b1) (j b2)', b1 = causal_block_size, b2 = causal_block_size)
475
- causal_mask = causal_mask[:sim.shape[-2], :sim.shape[-1]]
701
+ # whether to split the discrete action logits by the number of actions per action type
476
702
 
477
- sim = sim.masked_fill(causal_mask, mask_value)
703
+ if exists(discrete_action_logits) and return_split_discrete:
478
704
 
479
- # attend
705
+ split_sizes = self.num_discrete_actions[discrete_action_types] if exists(discrete_action_types) else self.num_discrete_actions
480
706
 
481
- attn = sim.softmax(dim = -1)
707
+ discrete_action_logits = discrete_action_logits.split(split_sizes.tolist(), dim = -1)
482
708
 
483
- # aggregate
709
+ # continuous actions
484
710
 
485
- out = einsum(attn, v, 'b h g i j, b h j d -> b h g i d')
711
+ continuous_action_mean_log_var = None
486
712
 
487
- # merge the groups
713
+ if self.num_continuous_action_types > 0:
488
714
 
489
- return rearrange(out, 'b h g i d -> b (h g) i d')
715
+ continuous_action_unembed = self.continuous_action_unembed
490
716
 
491
- # flex attention related and factory function for attend depending on whether on cuda + flex attention available
717
+ if exists(continuous_action_types):
718
+ continuous_action_unembed = continuous_action_unembed[continuous_action_types]
492
719
 
493
- def block_mask_causal(block_size):
720
+ if exists(pred_head_index):
721
+ continuous_action_unembed = continuous_action_unembed.index_select(1, pred_head_index)
494
722
 
495
- def inner(b, h, q, k):
496
- bq = q // block_size
497
- bk = k // block_size
498
- return bq >= bk
723
+ continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na mtp d two -> mtp ... na two')
499
724
 
500
- return inner
725
+ if self.squeeze_unembed_preds or squeeze_one_pred_head:
726
+ continuous_action_mean_log_var = safe_squeeze_first(continuous_action_mean_log_var)
501
727
 
502
- def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = False):
503
- bq = q % seq_len
504
- bk = k % seq_len
728
+ return discrete_action_logits, continuous_action_mean_log_var
505
729
 
506
- is_special_start_index = seq_len - num_tokens
730
+ def sample(
731
+ self,
732
+ embed,
733
+ discrete_temperature = 1.,
734
+ continuous_temperature = 1.,
735
+ inverse_norm_continuous = None,
736
+ pred_head_index: int | Tensor | None = None,
737
+ squeeze = True,
738
+ **kwargs
739
+ ):
740
+ inverse_norm_continuous = default(inverse_norm_continuous, self.continuous_need_norm)
507
741
 
508
- q_is_special = q >= is_special_start_index
509
- k_is_special = k >= is_special_start_index
742
+ discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, pred_head_index = pred_head_index, **kwargs)
510
743
 
511
- if special_attend_only_itself:
512
- out = ~(q_is_special & ~k_is_special) # modality attends to everything, but latent can only attend to itself (proposed attention pattern for encoder of video tokenizer)
513
- else:
514
- out = ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens
744
+ sampled_discrete = sampled_continuous = None
515
745
 
516
- return out
746
+ if exists(discrete_logits):
747
+ sampled_discrete = []
517
748
 
518
- def block_mask_special_tokens_right(
519
- seq_len,
520
- num_tokens
521
- ):
522
- def inner(b, h, q, k):
523
- return special_token_mask(q, k, seq_len, num_tokens)
524
- return inner
749
+ for one_discrete_logits in discrete_logits:
750
+ sampled_discrete.append(gumbel_sample(one_discrete_logits, temperature = discrete_temperature, keepdim = True))
525
751
 
526
- def compose_mask(mask1, mask2):
527
- def inner(b, h, q, k):
528
- return mask1(b, h, q, k) & mask2(b, h, q, k)
752
+ sampled_discrete = cat(sampled_discrete, dim = -1)
529
753
 
530
- return inner
754
+ if exists(continuous_mean_log_var):
755
+ mean, log_var = continuous_mean_log_var.unbind(dim = -1)
756
+ std = (0.5 * log_var).exp()
531
757
 
532
- def block_mask_noop(b, h, q, k):
533
- return b >= 0
758
+ sampled_continuous = mean + std * torch.randn_like(mean) * continuous_temperature
534
759
 
535
- def score_mod_softclamp(value):
536
- def inner(sim, b, h, q, k):
537
- if not exists(value):
538
- return sim
760
+ # maybe inverse norm
539
761
 
540
- sim = sim / value
541
- sim = torch.tanh(sim)
542
- sim = sim * value
543
- return sim
762
+ if inverse_norm_continuous:
763
+ norm_mean, norm_std = self.continuous_norm_stats.unbind(dim = -1)
764
+ sampled_continuous = (sampled_continuous * norm_std) + norm_mean
544
765
 
545
- return inner
766
+ return sampled_discrete, sampled_continuous
546
767
 
547
- # factory for attend function
768
+ def log_probs(
769
+ self,
770
+ embeds, # (... d)
771
+ discrete_targets = None, # (... na)
772
+ continuous_targets = None, # (... na)
773
+ discrete_action_types = None, # (na)
774
+ continuous_action_types = None, # (na)
775
+ pred_head_index: int | Tensor | None = None,
776
+ parallel_discrete_calc = None,
777
+ return_entropies = False
778
+ ):
779
+ parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1)
548
780
 
549
- def get_attend_fn(
550
- use_flex,
551
- seq_len,
552
- k_seq_len,
553
- causal = False,
554
- causal_block_size = 1,
555
- softclamp_value = 50.,
556
- num_special_tokens = 0, # special tokens are latents / agents
557
- block_size_per_special = None, # defaults to k_seq_len
558
- special_attend_only_itself = False, # by default, modality only attends to itself while special sees everything, but if turned True, will be the inverse - special can only attend to itself but modality can attend everything
559
- device = None
560
- ):
561
- block_size_per_special = default(block_size_per_special, k_seq_len)
781
+ discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, pred_head_index = pred_head_index, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
562
782
 
563
- if use_flex:
564
- # flex pathway
783
+ # discrete
565
784
 
566
- block_mask_fn = block_mask_causal(causal_block_size) if causal else block_mask_noop
785
+ discrete_log_probs = None
786
+ discrete_entropies = None
567
787
 
568
- if num_special_tokens > 0:
569
- special_block_mask = block_mask_special_tokens_right(block_size_per_special, num_special_tokens, special_attend_only_itself)
570
- block_mask_fn = compose_mask(block_mask_fn, special_block_mask)
788
+ if exists(discrete_targets):
571
789
 
572
- block_mask = create_block_mask(block_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = k_seq_len)
790
+ if parallel_discrete_calc:
791
+ # use nested tensors
573
792
 
574
- score_mod = score_mod_softclamp(softclamp_value)
575
- attend_fn = partial(flex_attention, block_mask = block_mask, score_mod = score_mod, enable_gqa = True)
576
- else:
577
- # naive pathway
793
+ jagged_dims = tuple(t.shape[-1] for t in discrete_action_logits)
578
794
 
579
- mask = None
580
- if num_special_tokens > 0:
581
- q_seq = torch.arange(seq_len, device = device)[:, None]
582
- k_seq = torch.arange(k_seq_len, device = device)[None, :]
795
+ discrete_action_logits = cat(discrete_action_logits, dim = -1)
583
796
 
584
- mask = special_token_mask(q_seq, k_seq, block_size_per_special, num_special_tokens, special_attend_only_itself)
797
+ discrete_action_logits, inverse_pack_lead_dims = pack_one(discrete_action_logits, '* l')
798
+ batch = discrete_action_logits.shape[0]
585
799
 
586
- attend_fn = partial(naive_attend, causal = causal, causal_block_size = causal_block_size, mask = mask, softclamp_value = softclamp_value)
800
+ discrete_action_logits = rearrange(discrete_action_logits, 'b l -> (b l)')
587
801
 
588
- return attend_fn
802
+ # to nested tensor
589
803
 
590
- # attention
804
+ nested_logits = nested_tensor(discrete_action_logits.split(jagged_dims * batch), layout = torch.jagged, device = self.device, requires_grad = True)
591
805
 
592
- class Attention(Module):
593
- def __init__(
806
+ prob = nested_logits.softmax(dim = -1)
807
+
808
+ log_probs = log(prob)
809
+
810
+ # maybe entropy
811
+
812
+ if return_entropies:
813
+ discrete_entropies = (-prob * log_probs).sum(dim = -1, keepdim = True)
814
+ discrete_entropies = cat(discrete_entropies.unbind())
815
+ discrete_entropies = rearrange(discrete_entropies, '(b na) -> b na', b = batch)
816
+
817
+ discrete_entropies = inverse_pack_lead_dims(discrete_entropies, '* na')
818
+
819
+ # back to regular tensor
820
+
821
+ log_probs = cat(log_probs.unbind())
822
+ log_probs = rearrange(log_probs, '(b l) -> b l', b = batch)
823
+
824
+ log_probs = inverse_pack_lead_dims(log_probs)
825
+
826
+ # get indices to gather
827
+
828
+ discrete_action_types = default(discrete_action_types, self.default_discrete_action_types)
829
+
830
+ num_discrete_actions = self.num_discrete_actions[discrete_action_types]
831
+
832
+ offset = F.pad(num_discrete_actions.cumsum(dim = -1), (1, -1), value = 0)
833
+ log_prob_indices = discrete_targets + offset
834
+
835
+ # gather
836
+
837
+ discrete_log_probs = log_probs.gather(-1, log_prob_indices)
838
+
839
+ else:
840
+ discrete_log_probs = []
841
+ discrete_entropies = []
842
+
843
+ for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):
844
+
845
+ one_discrete_probs = one_discrete_action_logit.softmax(dim = -1)
846
+ one_discrete_log_probs = log(one_discrete_probs)
847
+ one_discrete_target = rearrange(one_discrete_target, '... -> ... 1')
848
+
849
+ log_prob = one_discrete_log_probs.gather(-1, one_discrete_target)
850
+ discrete_log_probs.append(log_prob)
851
+
852
+ if return_entropies:
853
+ entropy = (-one_discrete_probs * one_discrete_log_probs).sum(dim = -1)
854
+ discrete_entropies.append(entropy)
855
+
856
+ discrete_log_probs = cat(discrete_log_probs, dim = -1)
857
+
858
+ if return_entropies:
859
+ discrete_entropies = stack(discrete_entropies, dim = -1)
860
+
861
+ # continuous
862
+
863
+ continuous_log_probs = None
864
+ continuous_entropies = None
865
+
866
+ if exists(continuous_targets):
867
+ distr = mean_log_var_to_distr(continuous_action_mean_log_var)
868
+ continuous_log_probs = distr.log_prob(continuous_targets)
869
+
870
+ if return_entropies:
871
+ continuous_entropies = distr.entropy()
872
+
873
+ log_probs = (discrete_log_probs, continuous_log_probs)
874
+
875
+ if not return_entropies:
876
+ return log_probs
877
+
878
+ entropies = (discrete_entropies, continuous_entropies)
879
+
880
+ return log_probs, entropies
881
+
882
+ def kl_div(
594
883
  self,
595
- dim,
596
- dim_head = 64,
597
- query_heads = None,
598
- heads = 8,
599
- pre_rmsnorm = True,
600
- ):
601
- super().__init__()
602
- self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
884
+ src: tuple[MaybeTensor, MaybeTensor],
885
+ tgt: tuple[MaybeTensor, MaybeTensor],
886
+ reduce_across_num_actions = True
887
+ ) -> tuple[MaybeTensor, MaybeTensor]:
603
888
 
604
- # setup grouped query attention
889
+ src_discrete, src_continuous = src
890
+ tgt_discrete, tgt_continuous = tgt
605
891
 
606
- query_heads = default(query_heads, heads)
607
- assert query_heads >= heads and divisible_by(query_heads, heads)
892
+ discrete_kl_div = None
608
893
 
609
- # scaling, splitting and merging of heads
894
+ # split discrete if it is not already (multiple discrete actions)
610
895
 
611
- self.split_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
612
- self.merge_heads = Rearrange('b h n d -> b n (h d)')
896
+ if exists(src_discrete):
613
897
 
614
- dim_q_inner = dim_head * query_heads
615
- dim_kv_inner = dim_head * heads
898
+ discrete_split = self.num_discrete_actions.tolist()
616
899
 
617
- self.to_q = LinearNoBias(dim, dim_q_inner)
618
- self.to_kv = LinearNoBias(dim, dim_kv_inner * 2)
619
- self.to_out = LinearNoBias(dim_q_inner, dim)
900
+ if is_tensor(src_discrete):
901
+ src_discrete = src_discrete.split(discrete_split, dim = -1)
620
902
 
621
- # stability related
903
+ if is_tensor(tgt_discrete):
904
+ tgt_discrete = tgt_discrete.split(discrete_split, dim = -1)
905
+
906
+ discrete_kl_divs = []
907
+
908
+ for src_logit, tgt_logit in zip(src_discrete, tgt_discrete):
909
+
910
+ src_log_probs = src_logit.log_softmax(dim = -1)
911
+ tgt_prob = tgt_logit.softmax(dim = -1)
912
+
913
+ one_discrete_kl_div = F.kl_div(src_log_probs, tgt_prob, reduction = 'none')
914
+
915
+ discrete_kl_divs.append(one_discrete_kl_div.sum(dim = -1))
916
+
917
+ discrete_kl_div = stack(discrete_kl_divs, dim = -1)
918
+
919
+ # calculate kl divergence for continuous
920
+
921
+ continuous_kl_div = None
922
+
923
+ if exists(src_continuous):
924
+ src_normal = mean_log_var_to_distr(src_continuous)
925
+ tgt_normal = mean_log_var_to_distr(tgt_continuous)
926
+
927
+ continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
928
+
929
+ # maybe reduce
930
+
931
+ if reduce_across_num_actions:
932
+ if exists(discrete_kl_div):
933
+ discrete_kl_div = discrete_kl_div.sum(dim = -1)
622
934
 
623
- self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
624
- self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
935
+ if exists(continuous_kl_div):
936
+ continuous_kl_div = continuous_kl_div.sum(dim = -1)
937
+
938
+ return discrete_kl_div, continuous_kl_div
625
939
 
626
940
  def forward(
627
941
  self,
628
- tokens, # (b n d)
629
- kv_cache = None,
630
- return_kv_cache = False,
631
- rotary_pos_emb = None,
632
- attend_fn: Callable | None = None
942
+ *,
943
+ discrete_actions = None, # (... na)
944
+ continuous_actions = None, # (... na)
945
+ discrete_action_types = None, # (na)
946
+ continuous_action_types = None, # (na)
947
+ return_sum_pooled_embeds = True
633
948
  ):
634
- tokens, inverse_packed_batch = pack_one(tokens, '* n d')
635
949
 
636
- tokens = self.norm(tokens)
950
+ discrete_embeds = continuous_embeds = None
637
951
 
638
- q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
952
+ if exists(discrete_actions):
639
953
 
640
- # split heads
954
+ discrete_action_types = default(discrete_action_types, self.default_discrete_action_types)
641
955
 
642
- q, k, v = map(self.split_heads, (q, k, v))
956
+ discrete_action_types = self.cast_action_types(discrete_action_types)
643
957
 
644
- # qk rmsnorm
958
+ offsets = self.discrete_action_offsets[discrete_action_types]
645
959
 
646
- q = self.q_heads_rmsnorm(q)
647
- k = self.k_heads_rmsnorm(k)
960
+ assert offsets.shape[-1] == discrete_actions.shape[-1], 'mismatched number of discrete actions'
648
961
 
649
- # caching
962
+ # offset the discrete actions based on the action types passed in (by default all discrete actions) and the calculated offset
650
963
 
651
- if exists(kv_cache):
652
- ck, cv = kv_cache
653
- k = cat((ck, k), dim = -2)
654
- v = cat((cv, v), dim = -2)
964
+ discrete_actions_offsetted = add('... na, na', discrete_actions, offsets)
965
+ discrete_embeds = self.discrete_action_embed(discrete_actions_offsetted)
655
966
 
656
- # rotary
967
+ if exists(continuous_actions):
968
+ continuous_action_types = default(continuous_action_types, self.default_continuous_action_types)
657
969
 
658
- if exists(rotary_pos_emb):
659
- q = apply_rotations(rotary_pos_emb, q)
660
- k = apply_rotations(rotary_pos_emb, k)
970
+ continuous_action_types = self.cast_action_types(continuous_action_types)
661
971
 
662
- # attention
972
+ assert continuous_action_types.shape[-1] == continuous_actions.shape[-1], 'mismatched number of continuous actions'
663
973
 
664
- attend_fn = default(attend_fn, naive_attend)
974
+ continuous_action_embed = self.continuous_action_embed(continuous_action_types)
665
975
 
666
- out = attend_fn(q, k, v)
976
+ # maybe normalization
667
977
 
668
- # merge heads
978
+ if self.continuous_need_norm:
979
+ norm_mean, norm_std = self.continuous_norm_stats.unbind(dim = -1)
980
+ continuous_actions = (continuous_actions - norm_mean) / norm_std.clamp(min = 1e-6)
669
981
 
670
- out = self.merge_heads(out)
982
+ # continuous embed is just the selected continuous action type with the scale
671
983
 
672
- # combine heads
984
+ continuous_embeds = multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions)
673
985
 
674
- out = self.to_out(out)
986
+ # return not pooled
675
987
 
676
- out = inverse_packed_batch(out)
988
+ if not return_sum_pooled_embeds:
989
+ return ActionEmbeds(discrete_embeds, continuous_embeds)
677
990
 
678
- if not return_kv_cache:
679
- return out
991
+ # handle sum pooling, which is what they did in the paper for all the actions
680
992
 
681
- return out, stack((k, v))
993
+ pooled = 0.
682
994
 
683
- # feedforward
995
+ if exists(discrete_embeds):
996
+ pooled = pooled + reduce(discrete_embeds, '... na d -> ... d', 'sum')
684
997
 
685
- class SwiGLUFeedforward(Module):
998
+ if exists(continuous_embeds):
999
+ pooled = pooled + reduce(continuous_embeds, '... na d -> ... d', 'sum')
1000
+
1001
+ return pooled
1002
+
1003
+ # generalized advantage estimate
1004
+
1005
+ @torch.no_grad()
1006
+ def calc_gae(
1007
+ rewards,
1008
+ values,
1009
+ masks = None,
1010
+ gamma = 0.99,
1011
+ lam = 0.95,
1012
+ use_accelerated = None
1013
+ ):
1014
+ assert values.shape[-1] == rewards.shape[-1]
1015
+ use_accelerated = default(use_accelerated, rewards.is_cuda)
1016
+
1017
+ if not exists(masks):
1018
+ masks = torch.ones_like(values)
1019
+
1020
+ values = F.pad(values, (0, 1), value = 0.)
1021
+ values, values_next = values[..., :-1], values[..., 1:]
1022
+
1023
+ delta = rewards + gamma * values_next * masks - values
1024
+ gates = gamma * lam * masks
1025
+
1026
+ scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
1027
+
1028
+ gae = scan(gates, delta)
1029
+
1030
+ returns = gae + values
1031
+
1032
+ return returns
1033
+
1034
+ # rotary embeddings for time
1035
+
1036
+ class Rotary1D(Module):
686
1037
  def __init__(
687
1038
  self,
688
- dim,
689
- expansion_factor = 4,
690
- pre_rmsnorm = True
1039
+ dim_head,
1040
+ theta = 10000.
691
1041
  ):
692
1042
  super().__init__()
693
- self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
1043
+ inv_freq = 1.0 / (theta ** (arange(0, dim_head, 2).float() / dim_head))
1044
+ self.register_buffer('inv_freq', inv_freq)
694
1045
 
695
- dim_inner = int(dim * expansion_factor * 2 / 3)
1046
+ def forward(
1047
+ self,
1048
+ seq_len,
1049
+ offset = 0
1050
+ ):
1051
+ device, dtype = self.inv_freq.device, self.inv_freq.dtype
696
1052
 
697
- self.proj_in = Linear(dim, dim_inner * 2)
698
- self.proj_out = Linear(dim_inner, dim)
1053
+ t = torch.arange(seq_len, device = device).type(dtype) + offset
1054
+ freqs = einsum(t, self.inv_freq, 'i, j -> i j')
699
1055
 
700
- def forward(self, x):
701
- x = self.norm(x)
1056
+ return cat((freqs, freqs), dim = -1)
702
1057
 
703
- x, gates = self.proj_in(x).chunk(2, dim = -1)
704
- x = x * F.gelu(gates)
705
1058
 
706
- return self.proj_out(x)
1059
+ def apply_rotations(
1060
+ rotations, # (h n d) | (n d)
1061
+ t # (b h n d)
1062
+ ):
707
1063
 
708
- # video tokenizer
1064
+ heads, seq_len, dtype = *t.shape[1:3], t.dtype
709
1065
 
710
- class VideoTokenizer(Module):
1066
+ rotations_seq_len = rotations.shape[-2]
1067
+
1068
+ # handle kv caching with rotations
1069
+
1070
+ if rotations_seq_len > seq_len:
1071
+ rotations = rotations[-seq_len:]
1072
+
1073
+ # precision
1074
+
1075
+ t = t.float()
1076
+
1077
+ # handle gqa for rotary
1078
+
1079
+ if rotations.ndim == 3 and rotations.shape[0] < heads:
1080
+ rotary_heads = rotations.shape[0]
1081
+
1082
+ assert divisible_by(heads, rotary_heads)
1083
+ groups = heads // rotary_heads
1084
+ rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups)
1085
+
1086
+ x1, x2 = t.chunk(2, dim = -1)
1087
+ rotated_half_t = cat((-x2, x1), dim = -1)
1088
+
1089
+ # rotate in the positions
1090
+
1091
+ rotated = t * rotations.cos() + rotated_half_t * rotations.sin()
1092
+ return rotated.type(dtype)
1093
+
1094
+ # multi-head rmsnorm
1095
+
1096
+ class MultiHeadRMSNorm(Module):
711
1097
  def __init__(
712
1098
  self,
713
- dim,
714
- dim_latent,
715
- patch_size,
716
- image_height = None,
717
- image_width = None,
718
- num_latent_tokens = 4,
719
- encoder_depth = 4,
720
- decoder_depth = 4,
721
- attn_kwargs: dict = dict(),
722
- attn_dim_head = 64,
723
- attn_heads = 8,
724
- attn_softclamp_value = 50.,
725
- ff_kwargs: dict = dict(),
726
- decoder_pos_mlp_depth = 2,
727
- channels = 3,
728
- per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
729
- lpips_loss_network: Module | None = None,
730
- lpips_loss_weight = 0.2,
731
- nd_rotary_kwargs: dict = dict(
732
- rope_min_freq = 1.,
733
- rope_max_freq = 10000.,
734
- rope_p_zero_freqs = 0.
1099
+ dim_head,
1100
+ heads = 8
1101
+ ):
1102
+ super().__init__()
1103
+ self.scale = dim_head ** 0.5
1104
+ self.gamma = Parameter(torch.zeros(heads, dim_head)) # weight decay friendly
1105
+
1106
+ def forward(
1107
+ self,
1108
+ x # (b h n d)
1109
+ ):
1110
+ normed = l2norm(x)
1111
+ scale = (self.gamma + 1.) * self.scale
1112
+ return multiply('... h n d, h d', normed, scale)
1113
+
1114
+ # naive attend
1115
+
1116
+ def naive_attend(
1117
+ q, k, v,
1118
+ softclamp_value = None,
1119
+ scale = None,
1120
+ causal = False,
1121
+ causal_block_size = 1,
1122
+ mask = None
1123
+ ):
1124
+
1125
+ if not exists(scale):
1126
+ scale = q.shape[-1] ** -0.5
1127
+
1128
+ # grouped query attention
1129
+
1130
+ groups = q.shape[1] // k.shape[1]
1131
+
1132
+ q = rearrange(q, 'b (h g) ... -> b h g ...', g = groups)
1133
+
1134
+ # similarity
1135
+
1136
+ sim = einsum(q, k, 'b h g i d, b h j d -> b h g i j')
1137
+
1138
+ # scale and attention
1139
+
1140
+ sim = sim * scale
1141
+
1142
+ # softclamping a la gemma 3
1143
+
1144
+ if exists(softclamp_value):
1145
+ sim = softclamp(sim, softclamp_value)
1146
+
1147
+ # masking
1148
+
1149
+ mask_value = -torch.finfo(sim.dtype).max
1150
+
1151
+ if exists(mask):
1152
+ sim = sim.masked_fill(~mask, mask_value)
1153
+
1154
+ if causal:
1155
+ is_blocked_causal = causal_block_size > 1
1156
+ i, j = sim.shape[-2:]
1157
+
1158
+ if is_blocked_causal:
1159
+ i = ceil(i / causal_block_size)
1160
+ j = ceil(j / causal_block_size)
1161
+
1162
+ causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
1163
+
1164
+ if causal_block_size > 1:
1165
+ causal_mask = repeat(causal_mask, 'i j -> (i b1) (j b2)', b1 = causal_block_size, b2 = causal_block_size)
1166
+ causal_mask = causal_mask[:sim.shape[-2], :sim.shape[-1]]
1167
+
1168
+ sim = sim.masked_fill(causal_mask, mask_value)
1169
+
1170
+ # attend
1171
+
1172
+ attn = sim.softmax(dim = -1)
1173
+
1174
+ # aggregate
1175
+
1176
+ out = einsum(attn, v, 'b h g i j, b h j d -> b h g i d')
1177
+
1178
+ # merge the groups
1179
+
1180
+ return rearrange(out, 'b h g i d -> b (h g) i d')
1181
+
1182
+ # flex attention related and factory function for attend depending on whether on cuda + flex attention available
1183
+
1184
+ def block_mask_causal(block_size):
1185
+
1186
+ def inner(b, h, q, k):
1187
+ bq = q // block_size
1188
+ bk = k // block_size
1189
+ return bq >= bk
1190
+
1191
+ return inner
1192
+
1193
+ def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = False):
1194
+ bq = q % seq_len
1195
+ bk = k % seq_len
1196
+
1197
+ is_special_start_index = seq_len - num_tokens
1198
+
1199
+ q_is_special = q >= is_special_start_index
1200
+ k_is_special = k >= is_special_start_index
1201
+
1202
+ if special_attend_only_itself:
1203
+ out = ~(q_is_special & ~k_is_special) # modality attends to everything, but latent can only attend to itself (proposed attention pattern for encoder of video tokenizer)
1204
+ else:
1205
+ out = ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens
1206
+
1207
+ return out
1208
+
1209
+ def block_mask_special_tokens_right(
1210
+ seq_len,
1211
+ num_tokens,
1212
+ special_attend_only_itself = False
1213
+ ):
1214
+ def inner(b, h, q, k):
1215
+ return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself)
1216
+ return inner
1217
+
1218
+ def compose_mask(mask1, mask2):
1219
+ def inner(b, h, q, k):
1220
+ return mask1(b, h, q, k) & mask2(b, h, q, k)
1221
+
1222
+ return inner
1223
+
1224
+ def block_mask_noop(b, h, q, k):
1225
+ return b >= 0
1226
+
1227
+ def score_mod_softclamp(value):
1228
+ def inner(sim, b, h, q, k):
1229
+ if not exists(value):
1230
+ return sim
1231
+
1232
+ sim = sim / value
1233
+ sim = torch.tanh(sim)
1234
+ sim = sim * value
1235
+ return sim
1236
+
1237
+ return inner
1238
+
1239
+ # factory for attend function
1240
+
1241
+ def get_attend_fn(
1242
+ use_flex,
1243
+ seq_len,
1244
+ k_seq_len,
1245
+ causal = False,
1246
+ causal_block_size = 1,
1247
+ softclamp_value = 50.,
1248
+ num_special_tokens = 0, # special tokens are latents / agents
1249
+ block_size_per_special = None, # defaults to k_seq_len
1250
+ special_attend_only_itself = False, # by default, modality only attends to itself while special sees everything, but if turned True, will be the inverse - special can only attend to itself but modality can attend everything
1251
+ device = None
1252
+ ):
1253
+ block_size_per_special = default(block_size_per_special, k_seq_len)
1254
+
1255
+ if use_flex:
1256
+ # flex pathway
1257
+
1258
+ block_mask_fn = block_mask_causal(causal_block_size) if causal else block_mask_noop
1259
+
1260
+ if num_special_tokens > 0:
1261
+ special_block_mask = block_mask_special_tokens_right(block_size_per_special, num_special_tokens, special_attend_only_itself)
1262
+ block_mask_fn = compose_mask(block_mask_fn, special_block_mask)
1263
+
1264
+ block_mask = create_block_mask(block_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = k_seq_len)
1265
+
1266
+ score_mod = score_mod_softclamp(softclamp_value)
1267
+ attend_fn = partial(flex_attention, block_mask = block_mask, score_mod = score_mod, enable_gqa = True)
1268
+ else:
1269
+ # naive pathway
1270
+
1271
+ mask = None
1272
+ if num_special_tokens > 0:
1273
+ q_seq = torch.arange(seq_len, device = device)[:, None]
1274
+ k_seq = torch.arange(k_seq_len, device = device)[None, :]
1275
+
1276
+ mask = special_token_mask(q_seq, k_seq, block_size_per_special, num_special_tokens, special_attend_only_itself)
1277
+
1278
+ attend_fn = partial(naive_attend, causal = causal, causal_block_size = causal_block_size, mask = mask, softclamp_value = softclamp_value)
1279
+
1280
+ return attend_fn
1281
+
1282
+ # attention
1283
+
1284
+ class Attention(Module):
1285
+ def __init__(
1286
+ self,
1287
+ dim,
1288
+ dim_head = 64,
1289
+ query_heads = None,
1290
+ heads = 8,
1291
+ pre_rmsnorm = True,
1292
+ gate_values = True,
1293
+ rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2
1294
+ rmsnorm_key = True,
1295
+ value_residual = True
1296
+ ):
1297
+ super().__init__()
1298
+ self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
1299
+
1300
+ # setup grouped query attention
1301
+
1302
+ query_heads = default(query_heads, heads)
1303
+ assert query_heads >= heads and divisible_by(query_heads, heads)
1304
+
1305
+ # scaling, splitting and merging of heads
1306
+
1307
+ self.split_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
1308
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
1309
+
1310
+ dim_q_inner = dim_head * query_heads
1311
+ dim_kv_inner = dim_head * heads
1312
+
1313
+ self.to_q = LinearNoBias(dim, dim_q_inner)
1314
+ self.to_k = LinearNoBias(dim, dim_kv_inner)
1315
+ self.to_v = LinearNoBias(dim, dim_kv_inner)
1316
+ self.to_out = LinearNoBias(dim_q_inner, dim)
1317
+
1318
+ # alphafold gating per head, for attending to nothing
1319
+
1320
+ self.to_gates = None
1321
+
1322
+ if gate_values:
1323
+ self.to_gates = Sequential(
1324
+ LinearNoBias(dim, query_heads),
1325
+ Rearrange('b n h -> b h n 1'),
1326
+ nn.Sigmoid()
1327
+ )
1328
+
1329
+ # stability related
1330
+
1331
+ self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity()
1332
+ self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
1333
+
1334
+ # value residual
1335
+
1336
+ self.to_learned_value_residual_mix = nn.Sequential(
1337
+ nn.Linear(dim, heads),
1338
+ Rearrange('b n h -> b h n 1'),
1339
+ nn.Sigmoid()
1340
+ ) if value_residual else None
1341
+
1342
+ def muon_parameters(self):
1343
+ # omit the queries and keys for now given what we learned from kimi 2 paper
1344
+
1345
+ return [
1346
+ *self.to_v.parameters(),
1347
+ *self.to_out.parameters(),
1348
+ ]
1349
+
1350
+ def forward(
1351
+ self,
1352
+ tokens, # (b n d)
1353
+ kv_cache = None,
1354
+ return_intermediates = False,
1355
+ rotary_pos_emb = None,
1356
+ residual_values = None, # (b n h d)
1357
+ attend_fn: Callable | None = None
1358
+ ):
1359
+ tokens, inverse_packed_batch = pack_one(tokens, '* n d')
1360
+
1361
+ tokens = self.norm(tokens)
1362
+
1363
+ q, k, v = (self.to_q(tokens), self.to_k(tokens), self.to_v(tokens))
1364
+
1365
+ # split heads
1366
+
1367
+ q, k, v = map(self.split_heads, (q, k, v))
1368
+
1369
+ # handle maybe value residual
1370
+
1371
+ if exists(residual_values):
1372
+ residual_values = rearrange(residual_values, '... n h d -> (...) h n d')
1373
+
1374
+ assert exists(self.to_learned_value_residual_mix)
1375
+
1376
+ learned_mix = self.to_learned_value_residual_mix(tokens)
1377
+
1378
+ v = v.lerp(residual_values, learned_mix)
1379
+
1380
+ # qk rmsnorm
1381
+
1382
+ q = self.q_heads_rmsnorm(q)
1383
+ k = self.k_heads_rmsnorm(k)
1384
+
1385
+ # rotary
1386
+
1387
+ if exists(rotary_pos_emb):
1388
+ q = apply_rotations(rotary_pos_emb, q)
1389
+ k = apply_rotations(rotary_pos_emb, k)
1390
+
1391
+ # caching
1392
+
1393
+ if exists(kv_cache):
1394
+ ck, cv = kv_cache
1395
+ k = cat((ck, k), dim = -2)
1396
+ v = cat((cv, v), dim = -2)
1397
+
1398
+ # attention
1399
+
1400
+ attend_fn = default(attend_fn, naive_attend)
1401
+
1402
+ out = attend_fn(q, k, v)
1403
+
1404
+ # gate values
1405
+
1406
+ if exists(self.to_gates):
1407
+ gates = self.to_gates(tokens)
1408
+ out = out * gates
1409
+
1410
+ # merge heads
1411
+
1412
+ out = self.merge_heads(out)
1413
+
1414
+ # combine heads
1415
+
1416
+ out = self.to_out(out)
1417
+
1418
+ out = inverse_packed_batch(out)
1419
+
1420
+ if not return_intermediates:
1421
+ return out
1422
+
1423
+ return out, AttentionIntermediates(stack((k, v)), tokens)
1424
+
1425
+ # feedforward
1426
+
1427
+ class SwiGLUFeedforward(Module):
1428
+ def __init__(
1429
+ self,
1430
+ dim,
1431
+ expansion_factor = 4,
1432
+ pre_rmsnorm = True
1433
+ ):
1434
+ super().__init__()
1435
+ self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
1436
+
1437
+ dim_inner = int(dim * expansion_factor * 2 / 3)
1438
+
1439
+ self.proj_in = Linear(dim, dim_inner * 2)
1440
+ self.proj_out = Linear(dim_inner, dim)
1441
+
1442
+ def muon_parameters(self):
1443
+ return [
1444
+ self.proj_in.weight,
1445
+ self.proj_out.weight,
1446
+ ]
1447
+
1448
+ def forward(self, x):
1449
+ x = self.norm(x)
1450
+
1451
+ x, gates = self.proj_in(x).chunk(2, dim = -1)
1452
+ x = x * F.gelu(gates)
1453
+
1454
+ return self.proj_out(x)
1455
+
1456
+ # axial space time transformer
1457
+
1458
+ class AxialSpaceTimeTransformer(Module):
1459
+ def __init__(
1460
+ self,
1461
+ dim,
1462
+ depth,
1463
+ attn_heads = 8,
1464
+ attn_dim_head = 64,
1465
+ attn_softclamp_value = 50.,
1466
+ time_block_every = 4,
1467
+ attn_kwargs: dict = dict(),
1468
+ ff_kwargs: dict = dict(),
1469
+ num_residual_streams = 1,
1470
+ num_special_spatial_tokens = 1,
1471
+ special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
1472
+ final_norm = True,
1473
+ value_residual = True, # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS
1474
+ rnn_time = False
1475
+ ):
1476
+ super().__init__()
1477
+ assert depth >= time_block_every, f'depth must be at least {time_block_every}'
1478
+
1479
+ # hyper connections
1480
+
1481
+ hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
1482
+
1483
+ # attention
1484
+
1485
+ self.attn_softclamp_value = attn_softclamp_value
1486
+
1487
+ # attention masking
1488
+
1489
+ self.special_attend_only_itself = special_attend_only_itself
1490
+
1491
+ # time rotary embedding
1492
+
1493
+ self.time_rotary = Rotary1D(attn_dim_head)
1494
+
1495
+ # project initial for value residuals
1496
+
1497
+ self.value_residual = value_residual
1498
+
1499
+ if value_residual:
1500
+ dim_inner = attn_dim_head * attn_heads
1501
+
1502
+ self.to_value_residual = nn.Sequential(
1503
+ nn.RMSNorm(dim),
1504
+ nn.Linear(dim, dim_inner, bias = False),
1505
+ Rearrange('... (h d) -> ... h d', h = attn_heads)
1506
+ )
1507
+
1508
+ # a gru layer across time
1509
+
1510
+ self.rnn_time = rnn_time
1511
+ rnn_layers = []
1512
+
1513
+ # transformer
1514
+
1515
+ layers = []
1516
+ is_time = []
1517
+
1518
+ for i in range(depth):
1519
+ layer_index = i + 1
1520
+
1521
+ is_time_block = divisible_by(layer_index, time_block_every)
1522
+ is_time.append(is_time_block)
1523
+
1524
+ rearrange_to_attend = Rearrange('b t s ... -> b s t ...') if is_time_block else Identity()
1525
+ rearrange_from_attend = Rearrange('b s t ... -> b t s ...') if is_time_block else Identity()
1526
+
1527
+ layers.append(ModuleList([
1528
+ rearrange_to_attend,
1529
+ rearrange_from_attend,
1530
+ hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, value_residual = value_residual, **attn_kwargs)),
1531
+ hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
1532
+ ]))
1533
+
1534
+ rnn_layers.append(ModuleList([
1535
+ nn.RMSNorm(dim),
1536
+ nn.GRU(dim, dim, batch_first = True)
1537
+ ]) if is_time_block and rnn_time else None)
1538
+
1539
+ self.layers = ModuleList(layers)
1540
+ self.rnn_layers = ModuleList(rnn_layers)
1541
+
1542
+ self.is_time = is_time
1543
+
1544
+ # final norm
1545
+
1546
+ self.final_norm = nn.RMSNorm(dim) if final_norm else nn.Identity()
1547
+
1548
+ # special tokens
1549
+
1550
+ self.num_special_spatial_tokens = num_special_spatial_tokens
1551
+
1552
+ def muon_parameters(self):
1553
+ muon_params = []
1554
+
1555
+ for m in self.modules():
1556
+ if isinstance(m, (Attention, SwiGLUFeedforward)):
1557
+ muon_params.extend(m.muon_parameters())
1558
+
1559
+ return muon_params
1560
+
1561
+ def forward(
1562
+ self,
1563
+ tokens, # (b t s d)
1564
+ kv_cache: Tensor | None = None, # (y 2 b h t d)
1565
+ return_intermediates = False
1566
+
1567
+ ): # (b t s d) | (y 2 b h t d)
1568
+
1569
+ batch, time, space_seq_len, _, device = *tokens.shape, tokens.device
1570
+
1571
+ assert tokens.ndim == 4
1572
+
1573
+ # attend functions for space and time
1574
+
1575
+ has_kv_cache = exists(kv_cache)
1576
+ use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix
1577
+
1578
+ attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
1579
+
1580
+ space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = self.num_special_spatial_tokens, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
1581
+
1582
+ time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
1583
+
1584
+ # prepare cache
1585
+
1586
+ time_attn_kv_caches = []
1587
+
1588
+ if has_kv_cache:
1589
+ past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
1590
+
1591
+ rotary_seq_len = 1
1592
+ rotary_pos_offset = past_tokens.shape[1]
1593
+ else:
1594
+ rotary_seq_len = time
1595
+ rotary_pos_offset = 0
1596
+
1597
+ kv_cache = default(kv_cache, (None,))
1598
+
1599
+ iter_kv_cache = iter(kv_cache)
1600
+
1601
+ # rotary
1602
+
1603
+ rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
1604
+
1605
+ # value residual
1606
+
1607
+ residual_values = None
1608
+
1609
+ if self.value_residual:
1610
+ residual_values = self.to_value_residual(tokens)
1611
+
1612
+ # normed attention inputs
1613
+
1614
+ normed_time_attn_inputs = []
1615
+ normed_space_attn_inputs = []
1616
+
1617
+ # attention
1618
+
1619
+ tokens = self.expand_streams(tokens)
1620
+
1621
+ for (pre_attn_rearrange, post_attn_rearrange, attn, ff), maybe_rnn_modules, layer_is_time in zip(self.layers, self.rnn_layers, self.is_time):
1622
+
1623
+ tokens = pre_attn_rearrange(tokens)
1624
+
1625
+ # maybe rnn for time
1626
+
1627
+ if layer_is_time and exists(maybe_rnn_modules):
1628
+ rnn_prenorm, rnn = maybe_rnn_modules
1629
+
1630
+ rnn_input, inverse_pack_time = pack_one(tokens, '* t d')
1631
+
1632
+ rnn_out, rnn_hiddens = rnn(rnn_prenorm(rnn_input)) # todo, handle rnn cache
1633
+
1634
+ rnn_out = inverse_pack_time(rnn_out)
1635
+
1636
+ tokens = rnn_out + tokens
1637
+
1638
+ # when is a axial time attention block, should be causal
1639
+
1640
+ attend_fn = time_attend if layer_is_time else space_attend
1641
+
1642
+ layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
1643
+
1644
+ # maybe past kv cache
1645
+
1646
+ maybe_kv_cache = next(iter_kv_cache, None) if layer_is_time else None
1647
+
1648
+ # residual values
1649
+
1650
+ layer_residual_values = maybe(pre_attn_rearrange)(residual_values)
1651
+
1652
+ # attention layer
1653
+
1654
+ tokens, attn_intermediates = attn(
1655
+ tokens,
1656
+ rotary_pos_emb = layer_rotary_pos_emb,
1657
+ attend_fn = attend_fn,
1658
+ kv_cache = maybe_kv_cache,
1659
+ residual_values = layer_residual_values,
1660
+ return_intermediates = True
1661
+ )
1662
+
1663
+ tokens = post_attn_rearrange(tokens)
1664
+
1665
+ # feedforward layer
1666
+
1667
+ tokens = ff(tokens)
1668
+
1669
+ # save kv cache if is time layer
1670
+
1671
+ if layer_is_time:
1672
+ time_attn_kv_caches.append(attn_intermediates.next_kv_cache)
1673
+
1674
+ # save time attention inputs for decorr
1675
+
1676
+ space_or_time_inputs = normed_time_attn_inputs if layer_is_time else normed_space_attn_inputs
1677
+
1678
+ space_or_time_inputs.append(attn_intermediates.normed_inputs)
1679
+
1680
+ tokens = self.reduce_streams(tokens)
1681
+
1682
+ out = self.final_norm(tokens)
1683
+
1684
+ if has_kv_cache:
1685
+ # just concat the past tokens back on for now, todo - clean up the logic
1686
+ out = cat((past_tokens, out), dim = 1)
1687
+
1688
+ if not return_intermediates:
1689
+ return out
1690
+
1691
+ intermediates = TransformerIntermediates(
1692
+ stack(time_attn_kv_caches),
1693
+ safe_stack(normed_time_attn_inputs),
1694
+ safe_stack(normed_space_attn_inputs)
1695
+ )
1696
+
1697
+ return out, intermediates
1698
+
1699
+ # video tokenizer
1700
+
1701
+ class VideoTokenizer(Module):
1702
+ def __init__(
1703
+ self,
1704
+ dim,
1705
+ dim_latent,
1706
+ patch_size,
1707
+ image_height = None,
1708
+ image_width = None,
1709
+ num_latent_tokens = 4,
1710
+ encoder_depth = 4,
1711
+ decoder_depth = 4,
1712
+ time_block_every = 4,
1713
+ attn_kwargs: dict = dict(),
1714
+ attn_dim_head = 64,
1715
+ attn_heads = 8,
1716
+ attn_softclamp_value = 50.,
1717
+ ff_kwargs: dict = dict(),
1718
+ decoder_pos_mlp_depth = 2,
1719
+ channels = 3,
1720
+ per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
1721
+ lpips_loss_network: Module | None = None,
1722
+ lpips_loss_weight = 0.2,
1723
+ encoder_add_decor_aux_loss = False,
1724
+ decor_auxx_loss_weight = 0.1,
1725
+ decorr_sample_frac = 0.25,
1726
+ nd_rotary_kwargs: dict = dict(
1727
+ rope_min_freq = 1.,
1728
+ rope_max_freq = 10000.,
1729
+ rope_p_zero_freqs = 0.
1730
+ ),
1731
+ num_residual_streams = 1,
1732
+ ):
1733
+ super().__init__()
1734
+
1735
+ self.patch_size = patch_size
1736
+
1737
+ # special tokens
1738
+
1739
+ assert num_latent_tokens >= 1
1740
+ self.num_latent_tokens = num_latent_tokens
1741
+ self.latent_tokens = Parameter(randn(num_latent_tokens, dim) * 1e-2)
1742
+
1743
+ # hyper connections
1744
+
1745
+ hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim)
1746
+
1747
+ # mae masking - Kaiming He paper from long ago
1748
+
1749
+ self.per_image_patch_mask_prob = per_image_patch_mask_prob
1750
+ self.mask_token = Parameter(randn(dim) * 1e-2)
1751
+
1752
+ # patch and unpatch
1753
+
1754
+ dim_patch = channels * patch_size ** 2
1755
+
1756
+ self.patch_to_tokens = Sequential(
1757
+ Rearrange('b c t (h p1) (w p2) -> b t h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
1758
+ Linear(dim_patch, dim)
1759
+ )
1760
+
1761
+ self.tokens_to_patch = Sequential(
1762
+ Linear(dim, dim_patch),
1763
+ Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
1764
+ )
1765
+
1766
+ # encoder space / time transformer
1767
+
1768
+ self.encoder_transformer = AxialSpaceTimeTransformer(
1769
+ dim = dim,
1770
+ depth = encoder_depth,
1771
+ attn_dim_head = attn_dim_head,
1772
+ attn_softclamp_value = attn_softclamp_value,
1773
+ time_block_every = time_block_every,
1774
+ num_special_spatial_tokens = num_latent_tokens,
1775
+ num_residual_streams = num_residual_streams,
1776
+ final_norm = True
1777
+ )
1778
+
1779
+ # latents
1780
+
1781
+ self.encoded_to_latents = Sequential(
1782
+ LinearNoBias(dim, dim_latent),
1783
+ nn.Tanh(),
1784
+ )
1785
+
1786
+ self.latents_to_decoder = LinearNoBias(dim_latent, dim)
1787
+
1788
+ # decoder
1789
+
1790
+ self.image_height = image_height
1791
+ self.image_width = image_width
1792
+
1793
+ # parameterize the decoder positional embeddings for MAE style training so it can be resolution agnostic
1794
+
1795
+ self.to_decoder_pos_emb = create_mlp(
1796
+ dim_in = 2,
1797
+ dim = dim * 2,
1798
+ dim_out = dim,
1799
+ depth = decoder_pos_mlp_depth,
1800
+ )
1801
+
1802
+ # decoder transformer
1803
+
1804
+ self.decoder_transformer = AxialSpaceTimeTransformer(
1805
+ dim = dim,
1806
+ depth = decoder_depth,
1807
+ attn_dim_head = attn_dim_head,
1808
+ attn_softclamp_value = attn_softclamp_value,
1809
+ time_block_every = time_block_every,
1810
+ num_special_spatial_tokens = num_latent_tokens,
1811
+ num_residual_streams = num_residual_streams,
1812
+ special_attend_only_itself = True,
1813
+ final_norm = True
1814
+ )
1815
+
1816
+ # loss related
1817
+
1818
+ self.register_buffer('zero', tensor(0.), persistent = False)
1819
+
1820
+ self.has_lpips_loss = lpips_loss_weight > 0.
1821
+ self.lpips_loss_weight = lpips_loss_weight
1822
+
1823
+ if self.has_lpips_loss:
1824
+ self.lpips = LPIPSLoss(lpips_loss_network)
1825
+
1826
+ # decorr aux loss
1827
+ # https://arxiv.org/abs/2510.14657
1828
+
1829
+ self.encoder_add_decor_aux_loss = encoder_add_decor_aux_loss
1830
+ self.decorr_aux_loss_weight = decor_auxx_loss_weight
1831
+
1832
+ self.decorr_loss = DecorrelationLoss(decorr_sample_frac, soft_validate_num_sampled = True) if encoder_add_decor_aux_loss else None
1833
+
1834
+ @property
1835
+ def device(self):
1836
+ return self.zero.device
1837
+
1838
+ def muon_parameters(self):
1839
+ return [
1840
+ *self.encoder_transformer.muon_parameters(),
1841
+ *self.decoder_transformer.muon_parameters()
1842
+ ]
1843
+
1844
+ @torch.no_grad()
1845
+ def tokenize(
1846
+ self,
1847
+ video
1848
+ ):
1849
+ self.eval()
1850
+ return self.forward(video, return_latents = True)
1851
+
1852
+ def decode(
1853
+ self,
1854
+ latents, # (b t n d)
1855
+ height = None,
1856
+ width = None,
1857
+ ): # (b c t h w)
1858
+
1859
+ height = default(height, self.image_height)
1860
+ width = default(width, self.image_width)
1861
+
1862
+ assert exists(height) and exists(width), f'image height and width need to be passed in when decoding latents'
1863
+
1864
+ batch, time, device = *latents.shape[:2], latents.device
1865
+
1866
+ use_flex = latents.is_cuda and exists(flex_attention)
1867
+
1868
+ num_patch_height = height // self.patch_size
1869
+ num_patch_width = width // self.patch_size
1870
+
1871
+ # latents to tokens
1872
+
1873
+ latent_tokens = self.latents_to_decoder(latents)
1874
+
1875
+ # generate decoder positional embedding and concat the latent token
1876
+
1877
+ spatial_pos_height = torch.linspace(-1., 1., num_patch_height, device = device)
1878
+ spatial_pos_width = torch.linspace(-1., 1., num_patch_width, device = device)
1879
+
1880
+ space_height_width_coor = stack(torch.meshgrid(spatial_pos_height, spatial_pos_width, indexing = 'ij'), dim = -1)
1881
+
1882
+ decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor)
1883
+ decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time)
1884
+
1885
+ tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d')
1886
+
1887
+ # decoder attention
1888
+
1889
+ tokens = self.decoder_transformer(tokens)
1890
+
1891
+ # unpack latents
1892
+
1893
+ tokens, latent_tokens = unpack(tokens, packed_latent_shape, 'b t * d')
1894
+
1895
+ # project back to patches
1896
+
1897
+ recon_video = self.tokens_to_patch(tokens)
1898
+
1899
+ return recon_video
1900
+
1901
+ def forward(
1902
+ self,
1903
+ video, # (b c t h w)
1904
+ return_latents = False,
1905
+ mask_patches = None,
1906
+ return_all_losses = False
1907
+ ):
1908
+ batch, _, time, height, width = video.shape
1909
+ patch_size, device = self.patch_size, video.device
1910
+
1911
+ assert divisible_by(height, patch_size) and divisible_by(width, patch_size)
1912
+
1913
+ # to tokens
1914
+
1915
+ tokens = self.patch_to_tokens(video)
1916
+
1917
+ # get some dimensions
1918
+
1919
+ num_patch_height, num_patch_width, _ = tokens.shape[-3:]
1920
+
1921
+ # masking
1922
+
1923
+ mask_patches = default(mask_patches, self.training)
1924
+
1925
+ if mask_patches:
1926
+ min_mask_prob, max_mask_prob = self.per_image_patch_mask_prob
1927
+
1928
+ mask_prob = torch.empty(tokens.shape[:2], device = tokens.device).uniform_(min_mask_prob, max_mask_prob) # (b t)
1929
+
1930
+ mask_prob = repeat(mask_prob, 'b t -> b t vh vw', vh = tokens.shape[2], vw = tokens.shape[3])
1931
+ mask_patch = torch.bernoulli(mask_prob) == 1.
1932
+
1933
+ tokens = einx.where('..., d, ... d', mask_patch, self.mask_token, tokens)
1934
+
1935
+ # pack space
1936
+
1937
+ tokens, inverse_pack_space = pack_one(tokens, 'b t * d')
1938
+
1939
+ # add the latent
1940
+
1941
+ latents = repeat(self.latent_tokens, 'n d -> b t n d', b = tokens.shape[0], t = tokens.shape[1])
1942
+
1943
+ tokens, packed_latent_shape = pack((tokens, latents), 'b t * d')
1944
+
1945
+ # encoder attention
1946
+
1947
+ tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
1948
+
1949
+ # latent bottleneck
1950
+
1951
+ tokens, latents = unpack(tokens, packed_latent_shape, 'b t * d')
1952
+
1953
+ latents = self.encoded_to_latents(latents)
1954
+
1955
+ if return_latents:
1956
+ return latents
1957
+
1958
+ recon_video = self.decode(latents, height = height, width = width)
1959
+
1960
+ # losses
1961
+
1962
+ recon_loss = F.mse_loss(video, recon_video)
1963
+
1964
+ lpips_loss = self.zero
1965
+
1966
+ if self.has_lpips_loss:
1967
+ lpips_loss = self.lpips(video, recon_video)
1968
+
1969
+ time_decorr_loss = space_decorr_loss = self.zero
1970
+
1971
+ if self.encoder_add_decor_aux_loss:
1972
+ if exists(time_attn_normed_inputs):
1973
+ time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
1974
+
1975
+ if exists(space_attn_normed_inputs):
1976
+ space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
1977
+
1978
+ # losses
1979
+
1980
+ total_loss = (
1981
+ recon_loss +
1982
+ lpips_loss * self.lpips_loss_weight +
1983
+ time_decorr_loss * self.decorr_aux_loss_weight +
1984
+ space_decorr_loss * self.decorr_aux_loss_weight
1985
+ )
1986
+
1987
+ if not return_all_losses:
1988
+ return total_loss
1989
+
1990
+ losses = (recon_loss, lpips_loss, decorr_loss)
1991
+
1992
+ return total_loss, TokenizerLosses(*losses)
1993
+
1994
+ # dynamics model, axial space-time transformer
1995
+
1996
+ class DynamicsWorldModel(Module):
1997
+ def __init__(
1998
+ self,
1999
+ dim,
2000
+ dim_latent,
2001
+ video_tokenizer: VideoTokenizer | None = None,
2002
+ max_steps = 64, # K_max in paper
2003
+ num_register_tokens = 8, # they claim register tokens led to better temporal consistency
2004
+ num_spatial_tokens = 2, # latents projected to greater number of spatial tokens
2005
+ num_latent_tokens = None,
2006
+ num_agents = 1,
2007
+ num_tasks = 0,
2008
+ num_video_views = 1,
2009
+ dim_proprio = None,
2010
+ reward_encoder_kwargs: dict = dict(),
2011
+ depth = 4,
2012
+ pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
2013
+ time_block_every = 4, # every 4th block is time
2014
+ attn_kwargs: dict = dict(),
2015
+ transformer_kwargs: dict = dict(),
2016
+ attn_heads = 8,
2017
+ attn_dim_head = 64,
2018
+ attn_softclamp_value = 50.,
2019
+ ff_kwargs: dict = dict(),
2020
+ loss_weight_fn: Callable = ramp_weight,
2021
+ num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
2022
+ prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
2023
+ add_reward_embed_to_agent_token = False,
2024
+ add_reward_embed_dropout = 0.1,
2025
+ num_discrete_actions: int | tuple[int, ...] = 0,
2026
+ num_continuous_actions = 0,
2027
+ continuous_norm_stats = None,
2028
+ multi_token_pred_len = 8,
2029
+ value_head_mlp_depth = 3,
2030
+ policy_head_mlp_depth = 3,
2031
+ latent_flow_loss_weight = 1.,
2032
+ reward_loss_weight: float | list[float] = 1.,
2033
+ discrete_action_loss_weight: float | list[float] = 1.,
2034
+ continuous_action_loss_weight: float | list[float] = 1.,
2035
+ num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
2036
+ num_residual_streams = 1,
2037
+ keep_reward_ema_stats = False,
2038
+ reward_ema_decay = 0.998,
2039
+ reward_quantile_filter = (0.05, 0.95),
2040
+ gae_discount_factor = 0.997,
2041
+ gae_lambda = 0.95,
2042
+ ppo_eps_clip = 0.2,
2043
+ pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
2044
+ pmpo_reverse_kl = True,
2045
+ pmpo_kl_div_loss_weight = .3,
2046
+ normalize_advantages = None,
2047
+ value_clip = 0.4,
2048
+ policy_entropy_weight = .01,
2049
+ gae_use_accelerated = False
2050
+ ):
2051
+ super().__init__()
2052
+
2053
+ # can accept raw video if tokenizer is passed in
2054
+
2055
+ self.video_tokenizer = video_tokenizer
2056
+
2057
+ if exists(video_tokenizer):
2058
+ num_latent_tokens = default(num_latent_tokens, video_tokenizer.num_latent_tokens)
2059
+ assert video_tokenizer.num_latent_tokens == num_latent_tokens, f'`num_latent_tokens` must be the same for the tokenizer and dynamics model'
2060
+
2061
+ assert exists(num_latent_tokens), '`num_latent_tokens` must be set'
2062
+
2063
+ # spatial
2064
+
2065
+ self.num_latent_tokens = num_latent_tokens
2066
+ self.dim_latent = dim_latent
2067
+ self.latent_shape = (num_latent_tokens, dim_latent)
2068
+
2069
+ if num_spatial_tokens >= num_latent_tokens:
2070
+ assert divisible_by(num_spatial_tokens, num_latent_tokens)
2071
+
2072
+ expand_factor = num_spatial_tokens // num_latent_tokens
2073
+
2074
+ self.latents_to_spatial_tokens = Sequential(
2075
+ Linear(dim_latent, dim * expand_factor),
2076
+ Rearrange('... (s d) -> ... s d', s = expand_factor)
2077
+ )
2078
+
2079
+ self.to_latent_pred = Sequential(
2080
+ Reduce('b t v n s d -> b t v n d', 'mean'),
2081
+ RMSNorm(dim),
2082
+ LinearNoBias(dim, dim_latent)
2083
+ )
2084
+
2085
+ else:
2086
+ assert divisible_by(num_latent_tokens, num_spatial_tokens)
2087
+ latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
2088
+
2089
+ self.latents_to_spatial_tokens = Sequential(
2090
+ Rearrange('... n d -> ... (n d)'),
2091
+ Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
2092
+ Rearrange('... (s d) -> ... s d', s = num_spatial_tokens)
2093
+ )
2094
+
2095
+ self.to_latent_pred = Sequential(
2096
+ RMSNorm(dim),
2097
+ LinearNoBias(dim, dim_latent * latent_tokens_to_space),
2098
+ Rearrange('b t v s (n d) -> b t v (s n) d', n = latent_tokens_to_space)
2099
+ )
2100
+
2101
+ # number of video views, for robotics, which could have third person + wrist camera at least
2102
+
2103
+ assert num_video_views >= 1
2104
+ self.video_has_multi_view = num_video_views > 1
2105
+
2106
+ self.num_video_views = num_video_views
2107
+
2108
+ if self.video_has_multi_view:
2109
+ self.view_emb = nn.Parameter(torch.randn(num_video_views, dim) * 1e-2)
2110
+
2111
+ # proprioception
2112
+
2113
+ self.has_proprio = exists(dim_proprio)
2114
+ self.dim_proprio = dim_proprio
2115
+
2116
+ if self.has_proprio:
2117
+ self.to_proprio_token = nn.Linear(dim_proprio, dim)
2118
+
2119
+ self.to_proprio_pred = Sequential(
2120
+ RMSNorm(dim),
2121
+ nn.Linear(dim, dim_proprio)
2122
+ )
2123
+
2124
+ # register tokens
2125
+
2126
+ self.num_register_tokens = num_register_tokens
2127
+ self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
2128
+
2129
+ # signal and step sizes
2130
+
2131
+ assert divisible_by(dim, 2)
2132
+ dim_half = dim // 2
2133
+
2134
+ assert is_power_two(max_steps), '`max_steps` must be a power of 2'
2135
+ self.max_steps = max_steps
2136
+ self.num_step_sizes_log2 = int(log2(max_steps))
2137
+
2138
+ self.signal_levels_embed = nn.Embedding(max_steps, dim_half)
2139
+ self.step_size_embed = nn.Embedding(self.num_step_sizes_log2, dim_half) # power of 2, so 1/1, 1/2, 1/4, 1/8 ... 1/Kmax
2140
+
2141
+ self.prob_no_shortcut_train = default(prob_no_shortcut_train, self.num_step_sizes_log2 ** -1.)
2142
+
2143
+ # loss related
2144
+
2145
+ self.pred_orig_latent = pred_orig_latent # x-space or v-space
2146
+ self.loss_weight_fn = loss_weight_fn
2147
+
2148
+ # reinforcement related
2149
+
2150
+ # they sum all the actions into a single token
2151
+
2152
+ self.num_agents = num_agents
2153
+
2154
+ self.agent_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
2155
+ self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
2156
+
2157
+ self.reward_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
2158
+
2159
+ self.num_tasks = num_tasks
2160
+ self.task_embed = nn.Embedding(num_tasks, dim)
2161
+
2162
+ # learned set of latent genes
2163
+
2164
+ self.agent_has_genes = num_latent_genes > 0
2165
+ self.num_latent_genes = num_latent_genes
2166
+ self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
2167
+
2168
+ # policy head
2169
+
2170
+ self.policy_head = create_mlp(
2171
+ dim_in = dim,
2172
+ dim = dim * 4,
2173
+ dim_out = dim * 4,
2174
+ depth = policy_head_mlp_depth
2175
+ )
2176
+
2177
+ # action embedder
2178
+
2179
+ self.action_embedder = ActionEmbedder(
2180
+ dim = dim,
2181
+ num_discrete_actions = num_discrete_actions,
2182
+ num_continuous_actions = num_continuous_actions,
2183
+ continuous_norm_stats = continuous_norm_stats,
2184
+ can_unembed = True,
2185
+ unembed_dim = dim * 4,
2186
+ num_unembed_preds = multi_token_pred_len,
2187
+ squeeze_unembed_preds = False
2188
+ )
2189
+
2190
+ # multi token prediction length
2191
+
2192
+ self.multi_token_pred_len = multi_token_pred_len
2193
+
2194
+ # each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
2195
+
2196
+ self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
2197
+ self.add_reward_embed_dropout = add_reward_embed_dropout
2198
+
2199
+ self.reward_encoder = SymExpTwoHot(
2200
+ **reward_encoder_kwargs,
2201
+ dim_embed = dim,
2202
+ learned_embedding = add_reward_embed_to_agent_token
2203
+ )
2204
+
2205
+ to_reward_pred = Sequential(
2206
+ RMSNorm(dim),
2207
+ LinearNoBias(dim, self.reward_encoder.num_bins)
2208
+ )
2209
+
2210
+ self.to_reward_pred = Ensemble(
2211
+ to_reward_pred,
2212
+ multi_token_pred_len
2213
+ )
2214
+
2215
+ # value head
2216
+
2217
+ self.value_head = create_mlp(
2218
+ dim_in = dim,
2219
+ dim = dim * 4,
2220
+ dim_out = self.reward_encoder.num_bins,
2221
+ depth = value_head_mlp_depth,
735
2222
  )
736
- ):
737
- super().__init__()
738
2223
 
739
- self.patch_size = patch_size
2224
+ # efficient axial space / time transformer
2225
+
2226
+ self.transformer = AxialSpaceTimeTransformer(
2227
+ dim = dim,
2228
+ depth = depth,
2229
+ attn_heads = attn_heads,
2230
+ attn_dim_head = attn_dim_head,
2231
+ attn_softclamp_value = attn_softclamp_value,
2232
+ attn_kwargs = attn_kwargs,
2233
+ ff_kwargs = ff_kwargs,
2234
+ num_residual_streams = num_residual_streams,
2235
+ num_special_spatial_tokens = num_agents,
2236
+ time_block_every = time_block_every,
2237
+ final_norm = False,
2238
+ **transformer_kwargs
2239
+ )
740
2240
 
741
- # special tokens
2241
+ # ppo related
742
2242
 
743
- assert num_latent_tokens >= 1
744
- self.num_latent_tokens = num_latent_tokens
745
- self.latent_tokens = Parameter(randn(num_latent_tokens, dim) * 1e-2)
2243
+ self.gae_use_accelerated = gae_use_accelerated
2244
+ self.gae_discount_factor = gae_discount_factor
2245
+ self.gae_lambda = gae_lambda
746
2246
 
747
- # mae masking - Kaiming He paper from long ago
2247
+ self.ppo_eps_clip = ppo_eps_clip
2248
+ self.value_clip = value_clip
2249
+ self.policy_entropy_weight = policy_entropy_weight
748
2250
 
749
- self.per_image_patch_mask_prob = per_image_patch_mask_prob
750
- self.mask_token = Parameter(randn(dim) * 1e-2)
2251
+ # pmpo related
751
2252
 
752
- # patch and unpatch
2253
+ self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
2254
+ self.pmpo_kl_div_loss_weight = pmpo_kl_div_loss_weight
2255
+ self.pmpo_reverse_kl = pmpo_reverse_kl
753
2256
 
754
- dim_patch = channels * patch_size ** 2
2257
+ # rewards related
755
2258
 
756
- self.patch_to_tokens = Sequential(
757
- Rearrange('b c t (h p1) (w p2) -> b t h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
758
- Linear(dim_patch, dim)
759
- )
2259
+ self.keep_reward_ema_stats = keep_reward_ema_stats
2260
+ self.reward_ema_decay = reward_ema_decay
760
2261
 
761
- self.tokens_to_patch = Sequential(
762
- Linear(dim, dim_patch),
763
- Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
764
- )
2262
+ self.register_buffer('reward_quantile_filter', tensor(reward_quantile_filter), persistent = False)
765
2263
 
766
- # 3d rotations
2264
+ self.register_buffer('ema_returns_mean', tensor(0.))
2265
+ self.register_buffer('ema_returns_var', tensor(1.))
767
2266
 
768
- self.spacetime_rotary = GoldenGateRoPENd(
769
- dim_pos = 3,
770
- heads = attn_heads,
771
- dim_head = attn_dim_head,
772
- **nd_rotary_kwargs
773
- )
2267
+ # loss related
774
2268
 
775
- # attention related
2269
+ self.flow_loss_normalizer = LossNormalizer(1)
2270
+ self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
2271
+ self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
2272
+ self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None
776
2273
 
777
- self.attn_softclamp_value = attn_softclamp_value
2274
+ self.latent_flow_loss_weight = latent_flow_loss_weight
778
2275
 
779
- # encoder
2276
+ self.register_buffer('reward_loss_weight', tensor(reward_loss_weight))
2277
+ self.register_buffer('discrete_action_loss_weight', tensor(discrete_action_loss_weight))
2278
+ self.register_buffer('continuous_action_loss_weight', tensor(continuous_action_loss_weight))
780
2279
 
781
- encoder_layers = []
2280
+ assert self.reward_loss_weight.numel() in {1, multi_token_pred_len}
2281
+ assert self.discrete_action_loss_weight.numel() in {1, multi_token_pred_len}
2282
+ assert self.continuous_action_loss_weight.numel() in {1, multi_token_pred_len}
782
2283
 
783
- for _ in range(encoder_depth):
784
- encoder_layers.append(ModuleList([
785
- Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs),
786
- SwiGLUFeedforward(dim = dim, **ff_kwargs)
787
- ]))
2284
+ self.register_buffer('zero', tensor(0.), persistent = False)
788
2285
 
789
- self.encoder_layers = ModuleList(encoder_layers)
790
- self.encoder_norm = RMSNorm(dim)
2286
+ @property
2287
+ def device(self):
2288
+ return self.zero.device
791
2289
 
792
- # latents
2290
+ # types of parameters
793
2291
 
794
- self.encoded_to_latents = Sequential(
795
- LinearNoBias(dim, dim_latent),
796
- nn.Tanh(),
797
- )
2292
+ def muon_parameters(self):
2293
+ return self.transformer.muon_parameters()
798
2294
 
799
- self.latents_to_decoder = LinearNoBias(dim_latent, dim)
2295
+ def policy_head_parameters(self):
2296
+ return [
2297
+ *self.policy_head.parameters(),
2298
+ *self.action_embedder.unembed_parameters() # includes the unembed from the action-embedder
2299
+ ]
800
2300
 
801
- # decoder
2301
+ def value_head_parameters(self):
2302
+ return self.value_head.parameters()
802
2303
 
803
- self.image_height = image_height
804
- self.image_width = image_width
2304
+ def parameter(self):
2305
+ params = super().parameters()
805
2306
 
806
- # parameterize the decoder positional embeddings for MAE style training so it can be resolution agnostic
2307
+ if not exists(self.video_tokenizer):
2308
+ return params
807
2309
 
808
- self.to_decoder_pos_emb = create_mlp(
809
- dim_in = 2,
810
- dim = dim * 2,
811
- dim_out = dim,
812
- depth = decoder_pos_mlp_depth,
813
- )
2310
+ return list(set(params) - set(self.video_tokenizer.parameters()))
814
2311
 
815
- decoder_layers = []
2312
+ # helpers for shortcut flow matching
816
2313
 
817
- for _ in range(decoder_depth):
818
- decoder_layers.append(ModuleList([
819
- Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs),
820
- SwiGLUFeedforward(dim = dim, **ff_kwargs)
821
- ]))
2314
+ def get_times_from_signal_level(
2315
+ self,
2316
+ signal_levels,
2317
+ align_dims_left_to = None
2318
+ ):
2319
+ times = signal_levels.float() / self.max_steps
822
2320
 
823
- self.decoder_layers = ModuleList(decoder_layers)
824
- self.decoder_norm = RMSNorm(dim)
2321
+ if not exists(align_dims_left_to):
2322
+ return times
825
2323
 
826
- # loss related
2324
+ return align_dims_left(times, align_dims_left_to)
827
2325
 
828
- self.register_buffer('zero', tensor(0.), persistent = False)
2326
+ # evolutionary policy optimization - https://web3.arxiv.org/abs/2503.19037
829
2327
 
830
- self.has_lpips_loss = lpips_loss_weight > 0.
831
- self.lpips_loss_weight = lpips_loss_weight
2328
+ @torch.no_grad()
2329
+ def evolve_(
2330
+ self,
2331
+ fitness,
2332
+ select_frac = 0.5,
2333
+ tournament_frac = 0.5
2334
+ ):
2335
+ assert fitness.numel() == self.num_latent_genes
832
2336
 
833
- if self.has_lpips_loss:
834
- self.lpips = LPIPSLoss(lpips_loss_network)
2337
+ pop = self.latent_genes
835
2338
 
836
- @property
837
- def device(self):
838
- return self.zero.device
2339
+ pop_size = self.num_latent_genes
2340
+ num_selected = ceil(pop_size * select_frac)
2341
+ num_children = pop_size - num_selected
2342
+
2343
+ dim_gene = pop.shape[-1]
2344
+
2345
+ # natural selection just a sort and slice
2346
+
2347
+ selected_fitness, selected_indices = fitness.topk(num_selected, dim = -1)
2348
+ selected = pop[selected_indices]
2349
+
2350
+ # use tournament - one tournament per child
2351
+
2352
+ tournament_size = max(2, ceil(num_selected * tournament_frac))
2353
+
2354
+ tournaments = torch.randn((num_children, num_selected), device = self.device).argsort(dim = -1)[:, :tournament_size]
2355
+
2356
+ parent_ids = selected_fitness[tournaments].topk(2, dim = -1).indices # get top 2 winners as parents
2357
+
2358
+ parents = selected[parent_ids]
2359
+
2360
+ # crossover by random interpolation from parent1 to parent2
2361
+
2362
+ random_uniform_mix = torch.randn((num_children, dim_gene), device = self.device).sigmoid()
2363
+
2364
+ parent1, parent2 = parents.unbind(dim = 1)
2365
+ children = parent1.lerp(parent2, random_uniform_mix)
2366
+
2367
+ # store next population
2368
+
2369
+ next_pop = cat((selected, children))
2370
+
2371
+ self.latent_genes.copy_(next_pop)
2372
+
2373
+ # interacting with env for experience
839
2374
 
840
2375
  @torch.no_grad()
841
- def tokenize(
2376
+ def interact_with_env(
842
2377
  self,
843
- video
2378
+ env,
2379
+ seed = None,
2380
+ agent_index = 0,
2381
+ step_size = 4,
2382
+ max_timesteps = 16,
2383
+ env_is_vectorized = False,
2384
+ use_time_kv_cache = True,
2385
+ store_agent_embed = True,
2386
+ store_old_action_unembeds = True,
844
2387
  ):
845
- self.eval()
846
- return self.forward(video, return_latents = True)
2388
+ assert exists(self.video_tokenizer)
847
2389
 
848
- def get_rotary_pos_emb(
849
- self,
850
- time,
851
- num_patch_height,
852
- num_patch_width
853
- ):
854
- device = self.device
2390
+ init_frame = env.reset()
855
2391
 
856
- positions = stack(torch.meshgrid(
857
- arange(time, device = device),
858
- arange(num_patch_height, device = device),
859
- arange(num_patch_width, device = device)
860
- ), dim = -1)
2392
+ # frame to video
861
2393
 
862
- positions = rearrange(positions, 't h w p -> t (h w) p')
2394
+ if env_is_vectorized:
2395
+ video = rearrange(init_frame, 'b c vh vw -> b c 1 vh vw')
2396
+ else:
2397
+ video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
863
2398
 
864
- # give the latents an out of bounds position and assume the network will figure it out
2399
+ batch, device = video.shape[0], video.device
865
2400
 
866
- positions = pad_at_dim(positions, (0, self.num_latent_tokens), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated
2401
+ # accumulate
867
2402
 
868
- positions = rearrange(positions, 't hw p -> (t hw) p')
2403
+ rewards = None
2404
+ discrete_actions = None
2405
+ continuous_actions = None
2406
+ discrete_log_probs = None
2407
+ continuous_log_probs = None
2408
+ values = None
2409
+ latents = None
869
2410
 
870
- return self.spacetime_rotary(positions)
2411
+ acc_agent_embed = None
2412
+ acc_policy_embed = None
871
2413
 
872
- def decode(
873
- self,
874
- latents, # (b t n d)
875
- height = None,
876
- width = None,
877
- rotary_pos_emb = None
878
- ): # (b c t h w)
2414
+ # keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
879
2415
 
880
- height = default(height, self.image_height)
881
- width = default(width, self.image_width)
2416
+ is_terminated = full((batch,), False, device = device)
2417
+ is_truncated = full((batch,), False, device = device)
882
2418
 
883
- assert exists(height) and exists(width), f'image height and width need to be passed in when decoding latents'
2419
+ episode_lens = full((batch,), 0, device = device)
884
2420
 
885
- batch, time, device = *latents.shape[:2], latents.device
2421
+ # maybe time kv cache
886
2422
 
887
- use_flex = latents.is_cuda and exists(flex_attention)
2423
+ time_kv_cache = None
888
2424
 
889
- num_patch_height = height // self.patch_size
890
- num_patch_width = width // self.patch_size
2425
+ step_index = 0
891
2426
 
892
- if not exists(rotary_pos_emb):
893
- rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width)
2427
+ while not is_terminated.all():
2428
+ step_index += 1
894
2429
 
895
- # latents to tokens
2430
+ latents = self.video_tokenizer(video, return_latents = True)
896
2431
 
897
- latent_tokens = self.latents_to_decoder(latents)
2432
+ _, (agent_embed, next_time_kv_cache) = self.forward(
2433
+ latents = latents,
2434
+ signal_levels = self.max_steps - 1,
2435
+ step_sizes = step_size,
2436
+ rewards = rewards,
2437
+ discrete_actions = discrete_actions,
2438
+ continuous_actions = continuous_actions,
2439
+ time_kv_cache = time_kv_cache,
2440
+ latent_is_noised = True,
2441
+ return_pred_only = True,
2442
+ return_intermediates = True
2443
+ )
898
2444
 
899
- # generate decoder positional embedding and concat the latent token
2445
+ # time kv cache
900
2446
 
901
- spatial_pos_height = torch.linspace(-1., 1., num_patch_height, device = device)
902
- spatial_pos_width = torch.linspace(-1., 1., num_patch_width, device = device)
2447
+ if use_time_kv_cache:
2448
+ time_kv_cache = next_time_kv_cache
903
2449
 
904
- space_height_width_coor = stack(torch.meshgrid(spatial_pos_height, spatial_pos_width, indexing = 'ij'), dim = -1)
2450
+ # get one agent
905
2451
 
906
- decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor)
907
- decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time)
2452
+ one_agent_embed = agent_embed[..., -1:, agent_index, :]
908
2453
 
909
- tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d')
2454
+ # values
910
2455
 
911
- # pack time
2456
+ value_bins = self.value_head(one_agent_embed)
2457
+ value = self.reward_encoder.bins_to_scalar_value(value_bins)
912
2458
 
913
- tokens, inverse_pack_time = pack_one(tokens, 'b * d')
2459
+ values = safe_cat((values, value), dim = 1)
914
2460
 
915
- seq_len = tokens.shape[-2]
2461
+ # policy embed
916
2462
 
917
- # decoder attend
2463
+ policy_embed = self.policy_head(one_agent_embed)
918
2464
 
919
- decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, causal = True, num_special_tokens = self.num_latent_tokens, special_attend_only_itself = True)
2465
+ if store_old_action_unembeds:
2466
+ acc_policy_embed = safe_cat((acc_policy_embed, policy_embed), dim = 1)
920
2467
 
921
- # decoder attention
2468
+ # sample actions
922
2469
 
923
- for attn, ff in self.decoder_layers:
924
- tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens
2470
+ sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
925
2471
 
926
- tokens = ff(tokens) + tokens
2472
+ discrete_actions = safe_cat((discrete_actions, sampled_discrete_actions), dim = 1)
2473
+ continuous_actions = safe_cat((continuous_actions, sampled_continuous_actions), dim = 1)
927
2474
 
928
- tokens = self.decoder_norm(tokens)
2475
+ # get the log prob and values for policy optimization
929
2476
 
930
- # unpack time
2477
+ one_discrete_log_probs, one_continuous_log_probs = self.action_embedder.log_probs(
2478
+ policy_embed,
2479
+ pred_head_index = 0,
2480
+ discrete_targets = sampled_discrete_actions,
2481
+ continuous_targets = sampled_continuous_actions,
2482
+ )
931
2483
 
932
- tokens = inverse_pack_time(tokens)
2484
+ discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
2485
+ continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
933
2486
 
934
- # unpack latents
2487
+ # pass the sampled action to the environment and get back next state and reward
935
2488
 
936
- tokens, latent_tokens = unpack(tokens, packed_latent_shape, 'b t * d')
2489
+ env_step_out = env.step((sampled_discrete_actions, sampled_continuous_actions))
937
2490
 
938
- # project back to patches
2491
+ if len(env_step_out) == 2:
2492
+ next_frame, reward = env_step_out
2493
+ terminated = full((batch,), False)
2494
+ truncated = full((batch,), False)
939
2495
 
940
- recon_video = self.tokens_to_patch(tokens)
2496
+ elif len(env_step_out) == 3:
2497
+ next_frame, reward, terminated = env_step_out
2498
+ truncated = full((batch,), False)
941
2499
 
942
- return recon_video
2500
+ elif len(env_step_out) == 4:
2501
+ next_frame, reward, terminated, truncated = env_step_out
943
2502
 
944
- def forward(
945
- self,
946
- video, # (b c t h w)
947
- return_latents = False,
948
- mask_patches = None,
949
- return_all_losses = False
950
- ):
951
- batch, _, time, height, width = video.shape
952
- patch_size, device = self.patch_size, video.device
2503
+ elif len(env_step_out) == 5:
2504
+ next_frame, reward, terminated, truncated, info = env_step_out
953
2505
 
954
- assert divisible_by(height, patch_size) and divisible_by(width, patch_size)
2506
+ # update episode lens
955
2507
 
956
- # to tokens
2508
+ episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
957
2509
 
958
- tokens = self.patch_to_tokens(video)
2510
+ # update `is_terminated`
959
2511
 
960
- # get some dimensions
2512
+ # (1) - environment says it is terminated
2513
+ # (2) - previous step is truncated (this step is for bootstrap value)
961
2514
 
962
- num_patch_height, num_patch_width, _ = tokens.shape[-3:]
2515
+ is_terminated |= (terminated | is_truncated)
963
2516
 
964
- # rotary positions
2517
+ # update `is_truncated`
965
2518
 
966
- rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width)
2519
+ if step_index <= max_timesteps:
2520
+ is_truncated |= truncated
967
2521
 
968
- # masking
2522
+ if step_index == max_timesteps:
2523
+ # if the step index is at the max time step allowed, set the truncated flag, if not already terminated
969
2524
 
970
- mask_patches = default(mask_patches, self.training)
2525
+ is_truncated |= ~is_terminated
971
2526
 
972
- if mask_patches:
973
- min_mask_prob, max_mask_prob = self.per_image_patch_mask_prob
2527
+ # batch and time dimension
974
2528
 
975
- mask_prob = torch.empty(tokens.shape[:2], device = tokens.device).uniform_(min_mask_prob, max_mask_prob) # (b t)
2529
+ if env_is_vectorized:
2530
+ next_frame = rearrange(next_frame, 'b c vh vw -> b c 1 vh vw')
2531
+ reward = rearrange(reward, 'b -> b 1')
2532
+ else:
2533
+ next_frame = rearrange(next_frame, 'c vh vw -> 1 c 1 vh vw')
2534
+ reward = rearrange(reward, ' -> 1 1')
2535
+
2536
+ # concat
2537
+
2538
+ video = cat((video, next_frame), dim = 2)
2539
+ rewards = safe_cat((rewards, reward), dim = 1)
2540
+
2541
+ acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
2542
+
2543
+ # package up one experience for learning
2544
+
2545
+ batch, device = latents.shape[0], latents.device
2546
+
2547
+ one_experience = Experience(
2548
+ latents = latents,
2549
+ video = video[:, :, :-1],
2550
+ rewards = rewards,
2551
+ actions = (discrete_actions, continuous_actions),
2552
+ log_probs = (discrete_log_probs, continuous_log_probs),
2553
+ values = values,
2554
+ old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
2555
+ agent_embed = acc_agent_embed if store_agent_embed else None,
2556
+ step_size = step_size,
2557
+ agent_index = agent_index,
2558
+ is_truncated = is_truncated,
2559
+ lens = episode_lens,
2560
+ is_from_world_model = False
2561
+ )
976
2562
 
977
- mask_prob = repeat(mask_prob, 'b t -> b t vh vw', vh = tokens.shape[2], vw = tokens.shape[3])
978
- mask_patch = torch.bernoulli(mask_prob) == 1.
2563
+ return one_experience
979
2564
 
980
- tokens = einx.where('..., d, ... d', mask_patch, self.mask_token, tokens)
2565
+ # ppo
981
2566
 
982
- # pack space
2567
+ def learn_from_experience(
2568
+ self,
2569
+ experience: Experience,
2570
+ policy_optim: Optimizer | None = None,
2571
+ value_optim: Optimizer | None = None,
2572
+ only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
2573
+ use_pmpo = True,
2574
+ normalize_advantages = None,
2575
+ eps = 1e-6
2576
+ ):
2577
+ assert isinstance(experience, Experience)
983
2578
 
984
- tokens, inverse_pack_space = pack_one(tokens, 'b t * d')
2579
+ experience = experience.to(self.device)
985
2580
 
986
- # add the latent
2581
+ latents = experience.latents
2582
+ actions = experience.actions
2583
+ old_log_probs = experience.log_probs
2584
+ old_values = experience.values
2585
+ rewards = experience.rewards
2586
+ agent_embeds = experience.agent_embed
2587
+ old_action_unembeds = experience.old_action_unembeds
987
2588
 
988
- latents = repeat(self.latent_tokens, 'n d -> b t n d', b = tokens.shape[0], t = tokens.shape[1])
2589
+ step_size = experience.step_size
2590
+ agent_index = experience.agent_index
989
2591
 
990
- tokens, packed_latent_shape = pack((tokens, latents), 'b t * d')
2592
+ assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization - world_model.generate(..., return_log_probs_and_values = True)'
991
2593
 
992
- space_seq_len = tokens.shape[-2]
2594
+ batch, time = latents.shape[0], latents.shape[1]
993
2595
 
994
- # pack time
2596
+ # calculate returns
995
2597
 
996
- tokens, inverse_pack_time = pack_one(tokens, 'b * d')
2598
+ # mask out anything after the `lens`, which may include a bootstrapped node at the very end if `is_truncated = True`
997
2599
 
998
- seq_len = tokens.shape[1]
2600
+ if not exists(experience.is_truncated):
2601
+ experience.is_truncated = full((batch,), True, device = latents.device)
999
2602
 
1000
- # attend hyper parameters
2603
+ if exists(experience.lens):
2604
+ mask_for_gae = lens_to_mask(experience.lens, time)
1001
2605
 
1002
- attend_kwargs = dict(
1003
- causal = True,
1004
- causal_block_size = space_seq_len,
1005
- softclamp_value = self.attn_softclamp_value,
1006
- block_size_per_special = space_seq_len,
1007
- num_special_tokens = 1
1008
- )
2606
+ rewards = rewards.masked_fill(~mask_for_gae, 0.)
2607
+ old_values = old_values.masked_fill(~mask_for_gae, 0.)
1009
2608
 
1010
- use_flex = tokens.is_cuda and exists(flex_attention)
2609
+ # calculate returns
1011
2610
 
1012
- # encoder attend
2611
+ returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
1013
2612
 
1014
- # modality can only attend to itself while latents can attend to everything
1015
- # similar to agent token in dynamics model
2613
+ # handle variable lengths
1016
2614
 
1017
- encoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, causal = True, num_special_tokens = self.num_latent_tokens, special_attend_only_itself = False)
2615
+ max_time = latents.shape[1]
2616
+ is_var_len = exists(experience.lens)
1018
2617
 
1019
- # encoder
2618
+ mask = None
1020
2619
 
1021
- for attn, ff in self.encoder_layers:
1022
- tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn) + tokens
1023
- tokens = ff(tokens) + tokens
2620
+ if is_var_len:
2621
+ learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value
2622
+ mask = lens_to_mask(learnable_lens, max_time)
1024
2623
 
1025
- tokens = self.encoder_norm(tokens)
2624
+ # determine whether to finetune entire transformer or just learn the heads
1026
2625
 
1027
- # latent bottleneck
2626
+ world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
1028
2627
 
1029
- tokens = inverse_pack_time(tokens)
2628
+ # maybe keep track returns statistics and normalize returns and values before calculating advantage, as done in dreamer v3
1030
2629
 
1031
- tokens, latents = unpack(tokens, packed_latent_shape, 'b t * d')
2630
+ if self.keep_reward_ema_stats:
2631
+ ema_returns_mean, ema_returns_var = self.ema_returns_mean, self.ema_returns_var
1032
2632
 
1033
- latents = self.encoded_to_latents(latents)
2633
+ decay = 1. - self.reward_ema_decay
1034
2634
 
1035
- if return_latents:
1036
- return latents
2635
+ # quantile filter
1037
2636
 
1038
- recon_video = self.decode(latents, height = height, width = width, rotary_pos_emb = rotary_pos_emb)
2637
+ lo, hi = torch.quantile(returns, self.reward_quantile_filter).tolist()
2638
+ returns_for_stats = returns.clamp(lo, hi)
1039
2639
 
1040
- # losses
2640
+ # mean, var - todo - handle distributed
1041
2641
 
1042
- recon_loss = F.mse_loss(video, recon_video)
2642
+ returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var()
1043
2643
 
1044
- lpips_loss = self.zero
2644
+ # ema
1045
2645
 
1046
- if self.has_lpips_loss:
1047
- lpips_loss = self.lpips(video, recon_video)
2646
+ ema_returns_mean.lerp_(returns_mean, decay)
2647
+ ema_returns_var.lerp_(returns_var, decay)
1048
2648
 
1049
- # losses
2649
+ # normalize
1050
2650
 
1051
- total_loss = (
1052
- recon_loss +
1053
- lpips_loss * self.lpips_loss_weight
1054
- )
2651
+ ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
1055
2652
 
1056
- if not return_all_losses:
1057
- return total_loss
2653
+ normed_returns = (returns - ema_returns_mean) / ema_returns_std
2654
+ normed_old_values = (old_values - ema_returns_mean) / ema_returns_std
1058
2655
 
1059
- losses = (recon_loss, lpips_loss)
2656
+ advantage = normed_returns - normed_old_values
2657
+ else:
2658
+ advantage = returns - old_values
1060
2659
 
1061
- return total_loss, TokenizerLosses(losses)
2660
+ # if using pmpo, do not normalize advantages, but can be overridden
1062
2661
 
1063
- # dynamics model, axial space-time transformer
2662
+ normalize_advantages = default(normalize_advantages, not use_pmpo)
1064
2663
 
1065
- class DynamicsWorldModel(Module):
1066
- def __init__(
1067
- self,
1068
- dim,
1069
- dim_latent,
1070
- video_tokenizer: VideoTokenizer | None = None,
1071
- max_steps = 64, # K_max in paper
1072
- num_register_tokens = 8, # they claim register tokens led to better temporal consistency
1073
- num_spatial_tokens = 2, # latents projected to greater number of spatial tokens
1074
- num_latent_tokens = None,
1075
- num_agents = 1,
1076
- num_tasks = 0,
1077
- reward_encoder_kwargs: dict = dict(),
1078
- depth = 4,
1079
- pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
1080
- time_block_every = 4, # every 4th block is time
1081
- attn_kwargs: dict = dict(
1082
- heads = 8,
1083
- ),
1084
- attn_dim_head = 64,
1085
- attn_softclamp_value = 50.,
1086
- ff_kwargs: dict = dict(),
1087
- loss_weight_fn: Callable = ramp_weight,
1088
- num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
1089
- prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
1090
- add_reward_embed_to_agent_token = False,
1091
- add_reward_embed_dropout = 0.1,
1092
- reward_loss_weight = 0.1,
1093
- value_head_mlp_depth = 3,
1094
- num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
1095
- ):
1096
- super().__init__()
2664
+ if normalize_advantages:
2665
+ advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
1097
2666
 
1098
- # can accept raw video if tokenizer is passed in
2667
+ # https://arxiv.org/abs/2410.04166v1
1099
2668
 
1100
- self.video_tokenizer = video_tokenizer
2669
+ if use_pmpo:
2670
+ pos_advantage_mask = advantage >= 0.
2671
+ neg_advantage_mask = ~pos_advantage_mask
1101
2672
 
1102
- if exists(video_tokenizer):
1103
- num_latent_tokens = default(num_latent_tokens, video_tokenizer.num_latent_tokens)
1104
- assert video_tokenizer.num_latent_tokens == num_latent_tokens, f'`num_latent_tokens` must be the same for the tokenizer and dynamics model'
2673
+ # replay for the action logits and values
2674
+ # but only do so if fine tuning the entire world model for RL
1105
2675
 
1106
- assert exists(num_latent_tokens), '`num_latent_tokens` must be set'
2676
+ discrete_actions, continuous_actions = actions
1107
2677
 
1108
- # spatial
2678
+ if (
2679
+ not only_learn_policy_value_heads or
2680
+ not exists(agent_embeds)
2681
+ ):
1109
2682
 
1110
- self.num_latent_tokens = num_latent_tokens
1111
- self.dim_latent = dim_latent
1112
- self.latent_shape = (num_latent_tokens, dim_latent)
2683
+ with world_model_forward_context():
2684
+ _, (agent_embeds, _) = self.forward(
2685
+ latents = latents,
2686
+ signal_levels = self.max_steps - 1,
2687
+ step_sizes = step_size,
2688
+ rewards = rewards,
2689
+ discrete_actions = discrete_actions,
2690
+ continuous_actions = continuous_actions,
2691
+ latent_is_noised = True,
2692
+ return_pred_only = True,
2693
+ return_intermediates = True
2694
+ )
1113
2695
 
1114
- if num_spatial_tokens >= num_latent_tokens:
1115
- assert divisible_by(num_spatial_tokens, num_latent_tokens)
2696
+ agent_embeds = agent_embeds[..., agent_index, :]
1116
2697
 
1117
- expand_factor = num_spatial_tokens // num_latent_tokens
2698
+ # maybe detach agent embed
1118
2699
 
1119
- self.latents_to_spatial_tokens = Sequential(
1120
- Linear(dim_latent, dim * expand_factor),
1121
- Rearrange('... (s d) -> ... s d', s = expand_factor)
1122
- )
2700
+ if only_learn_policy_value_heads:
2701
+ agent_embeds = agent_embeds.detach()
1123
2702
 
1124
- self.to_latent_pred = Sequential(
1125
- Reduce('b t n s d -> b t n d', 'mean'),
1126
- RMSNorm(dim),
1127
- LinearNoBias(dim, dim_latent)
1128
- )
2703
+ # ppo
1129
2704
 
1130
- else:
1131
- assert divisible_by(num_latent_tokens, num_spatial_tokens)
1132
- latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
2705
+ policy_embed = self.policy_head(agent_embeds)
1133
2706
 
1134
- self.latents_to_spatial_tokens = Sequential(
1135
- Rearrange('b t n d -> b t (n d)'),
1136
- Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
1137
- Rearrange('b t (s d) -> b t s d', s = num_spatial_tokens)
1138
- )
2707
+ log_probs, entropies = self.action_embedder.log_probs(policy_embed, pred_head_index = 0, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True)
1139
2708
 
1140
- self.to_latent_pred = Sequential(
1141
- RMSNorm(dim),
1142
- LinearNoBias(dim, dim_latent * latent_tokens_to_space),
1143
- Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
1144
- )
2709
+ # concat discrete and continuous actions into one for optimizing
1145
2710
 
1146
- # register tokens
2711
+ old_log_probs = safe_cat(old_log_probs, dim = -1)
2712
+ log_probs = safe_cat(log_probs, dim = -1)
2713
+ entropies = safe_cat(entropies, dim = -1)
1147
2714
 
1148
- self.num_register_tokens = num_register_tokens
1149
- self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
2715
+ advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions
1150
2716
 
1151
- # signal and step sizes
2717
+ if use_pmpo:
2718
+ # pmpo - weighting the positive and negative advantages equally - ignoring magnitude of advantage and taking the sign
2719
+ # seems to be weighted across batch and time, iiuc
2720
+ # eq (10) in https://arxiv.org/html/2410.04166v1
1152
2721
 
1153
- assert divisible_by(dim, 2)
1154
- dim_half = dim // 2
2722
+ if exists(mask):
2723
+ pos_advantage_mask &= mask
2724
+ neg_advantage_mask &= mask
1155
2725
 
1156
- assert is_power_two(max_steps), '`max_steps` must be a power of 2'
1157
- self.max_steps = max_steps
1158
- self.num_step_sizes_log2 = int(log2(max_steps))
2726
+ α = self.pmpo_pos_to_neg_weight
1159
2727
 
1160
- self.signal_levels_embed = nn.Embedding(max_steps, dim_half)
1161
- self.step_size_embed = nn.Embedding(self.num_step_sizes_log2, dim_half) # power of 2, so 1/1, 1/2, 1/4, 1/8 ... 1/Kmax
2728
+ pos = masked_mean(log_probs, pos_advantage_mask)
2729
+ neg = -masked_mean(log_probs, neg_advantage_mask)
1162
2730
 
1163
- self.prob_no_shortcut_train = default(prob_no_shortcut_train, self.num_step_sizes_log2 ** -1.)
2731
+ policy_loss = -(α * pos + (1. - α) * neg)
1164
2732
 
1165
- # loss related
2733
+ # take care of kl
1166
2734
 
1167
- self.pred_orig_latent = pred_orig_latent # x-space or v-space
1168
- self.loss_weight_fn = loss_weight_fn
2735
+ if self.pmpo_kl_div_loss_weight > 0.:
2736
+
2737
+ new_unembedded_actions = self.action_embedder.unembed(policy_embed, pred_head_index = 0)
2738
+
2739
+ kl_div_inputs, kl_div_targets = new_unembedded_actions, old_action_unembeds
1169
2740
 
1170
- # reinforcement related
2741
+ # mentioned that the "reverse direction for the prior KL" was used
2742
+ # make optional, as observed instability in toy task
1171
2743
 
1172
- # they sum all the actions into a single token
2744
+ if self.pmpo_reverse_kl:
2745
+ kl_div_inputs, kl_div_targets = kl_div_targets, kl_div_inputs
1173
2746
 
1174
- self.num_agents = num_agents
1175
- self.action_learned_embed = Parameter(randn(self.num_agents, dim) * 1e-2)
2747
+ discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(kl_div_inputs, kl_div_targets)
1176
2748
 
1177
- self.num_tasks = num_tasks
1178
- self.task_embed = nn.Embedding(num_tasks, dim)
2749
+ # accumulate discrete and continuous kl div
1179
2750
 
1180
- # learned set of latent genes
2751
+ kl_div_loss = 0.
1181
2752
 
1182
- self.agent_has_genes = num_latent_genes > 0
1183
- self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
2753
+ if exists(discrete_kl_div):
2754
+ kl_div_loss = kl_div_loss + discrete_kl_div[mask].mean()
1184
2755
 
1185
- # each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
2756
+ if exists(continuous_kl_div):
2757
+ kl_div_loss = kl_div_loss + continuous_kl_div[mask].mean()
1186
2758
 
1187
- self.add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
1188
- self.add_reward_embed_dropout = add_reward_embed_dropout
2759
+ policy_loss = policy_loss + kl_div_loss * self.pmpo_kl_div_loss_weight
1189
2760
 
1190
- self.reward_encoder = SymExpTwoHot(
1191
- **reward_encoder_kwargs,
1192
- dim_embed = dim,
1193
- learned_embedding = add_reward_embed_to_agent_token
1194
- )
2761
+ else:
2762
+ # ppo clipped surrogate loss
1195
2763
 
1196
- self.to_reward_pred = Sequential(
1197
- RMSNorm(dim),
1198
- LinearNoBias(dim, self.reward_encoder.num_bins)
1199
- )
2764
+ ratio = (log_probs - old_log_probs).exp()
2765
+ clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
1200
2766
 
1201
- self.reward_loss_weight = reward_loss_weight
2767
+ policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
2768
+ policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
1202
2769
 
1203
- # value head
2770
+ policy_loss = masked_mean(policy_loss, mask)
1204
2771
 
1205
- self.value_head = create_mlp(
1206
- dim_in = dim,
1207
- dim = dim * 4,
1208
- dim_out = self.reward_encoder.num_bins,
1209
- depth = value_head_mlp_depth,
1210
- )
2772
+ # handle entropy loss for naive exploration bonus
1211
2773
 
1212
- # attention
2774
+ entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
1213
2775
 
1214
- self.attn_softclamp_value = attn_softclamp_value
2776
+ entropy_loss = masked_mean(entropy_loss, mask)
1215
2777
 
1216
- # time rotary embedding
2778
+ # total policy loss
1217
2779
 
1218
- self.time_rotary = Rotary1D(attn_dim_head)
2780
+ total_policy_loss = (
2781
+ policy_loss +
2782
+ entropy_loss * self.policy_entropy_weight
2783
+ )
1219
2784
 
1220
- # transformer
2785
+ # maybe take policy optimizer step
1221
2786
 
1222
- layers = []
1223
- is_time = []
2787
+ if exists(policy_optim):
2788
+ total_policy_loss.backward()
1224
2789
 
1225
- for i in range(depth):
1226
- layer_index = i + 1
2790
+ policy_optim.step()
2791
+ policy_optim.zero_grad()
1227
2792
 
1228
- is_time_block = divisible_by(layer_index, time_block_every)
1229
- is_time.append(is_time_block)
2793
+ # value loss
1230
2794
 
1231
- rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity()
1232
- rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity()
2795
+ value_bins = self.value_head(agent_embeds)
2796
+ values = self.reward_encoder.bins_to_scalar_value(value_bins)
1233
2797
 
1234
- layers.append(ModuleList([
1235
- rearrange_to_attend,
1236
- rearrange_from_attend,
1237
- Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs),
1238
- SwiGLUFeedforward(dim = dim, **ff_kwargs)
1239
- ]))
2798
+ clipped_values = old_values + (values - old_values).clamp(-self.value_clip, self.value_clip)
2799
+ clipped_value_bins = self.reward_encoder(clipped_values)
1240
2800
 
1241
- self.layers = ModuleList(layers)
1242
- self.is_time = is_time
2801
+ return_bins = self.reward_encoder(returns)
1243
2802
 
1244
- # zero
2803
+ value_bins, return_bins, clipped_value_bins = tuple(rearrange(t, 'b t l -> b l t') for t in (value_bins, return_bins, clipped_value_bins))
1245
2804
 
1246
- self.register_buffer('zero', tensor(0.), persistent = False)
2805
+ value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none')
2806
+ value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none')
1247
2807
 
1248
- @property
1249
- def device(self):
1250
- return self.zero.device
2808
+ value_loss = torch.maximum(value_loss_1, value_loss_2)
1251
2809
 
1252
- def get_times_from_signal_level(
1253
- self,
1254
- signal_levels,
1255
- align_dims_left_to = None
1256
- ):
1257
- times = signal_levels.float() / self.max_steps
2810
+ # maybe variable length
1258
2811
 
1259
- if not exists(align_dims_left_to):
1260
- return times
2812
+ if is_var_len:
2813
+ value_loss = value_loss[mask].mean()
2814
+ else:
2815
+ value_loss = value_loss.mean()
1261
2816
 
1262
- return align_dims_left(times, align_dims_left_to)
2817
+ # maybe take value optimizer step
1263
2818
 
1264
- def parameter(self):
1265
- params = super().parameters()
2819
+ if exists(policy_optim):
2820
+ value_loss.backward()
1266
2821
 
1267
- if not exists(self.video_tokenizer):
1268
- return params
2822
+ value_optim.step()
2823
+ value_optim.zero_grad()
1269
2824
 
1270
- return list(set(params) - set(self.video_tokenizer.parameters()))
2825
+ return total_policy_loss, value_loss
1271
2826
 
1272
2827
  @torch.no_grad()
1273
2828
  def generate(
@@ -1276,20 +2831,49 @@ class DynamicsWorldModel(Module):
1276
2831
  num_steps = 4,
1277
2832
  batch_size = 1,
1278
2833
  agent_index = 0,
2834
+ tasks: int | Tensor | None = None,
2835
+ latent_gene_ids = None,
1279
2836
  image_height = None,
1280
2837
  image_width = None,
1281
2838
  return_decoded_video = None,
1282
2839
  context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
1283
- return_rewards_per_frame = False
2840
+ time_kv_cache: Tensor | None = None,
2841
+ use_time_kv_cache = True,
2842
+ return_rewards_per_frame = False,
2843
+ return_agent_actions = False,
2844
+ return_log_probs_and_values = False,
2845
+ return_for_policy_optimization = False,
2846
+ return_time_kv_cache = False,
2847
+ store_agent_embed = True,
2848
+ store_old_action_unembeds = True
1284
2849
 
1285
2850
  ): # (b t n d) | (b c t h w)
1286
2851
 
2852
+ # handy flag for returning generations for rl
2853
+
2854
+ if return_for_policy_optimization:
2855
+ return_agent_actions |= True
2856
+ return_log_probs_and_values |= True
2857
+ return_rewards_per_frame |= True
2858
+
2859
+ # more variables
2860
+
2861
+ has_proprio = self.has_proprio
1287
2862
  was_training = self.training
1288
2863
  self.eval()
1289
2864
 
2865
+ # validation
2866
+
1290
2867
  assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
1291
2868
  assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
1292
2869
 
2870
+ if isinstance(tasks, int):
2871
+ tasks = full((batch_size,), tasks, device = self.device)
2872
+
2873
+ assert not exists(tasks) or tasks.shape[0] == batch_size
2874
+
2875
+ # get state latent shape
2876
+
1293
2877
  latent_shape = self.latent_shape
1294
2878
 
1295
2879
  # derive step size
@@ -1299,12 +2883,41 @@ class DynamicsWorldModel(Module):
1299
2883
  # denoising
1300
2884
  # teacher forcing to start with
1301
2885
 
1302
- latents = empty((batch_size, 0, *latent_shape), device = self.device)
2886
+ latents = empty((batch_size, 0, self.num_video_views, *latent_shape), device = self.device)
2887
+
2888
+ past_latents_context_noise = latents.clone()
2889
+
2890
+ # maybe internal state
2891
+
2892
+ if has_proprio:
2893
+ proprio = empty((batch_size, 0, self.dim_proprio), device = self.device)
2894
+
2895
+ past_proprio_context_noise = proprio.clone()
2896
+
2897
+ # maybe return actions
2898
+
2899
+ return_agent_actions |= return_log_probs_and_values
2900
+
2901
+ decoded_discrete_actions = None
2902
+ decoded_continuous_actions = None
2903
+
2904
+ # policy optimization related
1303
2905
 
1304
- past_context_noise = latents.clone()
2906
+ decoded_discrete_log_probs = None
2907
+ decoded_continuous_log_probs = None
2908
+ decoded_values = None
2909
+
2910
+ # maybe store agent embed
2911
+
2912
+ acc_agent_embed = None
2913
+
2914
+ # maybe store old actions for kl
2915
+
2916
+ acc_policy_embed = None
1305
2917
 
1306
2918
  # maybe return rewards
1307
2919
 
2920
+ decoded_rewards = None
1308
2921
  if return_rewards_per_frame:
1309
2922
  decoded_rewards = empty((batch_size, 0), device = self.device, dtype = torch.float32)
1310
2923
 
@@ -1313,60 +2926,182 @@ class DynamicsWorldModel(Module):
1313
2926
  while latents.shape[1] < time_steps:
1314
2927
 
1315
2928
  curr_time_steps = latents.shape[1]
1316
- noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
1317
2929
 
1318
- for step in range(num_steps):
2930
+ # determine whether to take an extra step if
2931
+ # (1) using time kv cache
2932
+ # (2) decoding anything off agent embedding (rewards, actions, etc)
2933
+
2934
+ take_extra_step = (
2935
+ use_time_kv_cache or
2936
+ return_rewards_per_frame or
2937
+ store_agent_embed or
2938
+ return_agent_actions
2939
+ )
2940
+
2941
+ # prepare noised latent / proprio inputs
2942
+
2943
+ noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
2944
+
2945
+ noised_proprio = None
2946
+
2947
+ if has_proprio:
2948
+ noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device)
2949
+
2950
+ # denoising steps
2951
+
2952
+ for step in range(num_steps + int(take_extra_step)):
2953
+
2954
+ is_last_step = (step + 1) == num_steps
2955
+
1319
2956
  signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
2957
+
2958
+ # noising past latent context
2959
+
2960
+ noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8)
2961
+
2962
+ noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * v n d')
1320
2963
 
1321
- noised_context = latents.lerp(past_context_noise, context_signal_noise) # the paragraph after eq (8)
2964
+ # handle proprio
1322
2965
 
1323
- noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
2966
+ noised_proprio_with_context = None
2967
+
2968
+ if has_proprio:
2969
+ noised_proprio_context = proprio.lerp(past_proprio_context_noise, context_signal_noise)
2970
+ noised_proprio_with_context, _ = pack((noised_proprio_context, noised_proprio), 'b * d')
2971
+
2972
+ # proper signal levels
1324
2973
 
1325
2974
  signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
1326
2975
 
1327
- pred, agent_embed = self.forward(
2976
+ pred, (agent_embed, next_time_kv_cache) = self.forward(
1328
2977
  latents = noised_latent_with_context,
1329
2978
  signal_levels = signal_levels_with_context,
1330
2979
  step_sizes = step_size,
1331
2980
  rewards = decoded_rewards,
2981
+ tasks = tasks,
2982
+ latent_gene_ids = latent_gene_ids,
2983
+ discrete_actions = decoded_discrete_actions,
2984
+ continuous_actions = decoded_continuous_actions,
2985
+ proprio = noised_proprio_with_context,
2986
+ time_kv_cache = time_kv_cache,
1332
2987
  latent_is_noised = True,
2988
+ latent_has_view_dim = True,
1333
2989
  return_pred_only = True,
1334
- return_agent_tokens = True
2990
+ return_intermediates = True,
1335
2991
  )
1336
2992
 
1337
- _, pred = unpack(pred, pack_context_shape, 'b * n d')
2993
+ if use_time_kv_cache and is_last_step:
2994
+ time_kv_cache = next_time_kv_cache
2995
+
2996
+ # early break if taking an extra step for agent embedding off cleaned latents for decoding
2997
+
2998
+ if take_extra_step and is_last_step:
2999
+ break
3000
+
3001
+ # maybe proprio
3002
+
3003
+ if has_proprio:
3004
+ pred, pred_proprio = pred
3005
+
3006
+ # unpack pred
3007
+
3008
+ _, pred = unpack(pred, pack_context_shape, 'b * v n d')
3009
+
3010
+ if has_proprio:
3011
+ _, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d')
1338
3012
 
1339
3013
  # derive flow, based on whether in x-space or not
1340
3014
 
1341
- if self.pred_orig_latent:
1342
- times = self.get_times_from_signal_level(signal_levels, noised_latent)
1343
- flow = (pred - noised_latent) / (1. - times)
1344
- else:
1345
- flow = pred
3015
+ def denoise_step(pred, noised, signal_levels):
3016
+ if self.pred_orig_latent:
3017
+ times = self.get_times_from_signal_level(signal_levels)
3018
+ aligned_times = align_dims_left(times, noised)
3019
+
3020
+ flow = (pred - noised) / (1. - aligned_times)
3021
+ else:
3022
+ flow = pred
3023
+
3024
+ return flow * (step_size / self.max_steps)
1346
3025
 
1347
3026
  # denoise
1348
3027
 
1349
- noised_latent += flow * (step_size / self.max_steps)
3028
+ noised_latent += denoise_step(pred, noised_latent, signal_levels)
3029
+
3030
+ if has_proprio:
3031
+ noised_proprio += denoise_step(pred_proprio, noised_proprio, signal_levels)
1350
3032
 
1351
3033
  denoised_latent = noised_latent # it is now denoised
1352
3034
 
3035
+ if has_proprio:
3036
+ denoised_proprio = noised_proprio
3037
+
1353
3038
  # take care of the rewards by predicting on the agent token embedding on the last denoising step
1354
3039
 
1355
3040
  if return_rewards_per_frame:
1356
3041
  one_agent_embed = agent_embed[:, -1:, agent_index]
1357
3042
 
1358
- reward_logits = self.to_reward_pred(one_agent_embed)
3043
+ reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0)
1359
3044
  pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
1360
3045
 
1361
3046
  decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
1362
3047
 
3048
+ # maybe store agent embed
3049
+
3050
+ if store_agent_embed:
3051
+ one_agent_embed = agent_embed[:, -1:, agent_index]
3052
+ acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
3053
+
3054
+ # decode the agent actions if needed
3055
+
3056
+ if return_agent_actions:
3057
+ assert self.action_embedder.has_actions
3058
+
3059
+ one_agent_embed = agent_embed[:, -1:, agent_index]
3060
+
3061
+ policy_embed = self.policy_head(one_agent_embed)
3062
+
3063
+ # maybe store old actions
3064
+
3065
+ if store_old_action_unembeds:
3066
+ acc_policy_embed = safe_cat((acc_policy_embed, policy_embed), dim = 1)
3067
+
3068
+ # sample actions
3069
+
3070
+ sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
3071
+
3072
+ decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
3073
+ decoded_continuous_actions = safe_cat((decoded_continuous_actions, sampled_continuous_actions), dim = 1)
3074
+
3075
+ if return_log_probs_and_values:
3076
+ discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
3077
+ policy_embed,
3078
+ pred_head_index = 0,
3079
+ discrete_targets = sampled_discrete_actions,
3080
+ continuous_targets = sampled_continuous_actions,
3081
+ )
3082
+
3083
+ decoded_discrete_log_probs = safe_cat((decoded_discrete_log_probs, discrete_log_probs), dim = 1)
3084
+ decoded_continuous_log_probs = safe_cat((decoded_continuous_log_probs, continuous_log_probs), dim = 1)
3085
+
3086
+ value_bins = self.value_head(one_agent_embed)
3087
+ values = self.reward_encoder.bins_to_scalar_value(value_bins)
3088
+
3089
+ decoded_values = safe_cat((decoded_values, values), dim = 1)
3090
+
1363
3091
  # concat the denoised latent
1364
3092
 
1365
3093
  latents = cat((latents, denoised_latent), dim = 1)
1366
3094
 
1367
3095
  # add new fixed context noise for the temporal consistency
1368
3096
 
1369
- past_context_noise = cat((past_context_noise, randn_like(denoised_latent)), dim = 1)
3097
+ past_latents_context_noise = cat((past_latents_context_noise, randn_like(denoised_latent)), dim = 1)
3098
+
3099
+ # handle proprio
3100
+
3101
+ if has_proprio:
3102
+ proprio = cat((proprio, denoised_proprio), dim = 1)
3103
+
3104
+ past_proprio_context_noise = cat((past_proprio_context_noise, randn_like(denoised_proprio)), dim = 1)
1370
3105
 
1371
3106
  # restore state
1372
3107
 
@@ -1377,52 +3112,128 @@ class DynamicsWorldModel(Module):
1377
3112
  has_tokenizer = exists(self.video_tokenizer)
1378
3113
  return_decoded_video = default(return_decoded_video, has_tokenizer)
1379
3114
 
1380
- if not return_decoded_video:
1381
- if not return_rewards_per_frame:
1382
- return denoised_latents
3115
+ video = None
3116
+
3117
+ if return_decoded_video:
3118
+
3119
+ latents_for_video = rearrange(latents, 'b t v n d -> b v t n d')
3120
+ latents_for_video, unpack_view = pack_one(latents_for_video, '* t n d')
3121
+
3122
+ video = self.video_tokenizer.decode(
3123
+ latents_for_video,
3124
+ height = image_height,
3125
+ width = image_width
3126
+ )
3127
+
3128
+ video = unpack_view(video, '* t c vh vw')
3129
+
3130
+ # remove the lone view dimension
1383
3131
 
1384
- return denoised_latents, decoded_rewards
3132
+ if not self.video_has_multi_view:
3133
+ latents = rearrange(latents, 'b t 1 ... -> b t ...')
1385
3134
 
1386
- generated_video = self.video_tokenizer.decode(
1387
- latents,
1388
- height = image_height,
1389
- width = image_width
3135
+ if exists(video):
3136
+ video = rearrange(video, 'b 1 ... -> b ...')
3137
+
3138
+ # only return video or latent if not requesting anything else, for first stage training
3139
+
3140
+ if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
3141
+ out = video if return_decoded_video else latents
3142
+
3143
+ if not return_time_kv_cache:
3144
+ return out
3145
+
3146
+ return out, time_kv_cache
3147
+
3148
+ # returning agent actions, rewards, and log probs + values for policy optimization
3149
+
3150
+ batch, device = latents.shape[0], latents.device
3151
+ experience_lens = full((batch,), time_steps, device = device)
3152
+
3153
+ gen = Experience(
3154
+ latents = latents,
3155
+ video = video,
3156
+ proprio = proprio if has_proprio else None,
3157
+ agent_embed = acc_agent_embed if store_agent_embed else None,
3158
+ old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
3159
+ step_size = step_size,
3160
+ agent_index = agent_index,
3161
+ lens = experience_lens,
3162
+ is_from_world_model = True
1390
3163
  )
1391
3164
 
1392
- if not return_rewards_per_frame:
1393
- return generated_video
3165
+ if return_rewards_per_frame:
3166
+ gen.rewards = decoded_rewards
3167
+
3168
+ if return_agent_actions:
3169
+ gen.actions = (decoded_discrete_actions, decoded_continuous_actions)
3170
+
3171
+ if return_log_probs_and_values:
3172
+ gen.log_probs = (decoded_discrete_log_probs, decoded_continuous_log_probs)
3173
+
3174
+ gen.values = decoded_values
3175
+
3176
+ if not return_time_kv_cache:
3177
+ return gen
1394
3178
 
1395
- return generated_video, decoded_rewards
3179
+ return gen, time_kv_cache
1396
3180
 
1397
3181
  def forward(
1398
3182
  self,
1399
3183
  *,
1400
- video = None,
1401
- latents = None, # (b t n d) | (b t d)
1402
- signal_levels = None, # () | (b) | (b t)
1403
- step_sizes = None, # () | (b)
1404
- step_sizes_log2 = None, # () | (b)
1405
- tasks = None, # (b)
1406
- rewards = None, # (b t)
1407
- latent_gene_ids = None, # (b)
3184
+ video = None, # (b v? c t vh vw)
3185
+ latents = None, # (b t v? n d) | (b t v? d)
3186
+ lens = None, # (b)
3187
+ signal_levels = None, # () | (b) | (b t)
3188
+ step_sizes = None, # () | (b)
3189
+ step_sizes_log2 = None, # () | (b)
3190
+ latent_gene_ids = None, # (b)
3191
+ tasks = None, # (b)
3192
+ rewards = None, # (b t)
3193
+ discrete_actions = None, # (b t na) | (b t-1 na)
3194
+ continuous_actions = None, # (b t na) | (b t-1 na)
3195
+ discrete_action_types = None, # (na)
3196
+ continuous_action_types = None, # (na)
3197
+ proprio = None, # (b t dp)
3198
+ time_kv_cache = None,
1408
3199
  return_pred_only = False,
1409
3200
  latent_is_noised = False,
1410
3201
  return_all_losses = False,
1411
- return_agent_tokens = False
3202
+ return_intermediates = False,
3203
+ add_autoregressive_action_loss = True,
3204
+ update_loss_ema = None,
3205
+ latent_has_view_dim = False
1412
3206
  ):
1413
3207
  # handle video or latents
1414
3208
 
1415
3209
  assert exists(video) ^ exists(latents)
1416
3210
 
3211
+ # standardize view dimension
3212
+
3213
+ if not self.video_has_multi_view:
3214
+ if exists(video):
3215
+ video = rearrange(video, 'b ... -> b 1 ...')
3216
+
3217
+ if exists(latents) and not latent_has_view_dim:
3218
+ latents = rearrange(latents, 'b t ... -> b t 1 ...')
3219
+
3220
+ # if raw video passed in, tokenize
3221
+
1417
3222
  if exists(video):
3223
+ assert video.ndim == 6
3224
+
3225
+ video, unpack_views = pack_one(video, '* c t vh vw')
1418
3226
  assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
1419
3227
 
1420
3228
  latents = self.video_tokenizer.tokenize(video)
3229
+ latents = unpack_views(latents, '* t n d')
3230
+ latents = rearrange(latents, 'b v t n d -> b t v n d')
1421
3231
 
1422
- if latents.ndim == 3:
1423
- latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
3232
+ if latents.ndim == 4:
3233
+ latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
1424
3234
 
1425
- assert latents.shape[-2:] == self.latent_shape
3235
+ assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}'
3236
+ assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}'
1426
3237
 
1427
3238
  # variables
1428
3239
 
@@ -1497,29 +3308,30 @@ class DynamicsWorldModel(Module):
1497
3308
 
1498
3309
  # times is from 0 to 1
1499
3310
 
1500
- times = self.get_times_from_signal_level(signal_levels, latents)
3311
+ times = self.get_times_from_signal_level(signal_levels)
1501
3312
 
1502
3313
  if not latent_is_noised:
1503
3314
  # get the noise
1504
3315
 
1505
3316
  noise = randn_like(latents)
3317
+ aligned_times = align_dims_left(times, latents)
1506
3318
 
1507
3319
  # noise from 0 as noise to 1 as data
1508
3320
 
1509
- noised_latents = noise.lerp(latents, times)
3321
+ noised_latents = noise.lerp(latents, aligned_times)
1510
3322
 
1511
3323
  else:
1512
3324
  noised_latents = latents
1513
3325
 
1514
3326
  # reinforcement learning related
1515
3327
 
1516
- agent_tokens = repeat(self.action_learned_embed, '... d -> b ... d', b = batch)
3328
+ agent_tokens = repeat(self.agent_learned_embed, '... d -> b ... d', b = batch)
1517
3329
 
1518
3330
  if exists(tasks):
1519
3331
  assert self.num_tasks > 0
1520
3332
 
1521
3333
  task_embeds = self.task_embed(tasks)
1522
- agent_tokens = agent_tokens + task_embeds
3334
+ agent_tokens = add('b ... d, b d', agent_tokens, task_embeds)
1523
3335
 
1524
3336
  # maybe evolution
1525
3337
 
@@ -1527,13 +3339,15 @@ class DynamicsWorldModel(Module):
1527
3339
  assert exists(self.latent_genes)
1528
3340
  latent_genes = self.latent_genes[latent_gene_ids]
1529
3341
 
1530
- agent_tokens = einx.add('b ... d, b d', agent_tokens, latent_genes)
3342
+ agent_tokens = add('b ... d, b d', agent_tokens, latent_genes)
1531
3343
 
1532
3344
  # handle agent tokens w/ actions and task embeds
1533
3345
 
1534
3346
  agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
1535
3347
 
1536
- # maybe add a reward embedding to agent tokens
3348
+ # maybe reward tokens
3349
+
3350
+ reward_tokens = agent_tokens[:, :, 0:0]
1537
3351
 
1538
3352
  if exists(rewards):
1539
3353
  two_hot_encoding = self.reward_encoder(rewards)
@@ -1542,30 +3356,102 @@ class DynamicsWorldModel(Module):
1542
3356
  self.add_reward_embed_to_agent_token and
1543
3357
  (not self.training or not sample_prob(self.add_reward_embed_dropout)) # a bit of noise goes a long way
1544
3358
  ):
1545
- reward_embeds = self.reward_encoder.embed(two_hot_encoding)
3359
+ assert self.num_agents == 1
3360
+
3361
+ reward_tokens = self.reward_encoder.embed(two_hot_encoding)
3362
+
3363
+ pop_last_reward = int(reward_tokens.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case
3364
+
3365
+ reward_tokens = pad_at_dim(reward_tokens, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
3366
+
3367
+ reward_tokens = add('1 d, b t d', self.reward_learned_embed, reward_tokens)
3368
+
3369
+ # maybe proprioception
3370
+
3371
+ assert xnor(self.has_proprio, exists(proprio)), 'proprio must be passed in if `dim_proprio` is set and vice versa'
1546
3372
 
1547
- pop_last_reward = int(reward_embeds.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case
3373
+ noised_proprio = None
1548
3374
 
1549
- reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward
3375
+ if self.has_proprio:
1550
3376
 
1551
- agent_tokens = einx.add('b t ... d, b t d', agent_tokens, reward_embeds)
3377
+ if not latent_is_noised:
3378
+ # get the noise
3379
+
3380
+ proprio_noise = randn_like(proprio)
3381
+ aligned_times = align_dims_left(times, proprio)
3382
+
3383
+ # noise from 0 as noise to 1 as data
3384
+
3385
+ noised_proprio = proprio_noise.lerp(proprio, aligned_times)
3386
+
3387
+ else:
3388
+ noised_proprio = proprio
3389
+
3390
+ # maybe create the action tokens
3391
+
3392
+ if exists(discrete_actions) or exists(continuous_actions):
3393
+ assert self.action_embedder.has_actions
3394
+ assert self.num_agents == 1, 'only one agent allowed for now'
3395
+
3396
+ action_tokens = self.action_embedder(
3397
+ discrete_actions = discrete_actions,
3398
+ discrete_action_types = discrete_action_types,
3399
+ continuous_actions = continuous_actions,
3400
+ continuous_action_types = continuous_action_types
3401
+ )
3402
+
3403
+ # handle first timestep not having an associated past action
3404
+
3405
+ if action_tokens.shape[1] == (time - 1):
3406
+ action_tokens = pad_at_dim(action_tokens, (1, 0), value = 0. , dim = 1)
3407
+
3408
+ action_tokens = add('1 d, b t d', self.action_learned_embed, action_tokens)
3409
+
3410
+ elif self.action_embedder.has_actions:
3411
+ action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1])
3412
+
3413
+ else:
3414
+ action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
1552
3415
 
1553
3416
  # main function, needs to be defined as such for shortcut training - additional calls for consistency loss
1554
3417
 
1555
- def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = False):
3418
+ def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
3419
+
1556
3420
  # latents to spatial tokens
1557
3421
 
1558
3422
  space_tokens = self.latents_to_spatial_tokens(noised_latents)
1559
3423
 
3424
+ # maybe add view embedding
3425
+
3426
+ if self.video_has_multi_view:
3427
+ space_tokens = add('b t v ... d, v d', space_tokens, self.view_emb)
3428
+
3429
+ # merge spatial tokens
3430
+
1560
3431
  space_tokens, inverse_pack_space_per_latent = pack_one(space_tokens, 'b t * d')
1561
3432
 
1562
3433
  num_spatial_tokens = space_tokens.shape[-2]
1563
3434
 
3435
+ # action tokens
3436
+
3437
+ num_action_tokens = 1 if not is_empty(action_tokens) else 0
3438
+
3439
+ # reward tokens
3440
+
3441
+ num_reward_tokens = 1 if not is_empty(reward_tokens) else 0
3442
+
1564
3443
  # pack to tokens
1565
3444
  # [signal + step size embed] [latent space tokens] [register] [actions / agent]
1566
3445
 
1567
3446
  registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time)
1568
3447
 
3448
+ # maybe proprio
3449
+
3450
+ if exists(noised_proprio):
3451
+ proprio_token = self.to_proprio_token(noised_proprio)
3452
+ else:
3453
+ proprio_token = registers[:, :, 0:0]
3454
+
1569
3455
  # determine signal + step size embed for their diffusion forcing + shortcut
1570
3456
 
1571
3457
  signal_embed = self.signal_levels_embed(signal_levels)
@@ -1578,77 +3464,86 @@ class DynamicsWorldModel(Module):
1578
3464
 
1579
3465
  # pack to tokens for attending
1580
3466
 
1581
- tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_tokens], 'b t * d')
3467
+ tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
1582
3468
 
1583
- # attend functions for space and time
3469
+ # attention
1584
3470
 
1585
- seq_len = tokens.shape[1]
3471
+ tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
1586
3472
 
1587
- use_flex = exists(flex_attention) and tokens.is_cuda
3473
+ # unpack
1588
3474
 
1589
- attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device)
3475
+ flow_token, space_tokens, proprio_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
1590
3476
 
1591
- space_seq_len = (
1592
- + 1 # signal + step
1593
- + self.num_agents # action / agent tokens
1594
- + self.num_register_tokens
1595
- + num_spatial_tokens
1596
- )
3477
+ # pooling
1597
3478
 
1598
- space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = self.num_agents, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
3479
+ space_tokens = inverse_pack_space_per_latent(space_tokens)
1599
3480
 
1600
- time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
3481
+ pred = self.to_latent_pred(space_tokens)
1601
3482
 
1602
- # rotary
3483
+ # maybe proprio
1603
3484
 
1604
- rotary_pos_emb = self.time_rotary(time)
3485
+ if self.has_proprio:
3486
+ pred_proprio = self.to_proprio_pred(proprio_token)
1605
3487
 
1606
- # attention
3488
+ pred = (pred, pred_proprio)
1607
3489
 
1608
- for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
3490
+ # returning
1609
3491
 
1610
- tokens = pre_attn_rearrange(tokens)
3492
+ if not return_agent_tokens:
3493
+ return pred
1611
3494
 
1612
- # when is a axial time attention block, should be causal
3495
+ if not return_time_kv_cache:
3496
+ return pred, agent_tokens
1613
3497
 
1614
- attend_fn = time_attend if layer_is_time else space_attend
3498
+ return pred, (agent_tokens, next_time_kv_cache)
1615
3499
 
1616
- layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
3500
+ # curry into get_prediction what does not change during first call as well as the shortcut ones
1617
3501
 
1618
- # attention layer
3502
+ _get_prediction = partial(get_prediction, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
1619
3503
 
1620
- tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens
3504
+ # forward the network
1621
3505
 
1622
- tokens = post_attn_rearrange(tokens)
3506
+ pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
1623
3507
 
1624
- # feedforward layer
3508
+ if return_pred_only:
3509
+ if not return_intermediates:
3510
+ return pred
1625
3511
 
1626
- tokens = ff(tokens) + tokens
3512
+ return pred, (encoded_agent_tokens, next_time_kv_cache)
1627
3513
 
1628
- # unpack
3514
+ # pack the predictions to calculate flow for different modalities all at once
1629
3515
 
1630
- flow_token, space_tokens, register_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
3516
+ if self.has_proprio:
3517
+ pred, for_flow_loss_packed_shape = pack(pred, 'b t *')
1631
3518
 
1632
- # pooling
3519
+ noised, _ = pack((noised_latents, noised_proprio), 'b t *')
3520
+ data, _ = pack((latents, proprio), 'b t *')
3521
+ noise, _ = pack((noise, proprio_noise), 'b t *')
3522
+ else:
3523
+ noised = noised_latents
3524
+ data = latents
1633
3525
 
1634
- space_tokens = inverse_pack_space_per_latent(space_tokens)
3526
+ # wrapper function for maybe unpacking and packing modalities for doing flow math in unison
1635
3527
 
1636
- pred = self.to_latent_pred(space_tokens)
3528
+ def maybe_pack_unpack(fn):
3529
+ @wraps(fn)
3530
+ @torch.no_grad()
3531
+ def inner(noised, *args, **kwargs):
1637
3532
 
1638
- if not return_agent_tokens:
1639
- return pred
3533
+ noised_proprio = None
1640
3534
 
1641
- return pred, agent_tokens
3535
+ if self.has_proprio:
3536
+ noised, noised_proprio = unpack(noised, for_flow_loss_packed_shape, 'b t *')
1642
3537
 
1643
- # forward the network
3538
+ pred = fn(noised, noised_proprio, *args, **kwargs)
1644
3539
 
1645
- pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = True)
3540
+ if self.has_proprio:
3541
+ pred, _ = pack(pred, 'b t *')
1646
3542
 
1647
- if return_pred_only:
1648
- if not return_agent_tokens:
1649
3543
  return pred
3544
+ return inner
1650
3545
 
1651
- return pred, encoded_agent_tokens
3546
+ wrapped_get_prediction = maybe_pack_unpack(_get_prediction)
1652
3547
 
1653
3548
  # determine the target for the loss
1654
3549
 
@@ -1665,46 +3560,45 @@ class DynamicsWorldModel(Module):
1665
3560
  # x-space as in paper is in else clause
1666
3561
 
1667
3562
  if is_v_space_pred:
1668
- pred_target = flow = latents - noise
3563
+ pred_target = flow = data - noise
1669
3564
  else:
1670
- pred_target = latents
3565
+ pred_target = data
1671
3566
  else:
1672
3567
  # shortcut training - Frans et al. https://arxiv.org/abs/2410.12557
1673
3568
 
1674
3569
  # basically a consistency loss where you ensure quantity of two half steps equals one step
1675
3570
  # dreamer then makes it works for x-space with some math
1676
3571
 
1677
- get_prediction_no_grad = torch.no_grad()(get_prediction)
1678
-
1679
3572
  step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
1680
3573
  half_step_size = 2 ** step_sizes_log2_minus_one
1681
3574
 
1682
- first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, agent_tokens)
3575
+ first_step_pred = wrapped_get_prediction(noised, signal_levels, step_sizes_log2_minus_one)
1683
3576
 
1684
3577
  # first derive b'
1685
3578
 
1686
3579
  if is_v_space_pred:
1687
3580
  first_step_pred_flow = first_step_pred
1688
3581
  else:
1689
- first_times = self.get_times_from_signal_level(signal_levels, noised_latents)
1690
- first_step_pred_flow = (first_step_pred - noised_latents) / (1. - first_times)
3582
+ first_times = self.get_times_from_signal_level(signal_levels, noised)
3583
+
3584
+ first_step_pred_flow = (first_step_pred - noised) / (1. - first_times)
1691
3585
 
1692
3586
  # take a half step
1693
3587
 
1694
- half_step_size_align_left = align_dims_left(half_step_size, noised_latents)
3588
+ half_step_size_align_left = align_dims_left(half_step_size, noised)
1695
3589
 
1696
- denoised_latent = noised_latents + first_step_pred_flow * (half_step_size_align_left / self.max_steps)
3590
+ denoised = noised + first_step_pred_flow * (half_step_size_align_left / self.max_steps)
1697
3591
 
1698
3592
  # get second prediction for b''
1699
3593
 
1700
3594
  signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
1701
- second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, agent_tokens)
3595
+ second_step_pred = wrapped_get_prediction(denoised, signal_levels_plus_half_step, step_sizes_log2_minus_one)
1702
3596
 
1703
3597
  if is_v_space_pred:
1704
3598
  second_step_pred_flow = second_step_pred
1705
3599
  else:
1706
- second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised_latent)
1707
- second_step_pred_flow = (second_step_pred - denoised_latent) / (1. - second_times)
3600
+ second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised)
3601
+ second_step_pred_flow = (second_step_pred - denoised) / (1. - second_times)
1708
3602
 
1709
3603
  # pred target is sg(b' + b'') / 2
1710
3604
 
@@ -1713,7 +3607,7 @@ class DynamicsWorldModel(Module):
1713
3607
  # need to convert x-space to v-space
1714
3608
 
1715
3609
  if is_x_space:
1716
- pred = (pred - noised_latents) / (1. - first_times)
3610
+ pred = (pred - noised) / (1. - first_times)
1717
3611
  maybe_shortcut_loss_weight = (1. - first_times) ** 2
1718
3612
 
1719
3613
  # mse loss
@@ -1726,9 +3620,23 @@ class DynamicsWorldModel(Module):
1726
3620
 
1727
3621
  if exists(self.loss_weight_fn):
1728
3622
  loss_weight = self.loss_weight_fn(times)
3623
+ loss_weight = align_dims_left(loss_weight, flow_losses)
3624
+
1729
3625
  flow_losses = flow_losses * loss_weight
1730
3626
 
1731
- flow_loss = flow_losses.mean()
3627
+ # handle variable lengths if needed
3628
+
3629
+ is_var_len = exists(lens)
3630
+
3631
+ if is_var_len:
3632
+
3633
+ loss_mask = lens_to_mask(lens, time)
3634
+ loss_mask_without_last = loss_mask[:, :-1]
3635
+
3636
+ flow_loss = flow_losses[loss_mask].mean()
3637
+
3638
+ else:
3639
+ flow_loss = flow_losses.mean()
1732
3640
 
1733
3641
  # now take care of the agent token losses
1734
3642
 
@@ -1739,28 +3647,114 @@ class DynamicsWorldModel(Module):
1739
3647
  if rewards.ndim == 2: # (b t)
1740
3648
  encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
1741
3649
 
1742
- reward_pred = self.to_reward_pred(encoded_agent_tokens)
1743
- reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
3650
+ reward_pred = self.to_reward_pred(encoded_agent_tokens[:, :-1])
3651
+
3652
+ reward_pred = rearrange(reward_pred, 'mtp b t l -> b l t mtp')
3653
+
3654
+ reward_targets, reward_loss_mask = create_multi_token_prediction_targets(two_hot_encoding[:, :-1], self.multi_token_pred_len)
3655
+
3656
+ reward_targets = rearrange(reward_targets, 'b t mtp l -> b l t mtp')
3657
+
3658
+ reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
3659
+
3660
+ reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.)
3661
+
3662
+ if is_var_len:
3663
+ reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
3664
+ else:
3665
+ reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
3666
+
3667
+ # maybe autoregressive action loss
3668
+
3669
+ discrete_action_loss = self.zero
3670
+ continuous_action_loss = self.zero
3671
+
3672
+ if (
3673
+ self.num_agents == 1 and
3674
+ add_autoregressive_action_loss and
3675
+ time > 1,
3676
+ (exists(discrete_actions) or exists(continuous_actions))
3677
+ ):
3678
+ assert self.action_embedder.has_actions
3679
+
3680
+ # handle actions having time vs time - 1 length
3681
+ # remove the first action if it is equal to time (as it would come from some agent token in the past)
3682
+
3683
+ if exists(discrete_actions) and discrete_actions.shape[1] == time:
3684
+ discrete_actions = discrete_actions[:, 1:]
3685
+
3686
+ if exists(continuous_actions) and continuous_actions.shape[1] == time:
3687
+ continuous_actions = continuous_actions[:, 1:]
3688
+
3689
+ # only for 1 agent
3690
+
3691
+ agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d')
3692
+ policy_embed = self.policy_head(agent_tokens[:, :-1])
3693
+
3694
+ # constitute multi token prediction targets
3695
+
3696
+ discrete_action_targets = continuous_action_targets = None
3697
+
3698
+ if exists(discrete_actions):
3699
+ discrete_action_targets, discrete_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
3700
+ discrete_action_targets = rearrange(discrete_action_targets, 'b t mtp ... -> mtp b t ...')
3701
+ discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
3702
+
3703
+ if exists(continuous_actions):
3704
+ continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(continuous_actions, self.multi_token_pred_len)
3705
+ continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
3706
+ continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')
3707
+
3708
+ discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
3709
+ policy_embed,
3710
+ discrete_targets = discrete_action_targets if exists(discrete_actions) else None,
3711
+ continuous_targets = continuous_action_targets if exists(continuous_actions) else None
3712
+ )
3713
+
3714
+ if exists(discrete_log_probs):
3715
+ discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
3716
+
3717
+ if is_var_len:
3718
+ discrete_action_losses = rearrange(-discrete_log_probs, 'mtp b t na -> b t na mtp')
3719
+ discrete_action_loss = reduce(discrete_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
3720
+ else:
3721
+ discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
3722
+
3723
+ if exists(continuous_log_probs):
3724
+ continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
3725
+
3726
+ if is_var_len:
3727
+ continuous_action_losses = rearrange(-continuous_log_probs, 'mtp b t na -> b t na mtp')
3728
+ continuous_action_loss = reduce(continuous_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
3729
+ else:
3730
+ continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
3731
+
3732
+ # handle loss normalization
3733
+
3734
+ losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss)
3735
+
3736
+ if exists(self.flow_loss_normalizer):
3737
+ flow_loss = self.flow_loss_normalizer(flow_loss, update_ema = update_loss_ema)
3738
+
3739
+ if exists(rewards) and exists(self.reward_loss_normalizer):
3740
+ reward_loss = self.reward_loss_normalizer(reward_loss, update_ema = update_loss_ema)
1744
3741
 
1745
- # gather losses
3742
+ if exists(discrete_actions) and exists(self.discrete_actions_loss_normalizer):
3743
+ discrete_action_loss = self.discrete_actions_loss_normalizer(discrete_action_loss, update_ema = update_loss_ema)
3744
+
3745
+ if exists(continuous_actions) and exists(self.continuous_actions_loss_normalizer):
3746
+ continuous_action_loss = self.continuous_actions_loss_normalizer(continuous_action_loss, update_ema = update_loss_ema)
3747
+
3748
+ # gather losses - they sum across the multi token prediction steps for rewards and actions - eq (9)
1746
3749
 
1747
3750
  total_loss = (
1748
- flow_loss +
1749
- reward_loss * self.reward_loss_weight
3751
+ flow_loss * self.latent_flow_loss_weight +
3752
+ (reward_loss * self.reward_loss_weight).sum() +
3753
+ (discrete_action_loss * self.discrete_action_loss_weight).sum() +
3754
+ (continuous_action_loss * self.continuous_action_loss_weight).sum()
1750
3755
  )
1751
3756
 
1752
3757
  if not return_all_losses:
1753
3758
  return total_loss
1754
3759
 
1755
- return total_loss, (flow_loss, reward_loss)
1756
-
1757
- # dreamer
1758
-
1759
- class Dreamer(Module):
1760
- def __init__(
1761
- self,
1762
- video_tokenizer: VideoTokenizer,
1763
- dynamics_model: DynamicsModel,
1764
- discount_factor = 0.997
1765
- ):
1766
- super().__init__()
3760
+ return total_loss, losses