dreamer4 0.0.31__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

dreamer4/dreamer4.py CHANGED
@@ -3,16 +3,18 @@ from __future__ import annotations
3
3
  import math
4
4
  from math import ceil, log2
5
5
  from random import random
6
+ from contextlib import nullcontext
6
7
  from collections import namedtuple
7
- from functools import partial
8
- from dataclasses import dataclass
8
+ from functools import partial, wraps
9
+ from dataclasses import dataclass, asdict
9
10
 
10
11
  import torch
11
12
  import torch.nn.functional as F
12
13
  from torch.nested import nested_tensor
13
- from torch.distributions import Normal
14
+ from torch.distributions import Normal, kl
14
15
  from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
15
- from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
16
+ from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
17
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
16
18
 
17
19
  import torchvision
18
20
  from torchvision.models import VGG16_Weights
@@ -20,11 +22,13 @@ from torchvision.models import VGG16_Weights
20
22
  from torch.optim import Optimizer
21
23
  from adam_atan2_pytorch import MuonAdamAtan2
22
24
 
23
- from x_mlps_pytorch.normed_mlp import create_mlp
24
25
  from x_mlps_pytorch.ensemble import Ensemble
26
+ from x_mlps_pytorch.normed_mlp import create_mlp
25
27
 
26
28
  from hyper_connections import get_init_and_expand_reduce_stream_functions
27
29
 
30
+ from vit_pytorch.vit_with_decorr import DecorrelationLoss
31
+
28
32
  from assoc_scan import AssocScan
29
33
 
30
34
  # ein related
@@ -35,12 +39,15 @@ from assoc_scan import AssocScan
35
39
  # d - feature dimension
36
40
  # f - frequencies (rotary)
37
41
  # l - logit / predicted bins
42
+ # y - layers of transformer
38
43
  # p - positions (3 for spacetime in this work)
39
44
  # t - time
40
45
  # na - action dimension (number of discrete and continuous actions)
41
46
  # g - groups of query heads to key heads (gqa)
42
47
  # vc - video channels
43
48
  # vh, vw - video height and width
49
+ # mtp - multi token prediction length
50
+ # v - video viewpoints
44
51
 
45
52
  import einx
46
53
  from einx import add, multiply
@@ -63,22 +70,96 @@ except ImportError:
63
70
 
64
71
  LinearNoBias = partial(Linear, bias = False)
65
72
 
66
- TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
73
+ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
74
+
75
+ WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
67
76
 
68
- WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone'))
77
+ AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
78
+
79
+ TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs'))
80
+
81
+ MaybeTensor = Tensor | None
69
82
 
70
83
  @dataclass
71
84
  class Experience:
72
85
  latents: Tensor
73
- video: Tensor | None = None
86
+ video: MaybeTensor = None
87
+ proprio: MaybeTensor = None
88
+ agent_embed: MaybeTensor = None
74
89
  rewards: Tensor | None = None
75
- actions: tuple[Tensor, Tensor] | None = None
76
- log_probs: tuple[Tensor, Tensor] | None = None
77
- values: Tensor | None = None
90
+ actions: tuple[MaybeTensor, MaybeTensor] | None = None
91
+ log_probs: tuple[MaybeTensor, MaybeTensor] | None = None
92
+ old_action_unembeds: tuple[MaybeTensor, MaybeTensor] | None = None
93
+ values: MaybeTensor = None
78
94
  step_size: int | None = None
95
+ lens: MaybeTensor = None
96
+ is_truncated: MaybeTensor = None
79
97
  agent_index: int = 0
80
98
  is_from_world_model: bool = True
81
99
 
100
+ def cpu(self):
101
+ return self.to(torch.device('cpu'))
102
+
103
+ def to(self, device):
104
+ experience_dict = asdict(self)
105
+ experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
106
+ return Experience(**experience_dict)
107
+
108
+ def combine_experiences(
109
+ exps: list[Experiences]
110
+ ) -> Experience:
111
+
112
+ assert len(exps) > 0
113
+
114
+ # set lens if not there
115
+
116
+ for exp in exps:
117
+ latents = exp.latents
118
+ batch, time, device = *latents.shape[:2], latents.device
119
+
120
+ if not exists(exp.lens):
121
+ exp.lens = full((batch,), time, device = device)
122
+
123
+ if not exists(exp.is_truncated):
124
+ exp.is_truncated = full((batch,), True, device = device)
125
+
126
+ # convert to dictionary
127
+
128
+ exps_dict = [asdict(exp) for exp in exps]
129
+
130
+ values, tree_specs = zip(*[tree_flatten(exp_dict) for exp_dict in exps_dict])
131
+
132
+ tree_spec = first(tree_specs)
133
+
134
+ all_field_values = list(zip(*values))
135
+
136
+ # an assert to make sure all fields are either all tensors, or a single matching value (for step size, agent index etc) - can change this later
137
+
138
+ assert all([
139
+ all([is_tensor(v) for v in field_values]) or len(set(field_values)) == 1
140
+ for field_values in all_field_values
141
+ ])
142
+
143
+ concatted = []
144
+
145
+ for field_values in all_field_values:
146
+
147
+ if is_tensor(first(field_values)):
148
+
149
+ field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2))
150
+
151
+ new_field_value = cat(field_values)
152
+ else:
153
+ new_field_value = first(list(set(field_values)))
154
+
155
+ concatted.append(new_field_value)
156
+
157
+ # return experience
158
+
159
+ concat_exp_dict = tree_unflatten(concatted, tree_spec)
160
+
161
+ return Experience(**concat_exp_dict)
162
+
82
163
  # helpers
83
164
 
84
165
  def exists(v):
@@ -90,6 +171,9 @@ def default(v, d):
90
171
  def first(arr):
91
172
  return arr[0]
92
173
 
174
+ def xnor(x, y):
175
+ return not (x ^ y)
176
+
93
177
  def has_at_least_one(*bools):
94
178
  return sum([*map(int, bools)]) > 0
95
179
 
@@ -105,14 +189,55 @@ def sample_prob(prob):
105
189
  def is_power_two(num):
106
190
  return log2(num).is_integer()
107
191
 
192
+ def maybe(fn):
193
+ def inner(t, *args, **kwargs):
194
+ if not exists(t) or not exists(fn):
195
+ return None
196
+ return fn(t)
197
+ return inner
198
+
108
199
  # tensor helpers
109
200
 
110
201
  def is_empty(t):
111
202
  return t.numel() == 0
112
203
 
204
+ def lens_to_mask(t, max_len = None):
205
+ if not exists(max_len):
206
+ max_len = t.amax().item()
207
+
208
+ device = t.device
209
+ seq = torch.arange(max_len, device = device)
210
+
211
+ return einx.less('j, i -> i j', seq, t)
212
+
213
+ def masked_mean(t, mask = None):
214
+ if not exists(mask):
215
+ return t.mean()
216
+
217
+ if not mask.any():
218
+ return t[mask].sum()
219
+
220
+ return t[mask].mean()
221
+
113
222
  def log(t, eps = 1e-20):
114
223
  return t.clamp(min = eps).log()
115
224
 
225
+ def mean_log_var_to_distr(
226
+ mean_log_var: Tensor
227
+ ) -> Normal:
228
+
229
+ mean, log_var = mean_log_var.unbind(dim = -1)
230
+ std = (0.5 * log_var).exp()
231
+ return Normal(mean, std)
232
+
233
+ def safe_stack(tensors, dim = 0):
234
+ tensors = [*filter(exists, tensors)]
235
+
236
+ if len(tensors) == 0:
237
+ return None
238
+
239
+ return stack(tensors, dim = dim)
240
+
116
241
  def safe_cat(tensors, dim):
117
242
  tensors = [*filter(exists, tensors)]
118
243
 
@@ -123,6 +248,15 @@ def safe_cat(tensors, dim):
123
248
 
124
249
  return cat(tensors, dim = dim)
125
250
 
251
+ def safe_squeeze_first(t):
252
+ if not exists(t):
253
+ return None
254
+
255
+ if t.shape[0] != 1:
256
+ return t
257
+
258
+ return rearrange(t, '1 ... -> ...')
259
+
126
260
  def gumbel_noise(t):
127
261
  noise = torch.rand_like(t)
128
262
  return -log(-log(noise))
@@ -159,6 +293,27 @@ def pad_at_dim(
159
293
  zeros = ((0, 0) * dims_from_right)
160
294
  return F.pad(t, (*zeros, *pad), value = value)
161
295
 
296
+ def pad_to_len(t, target_len, *, dim):
297
+ curr_len = t.shape[dim]
298
+
299
+ if curr_len >= target_len:
300
+ return t
301
+
302
+ return pad_at_dim(t, (0, target_len - curr_len), dim = dim)
303
+
304
+ def pad_tensors_at_dim_to_max_len(
305
+ tensors: list[Tensor],
306
+ dims: tuple[int, ...]
307
+ ):
308
+ for dim in dims:
309
+ if dim >= first(tensors).ndim:
310
+ continue
311
+
312
+ max_time = max([t.shape[dim] for t in tensors])
313
+ tensors = [pad_to_len(t, max_time, dim = dim) for t in tensors]
314
+
315
+ return tensors
316
+
162
317
  def align_dims_left(t, aligned_to):
163
318
  shape = t.shape
164
319
  num_right_dims = aligned_to.ndim - t.ndim
@@ -174,8 +329,74 @@ def l2norm(t):
174
329
  def softclamp(t, value = 50.):
175
330
  return (t / value).tanh() * value
176
331
 
332
+ def create_multi_token_prediction_targets(
333
+ t, # (b t ...)
334
+ steps_future,
335
+
336
+ ): # (b t-1 steps ...), (b t-1 steps) - targets and the mask, where mask is False for padding
337
+
338
+ batch, seq_len, device = *t.shape[:2], t.device
339
+
340
+ batch_arange = arange(batch, device = device)
341
+ seq_arange = arange(seq_len, device = device)
342
+ steps_arange = arange(steps_future, device = device)
343
+
344
+ indices = add('t, steps -> t steps', seq_arange, steps_arange)
345
+ mask = indices < seq_len
346
+
347
+ batch_arange = rearrange(batch_arange, 'b -> b 1 1')
348
+
349
+ indices[~mask] = 0
350
+ mask = repeat(mask, 't steps -> b t steps', b = batch)
351
+
352
+ out = t[batch_arange, indices]
353
+
354
+ return out, mask
355
+
177
356
  # loss related
178
357
 
358
+ class LossNormalizer(Module):
359
+
360
+ # the authors mentioned the need for loss normalization in the dynamics transformer
361
+
362
+ def __init__(
363
+ self,
364
+ num_losses: int,
365
+ beta = 0.95,
366
+ eps = 1e-6
367
+ ):
368
+ super().__init__()
369
+ self.register_buffer('exp_avg_sq', torch.ones(num_losses))
370
+ self.beta = beta
371
+ self.eps = eps
372
+
373
+ def forward(
374
+ self,
375
+ losses: Tensor | list[Tensor] | dict[str, Tensor],
376
+ update_ema = None
377
+ ):
378
+ exp_avg_sq, beta = self.exp_avg_sq, self.beta
379
+ update_ema = default(update_ema, self.training)
380
+
381
+ # get the rms value - as mentioned at the end of section 3 in the paper
382
+
383
+ rms = exp_avg_sq.sqrt()
384
+
385
+ if update_ema:
386
+ decay = 1. - beta
387
+
388
+ # update the ema
389
+
390
+ exp_avg_sq.lerp_(losses.detach().square(), decay)
391
+
392
+ # then normalize
393
+
394
+ assert losses.numel() == rms.numel()
395
+
396
+ normed_losses = losses / rms.clamp(min = self.eps)
397
+
398
+ return normed_losses
399
+
179
400
  class LPIPSLoss(Module):
180
401
  def __init__(
181
402
  self,
@@ -340,7 +561,9 @@ class ActionEmbedder(Module):
340
561
  num_continuous_actions = 0,
341
562
  continuous_norm_stats: tuple[tuple[float, float], ...] | None = None,
342
563
  can_unembed = False,
343
- unembed_dim = None
564
+ unembed_dim = None,
565
+ num_unembed_preds = 1,
566
+ squeeze_unembed_preds = True # will auto-squeeze if prediction is just 1
344
567
  ):
345
568
  super().__init__()
346
569
 
@@ -378,11 +601,14 @@ class ActionEmbedder(Module):
378
601
 
379
602
  self.can_unembed = can_unembed
380
603
 
604
+ self.num_unembed_preds = num_unembed_preds
605
+ self.squeeze_unembed_preds = squeeze_unembed_preds
606
+
381
607
  if not can_unembed:
382
608
  return
383
609
 
384
610
  unembed_dim = default(unembed_dim, dim)
385
- self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, unembed_dim) * 1e-2)
611
+ self.discrete_action_unembed = Parameter(torch.randn(total_discrete_actions, num_unembed_preds, unembed_dim) * 1e-2)
386
612
 
387
613
  discrete_action_index = arange(total_discrete_actions)
388
614
 
@@ -396,13 +622,13 @@ class ActionEmbedder(Module):
396
622
 
397
623
  self.register_buffer('discrete_action_mask', discrete_action_mask, persistent = False)
398
624
 
399
- self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2)
625
+ self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, num_unembed_preds, unembed_dim, 2) * 1e-2)
400
626
 
401
627
  def embed_parameters(self):
402
628
  return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()])
403
629
 
404
630
  def unembed_parameters(self):
405
- return set([*self.discrete_action_unembed.parameters(), *self.continuous_action_unembed.parameters()])
631
+ return set([self.discrete_action_unembed, self.continuous_action_unembed])
406
632
 
407
633
  @property
408
634
  def device(self):
@@ -429,12 +655,26 @@ class ActionEmbedder(Module):
429
655
  embeds, # (... d)
430
656
  discrete_action_types = None, # (na)
431
657
  continuous_action_types = None, # (na)
432
- return_split_discrete = False
658
+ return_split_discrete = False,
659
+ pred_head_index: int | Tensor | None = None
433
660
 
434
661
  ): # (... discrete_na), (... continuous_na 2)
435
662
 
663
+ device = embeds.device
664
+
436
665
  assert self.can_unembed, 'can only unembed for predicted discrete and continuous actions if `can_unembed = True` is set on init'
437
666
 
667
+ # handle only one prediction head during inference
668
+
669
+ if exists(pred_head_index) and isinstance(pred_head_index, int):
670
+ pred_head_index = tensor(pred_head_index, device = device)
671
+
672
+ # if pred_head_index given as a solo int, just assume we want to squeeze out the prediction head dimension
673
+
674
+ squeeze_one_pred_head = exists(pred_head_index) and pred_head_index.ndim == 0
675
+
676
+ # get action types
677
+
438
678
  discrete_action_types, continuous_action_types = tuple(self.cast_action_types(t) for t in (discrete_action_types, continuous_action_types))
439
679
 
440
680
  # discrete actions
@@ -450,7 +690,13 @@ class ActionEmbedder(Module):
450
690
 
451
691
  discrete_action_unembed = discrete_action_unembed[discrete_action_mask]
452
692
 
453
- discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na d -> ... na')
693
+ if exists(pred_head_index):
694
+ discrete_action_unembed = discrete_action_unembed.index_select(1, pred_head_index)
695
+
696
+ discrete_action_logits = einsum(embeds, discrete_action_unembed, '... d, na mtp d -> mtp ... na')
697
+
698
+ if self.squeeze_unembed_preds or squeeze_one_pred_head:
699
+ discrete_action_logits = safe_squeeze_first(discrete_action_logits)
454
700
 
455
701
  # whether to split the discrete action logits by the number of actions per action type
456
702
 
@@ -471,7 +717,13 @@ class ActionEmbedder(Module):
471
717
  if exists(continuous_action_types):
472
718
  continuous_action_unembed = continuous_action_unembed[continuous_action_types]
473
719
 
474
- continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na d two -> ... na two')
720
+ if exists(pred_head_index):
721
+ continuous_action_unembed = continuous_action_unembed.index_select(1, pred_head_index)
722
+
723
+ continuous_action_mean_log_var = einsum(embeds, continuous_action_unembed, '... d, na mtp d two -> mtp ... na two')
724
+
725
+ if self.squeeze_unembed_preds or squeeze_one_pred_head:
726
+ continuous_action_mean_log_var = safe_squeeze_first(continuous_action_mean_log_var)
475
727
 
476
728
  return discrete_action_logits, continuous_action_mean_log_var
477
729
 
@@ -480,9 +732,14 @@ class ActionEmbedder(Module):
480
732
  embed,
481
733
  discrete_temperature = 1.,
482
734
  continuous_temperature = 1.,
735
+ inverse_norm_continuous = None,
736
+ pred_head_index: int | Tensor | None = None,
737
+ squeeze = True,
483
738
  **kwargs
484
739
  ):
485
- discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, **kwargs)
740
+ inverse_norm_continuous = default(inverse_norm_continuous, self.continuous_need_norm)
741
+
742
+ discrete_logits, continuous_mean_log_var = self.unembed(embed, return_split_discrete = True, pred_head_index = pred_head_index, **kwargs)
486
743
 
487
744
  sampled_discrete = sampled_continuous = None
488
745
 
@@ -500,6 +757,12 @@ class ActionEmbedder(Module):
500
757
 
501
758
  sampled_continuous = mean + std * torch.randn_like(mean) * continuous_temperature
502
759
 
760
+ # maybe inverse norm
761
+
762
+ if inverse_norm_continuous:
763
+ norm_mean, norm_std = self.continuous_norm_stats.unbind(dim = -1)
764
+ sampled_continuous = (sampled_continuous * norm_std) + norm_mean
765
+
503
766
  return sampled_discrete, sampled_continuous
504
767
 
505
768
  def log_probs(
@@ -509,12 +772,13 @@ class ActionEmbedder(Module):
509
772
  continuous_targets = None, # (... na)
510
773
  discrete_action_types = None, # (na)
511
774
  continuous_action_types = None, # (na)
775
+ pred_head_index: int | Tensor | None = None,
512
776
  parallel_discrete_calc = None,
513
777
  return_entropies = False
514
778
  ):
515
779
  parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1)
516
780
 
517
- discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
781
+ discrete_action_logits, continuous_action_mean_log_var = self.unembed(embeds, pred_head_index = pred_head_index, discrete_action_types = discrete_action_types, continuous_action_types = continuous_action_types, return_split_discrete = True)
518
782
 
519
783
  # discrete
520
784
 
@@ -600,10 +864,7 @@ class ActionEmbedder(Module):
600
864
  continuous_entropies = None
601
865
 
602
866
  if exists(continuous_targets):
603
- mean, log_var = continuous_action_mean_log_var.unbind(dim = -1)
604
- std = (0.5 * log_var).exp()
605
-
606
- distr = Normal(mean, std)
867
+ distr = mean_log_var_to_distr(continuous_action_mean_log_var)
607
868
  continuous_log_probs = distr.log_prob(continuous_targets)
608
869
 
609
870
  if return_entropies:
@@ -618,6 +879,64 @@ class ActionEmbedder(Module):
618
879
 
619
880
  return log_probs, entropies
620
881
 
882
+ def kl_div(
883
+ self,
884
+ src: tuple[MaybeTensor, MaybeTensor],
885
+ tgt: tuple[MaybeTensor, MaybeTensor],
886
+ reduce_across_num_actions = True
887
+ ) -> tuple[MaybeTensor, MaybeTensor]:
888
+
889
+ src_discrete, src_continuous = src
890
+ tgt_discrete, tgt_continuous = tgt
891
+
892
+ discrete_kl_div = None
893
+
894
+ # split discrete if it is not already (multiple discrete actions)
895
+
896
+ if exists(src_discrete):
897
+
898
+ discrete_split = self.num_discrete_actions.tolist()
899
+
900
+ if is_tensor(src_discrete):
901
+ src_discrete = src_discrete.split(discrete_split, dim = -1)
902
+
903
+ if is_tensor(tgt_discrete):
904
+ tgt_discrete = tgt_discrete.split(discrete_split, dim = -1)
905
+
906
+ discrete_kl_divs = []
907
+
908
+ for src_logit, tgt_logit in zip(src_discrete, tgt_discrete):
909
+
910
+ src_log_probs = src_logit.log_softmax(dim = -1)
911
+ tgt_prob = tgt_logit.softmax(dim = -1)
912
+
913
+ one_discrete_kl_div = F.kl_div(src_log_probs, tgt_prob, reduction = 'none')
914
+
915
+ discrete_kl_divs.append(one_discrete_kl_div.sum(dim = -1))
916
+
917
+ discrete_kl_div = stack(discrete_kl_divs, dim = -1)
918
+
919
+ # calculate kl divergence for continuous
920
+
921
+ continuous_kl_div = None
922
+
923
+ if exists(src_continuous):
924
+ src_normal = mean_log_var_to_distr(src_continuous)
925
+ tgt_normal = mean_log_var_to_distr(tgt_continuous)
926
+
927
+ continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
928
+
929
+ # maybe reduce
930
+
931
+ if reduce_across_num_actions:
932
+ if exists(discrete_kl_div):
933
+ discrete_kl_div = discrete_kl_div.sum(dim = -1)
934
+
935
+ if exists(continuous_kl_div):
936
+ continuous_kl_div = continuous_kl_div.sum(dim = -1)
937
+
938
+ return discrete_kl_div, continuous_kl_div
939
+
621
940
  def forward(
622
941
  self,
623
942
  *,
@@ -726,11 +1045,12 @@ class Rotary1D(Module):
726
1045
 
727
1046
  def forward(
728
1047
  self,
729
- seq_len
1048
+ seq_len,
1049
+ offset = 0
730
1050
  ):
731
1051
  device, dtype = self.inv_freq.device, self.inv_freq.dtype
732
1052
 
733
- t = torch.arange(seq_len, device = device).type(dtype)
1053
+ t = torch.arange(seq_len, device = device).type(dtype) + offset
734
1054
  freqs = einsum(t, self.inv_freq, 'i, j -> i j')
735
1055
 
736
1056
  return cat((freqs, freqs), dim = -1)
@@ -740,7 +1060,18 @@ def apply_rotations(
740
1060
  rotations, # (h n d) | (n d)
741
1061
  t # (b h n d)
742
1062
  ):
743
- heads, dtype = t.shape[1], t.dtype
1063
+
1064
+ heads, seq_len, dtype = *t.shape[1:3], t.dtype
1065
+
1066
+ rotations_seq_len = rotations.shape[-2]
1067
+
1068
+ # handle kv caching with rotations
1069
+
1070
+ if rotations_seq_len > seq_len:
1071
+ rotations = rotations[-seq_len:]
1072
+
1073
+ # precision
1074
+
744
1075
  t = t.float()
745
1076
 
746
1077
  # handle gqa for rotary
@@ -877,10 +1208,11 @@ def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = F
877
1208
 
878
1209
  def block_mask_special_tokens_right(
879
1210
  seq_len,
880
- num_tokens
1211
+ num_tokens,
1212
+ special_attend_only_itself = False
881
1213
  ):
882
1214
  def inner(b, h, q, k):
883
- return special_token_mask(q, k, seq_len, num_tokens)
1215
+ return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself)
884
1216
  return inner
885
1217
 
886
1218
  def compose_mask(mask1, mask2):
@@ -957,6 +1289,10 @@ class Attention(Module):
957
1289
  query_heads = None,
958
1290
  heads = 8,
959
1291
  pre_rmsnorm = True,
1292
+ gate_values = True,
1293
+ rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2
1294
+ rmsnorm_key = True,
1295
+ value_residual = True
960
1296
  ):
961
1297
  super().__init__()
962
1298
  self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
@@ -979,12 +1315,33 @@ class Attention(Module):
979
1315
  self.to_v = LinearNoBias(dim, dim_kv_inner)
980
1316
  self.to_out = LinearNoBias(dim_q_inner, dim)
981
1317
 
1318
+ # alphafold gating per head, for attending to nothing
1319
+
1320
+ self.to_gates = None
1321
+
1322
+ if gate_values:
1323
+ self.to_gates = Sequential(
1324
+ LinearNoBias(dim, query_heads),
1325
+ Rearrange('b n h -> b h n 1'),
1326
+ nn.Sigmoid()
1327
+ )
1328
+
982
1329
  # stability related
983
1330
 
984
- self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
985
- self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
1331
+ self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity()
1332
+ self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
1333
+
1334
+ # value residual
1335
+
1336
+ self.to_learned_value_residual_mix = nn.Sequential(
1337
+ nn.Linear(dim, heads),
1338
+ Rearrange('b n h -> b h n 1'),
1339
+ nn.Sigmoid()
1340
+ ) if value_residual else None
986
1341
 
987
1342
  def muon_parameters(self):
1343
+ # omit the queries and keys for now given what we learned from kimi 2 paper
1344
+
988
1345
  return [
989
1346
  *self.to_v.parameters(),
990
1347
  *self.to_out.parameters(),
@@ -994,8 +1351,9 @@ class Attention(Module):
994
1351
  self,
995
1352
  tokens, # (b n d)
996
1353
  kv_cache = None,
997
- return_kv_cache = False,
1354
+ return_intermediates = False,
998
1355
  rotary_pos_emb = None,
1356
+ residual_values = None, # (b n h d)
999
1357
  attend_fn: Callable | None = None
1000
1358
  ):
1001
1359
  tokens, inverse_packed_batch = pack_one(tokens, '* n d')
@@ -1008,11 +1366,28 @@ class Attention(Module):
1008
1366
 
1009
1367
  q, k, v = map(self.split_heads, (q, k, v))
1010
1368
 
1369
+ # handle maybe value residual
1370
+
1371
+ if exists(residual_values):
1372
+ residual_values = rearrange(residual_values, '... n h d -> (...) h n d')
1373
+
1374
+ assert exists(self.to_learned_value_residual_mix)
1375
+
1376
+ learned_mix = self.to_learned_value_residual_mix(tokens)
1377
+
1378
+ v = v.lerp(residual_values, learned_mix)
1379
+
1011
1380
  # qk rmsnorm
1012
1381
 
1013
1382
  q = self.q_heads_rmsnorm(q)
1014
1383
  k = self.k_heads_rmsnorm(k)
1015
1384
 
1385
+ # rotary
1386
+
1387
+ if exists(rotary_pos_emb):
1388
+ q = apply_rotations(rotary_pos_emb, q)
1389
+ k = apply_rotations(rotary_pos_emb, k)
1390
+
1016
1391
  # caching
1017
1392
 
1018
1393
  if exists(kv_cache):
@@ -1020,18 +1395,18 @@ class Attention(Module):
1020
1395
  k = cat((ck, k), dim = -2)
1021
1396
  v = cat((cv, v), dim = -2)
1022
1397
 
1023
- # rotary
1024
-
1025
- if exists(rotary_pos_emb):
1026
- q = apply_rotations(rotary_pos_emb, q)
1027
- k = apply_rotations(rotary_pos_emb, k)
1028
-
1029
1398
  # attention
1030
1399
 
1031
1400
  attend_fn = default(attend_fn, naive_attend)
1032
1401
 
1033
1402
  out = attend_fn(q, k, v)
1034
1403
 
1404
+ # gate values
1405
+
1406
+ if exists(self.to_gates):
1407
+ gates = self.to_gates(tokens)
1408
+ out = out * gates
1409
+
1035
1410
  # merge heads
1036
1411
 
1037
1412
  out = self.merge_heads(out)
@@ -1042,10 +1417,10 @@ class Attention(Module):
1042
1417
 
1043
1418
  out = inverse_packed_batch(out)
1044
1419
 
1045
- if not return_kv_cache:
1420
+ if not return_intermediates:
1046
1421
  return out
1047
1422
 
1048
- return out, stack((k, v))
1423
+ return out, AttentionIntermediates(stack((k, v)), tokens)
1049
1424
 
1050
1425
  # feedforward
1051
1426
 
@@ -1085,6 +1460,7 @@ class AxialSpaceTimeTransformer(Module):
1085
1460
  self,
1086
1461
  dim,
1087
1462
  depth,
1463
+ attn_heads = 8,
1088
1464
  attn_dim_head = 64,
1089
1465
  attn_softclamp_value = 50.,
1090
1466
  time_block_every = 4,
@@ -1093,9 +1469,12 @@ class AxialSpaceTimeTransformer(Module):
1093
1469
  num_residual_streams = 1,
1094
1470
  num_special_spatial_tokens = 1,
1095
1471
  special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
1096
- final_norm = True
1472
+ final_norm = True,
1473
+ value_residual = True, # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS
1474
+ rnn_time = False
1097
1475
  ):
1098
1476
  super().__init__()
1477
+ assert depth >= time_block_every, f'depth must be at least {time_block_every}'
1099
1478
 
1100
1479
  # hyper connections
1101
1480
 
@@ -1113,6 +1492,24 @@ class AxialSpaceTimeTransformer(Module):
1113
1492
 
1114
1493
  self.time_rotary = Rotary1D(attn_dim_head)
1115
1494
 
1495
+ # project initial for value residuals
1496
+
1497
+ self.value_residual = value_residual
1498
+
1499
+ if value_residual:
1500
+ dim_inner = attn_dim_head * attn_heads
1501
+
1502
+ self.to_value_residual = nn.Sequential(
1503
+ nn.RMSNorm(dim),
1504
+ nn.Linear(dim, dim_inner, bias = False),
1505
+ Rearrange('... (h d) -> ... h d', h = attn_heads)
1506
+ )
1507
+
1508
+ # a gru layer across time
1509
+
1510
+ self.rnn_time = rnn_time
1511
+ rnn_layers = []
1512
+
1116
1513
  # transformer
1117
1514
 
1118
1515
  layers = []
@@ -1124,17 +1521,24 @@ class AxialSpaceTimeTransformer(Module):
1124
1521
  is_time_block = divisible_by(layer_index, time_block_every)
1125
1522
  is_time.append(is_time_block)
1126
1523
 
1127
- rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity()
1128
- rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity()
1524
+ rearrange_to_attend = Rearrange('b t s ... -> b s t ...') if is_time_block else Identity()
1525
+ rearrange_from_attend = Rearrange('b s t ... -> b t s ...') if is_time_block else Identity()
1129
1526
 
1130
1527
  layers.append(ModuleList([
1131
1528
  rearrange_to_attend,
1132
1529
  rearrange_from_attend,
1133
- hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
1530
+ hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, value_residual = value_residual, **attn_kwargs)),
1134
1531
  hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
1135
1532
  ]))
1136
1533
 
1534
+ rnn_layers.append(ModuleList([
1535
+ nn.RMSNorm(dim),
1536
+ nn.GRU(dim, dim, batch_first = True)
1537
+ ]) if is_time_block and rnn_time else None)
1538
+
1137
1539
  self.layers = ModuleList(layers)
1540
+ self.rnn_layers = ModuleList(rnn_layers)
1541
+
1138
1542
  self.is_time = is_time
1139
1543
 
1140
1544
  # final norm
@@ -1145,17 +1549,31 @@ class AxialSpaceTimeTransformer(Module):
1145
1549
 
1146
1550
  self.num_special_spatial_tokens = num_special_spatial_tokens
1147
1551
 
1552
+ def muon_parameters(self):
1553
+ muon_params = []
1554
+
1555
+ for m in self.modules():
1556
+ if isinstance(m, (Attention, SwiGLUFeedforward)):
1557
+ muon_params.extend(m.muon_parameters())
1558
+
1559
+ return muon_params
1560
+
1148
1561
  def forward(
1149
1562
  self,
1150
- tokens # (b t s d)
1151
- ):
1563
+ tokens, # (b t s d)
1564
+ kv_cache: Tensor | None = None, # (y 2 b h t d)
1565
+ return_intermediates = False
1566
+
1567
+ ): # (b t s d) | (y 2 b h t d)
1568
+
1152
1569
  batch, time, space_seq_len, _, device = *tokens.shape, tokens.device
1153
1570
 
1154
1571
  assert tokens.ndim == 4
1155
1572
 
1156
1573
  # attend functions for space and time
1157
1574
 
1158
- use_flex = exists(flex_attention) and tokens.is_cuda
1575
+ has_kv_cache = exists(kv_cache)
1576
+ use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix
1159
1577
 
1160
1578
  attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
1161
1579
 
@@ -1163,37 +1581,120 @@ class AxialSpaceTimeTransformer(Module):
1163
1581
 
1164
1582
  time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
1165
1583
 
1584
+ # prepare cache
1585
+
1586
+ time_attn_kv_caches = []
1587
+
1588
+ if has_kv_cache:
1589
+ past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
1590
+
1591
+ rotary_seq_len = 1
1592
+ rotary_pos_offset = past_tokens.shape[1]
1593
+ else:
1594
+ rotary_seq_len = time
1595
+ rotary_pos_offset = 0
1596
+
1597
+ kv_cache = default(kv_cache, (None,))
1598
+
1599
+ iter_kv_cache = iter(kv_cache)
1600
+
1166
1601
  # rotary
1167
1602
 
1168
- rotary_pos_emb = self.time_rotary(time)
1603
+ rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
1604
+
1605
+ # value residual
1606
+
1607
+ residual_values = None
1608
+
1609
+ if self.value_residual:
1610
+ residual_values = self.to_value_residual(tokens)
1611
+
1612
+ # normed attention inputs
1613
+
1614
+ normed_time_attn_inputs = []
1615
+ normed_space_attn_inputs = []
1169
1616
 
1170
1617
  # attention
1171
1618
 
1172
1619
  tokens = self.expand_streams(tokens)
1173
1620
 
1174
- for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
1621
+ for (pre_attn_rearrange, post_attn_rearrange, attn, ff), maybe_rnn_modules, layer_is_time in zip(self.layers, self.rnn_layers, self.is_time):
1175
1622
 
1176
1623
  tokens = pre_attn_rearrange(tokens)
1177
1624
 
1625
+ # maybe rnn for time
1626
+
1627
+ if layer_is_time and exists(maybe_rnn_modules):
1628
+ rnn_prenorm, rnn = maybe_rnn_modules
1629
+
1630
+ rnn_input, inverse_pack_time = pack_one(tokens, '* t d')
1631
+
1632
+ rnn_out, rnn_hiddens = rnn(rnn_prenorm(rnn_input)) # todo, handle rnn cache
1633
+
1634
+ rnn_out = inverse_pack_time(rnn_out)
1635
+
1636
+ tokens = rnn_out + tokens
1637
+
1178
1638
  # when is a axial time attention block, should be causal
1179
1639
 
1180
1640
  attend_fn = time_attend if layer_is_time else space_attend
1181
1641
 
1182
1642
  layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
1183
1643
 
1644
+ # maybe past kv cache
1645
+
1646
+ maybe_kv_cache = next(iter_kv_cache, None) if layer_is_time else None
1647
+
1648
+ # residual values
1649
+
1650
+ layer_residual_values = maybe(pre_attn_rearrange)(residual_values)
1651
+
1184
1652
  # attention layer
1185
1653
 
1186
- tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens
1654
+ tokens, attn_intermediates = attn(
1655
+ tokens,
1656
+ rotary_pos_emb = layer_rotary_pos_emb,
1657
+ attend_fn = attend_fn,
1658
+ kv_cache = maybe_kv_cache,
1659
+ residual_values = layer_residual_values,
1660
+ return_intermediates = True
1661
+ )
1187
1662
 
1188
1663
  tokens = post_attn_rearrange(tokens)
1189
1664
 
1190
1665
  # feedforward layer
1191
1666
 
1192
- tokens = ff(tokens) + tokens
1667
+ tokens = ff(tokens)
1668
+
1669
+ # save kv cache if is time layer
1670
+
1671
+ if layer_is_time:
1672
+ time_attn_kv_caches.append(attn_intermediates.next_kv_cache)
1673
+
1674
+ # save time attention inputs for decorr
1675
+
1676
+ space_or_time_inputs = normed_time_attn_inputs if layer_is_time else normed_space_attn_inputs
1677
+
1678
+ space_or_time_inputs.append(attn_intermediates.normed_inputs)
1193
1679
 
1194
1680
  tokens = self.reduce_streams(tokens)
1195
1681
 
1196
- return self.final_norm(tokens)
1682
+ out = self.final_norm(tokens)
1683
+
1684
+ if has_kv_cache:
1685
+ # just concat the past tokens back on for now, todo - clean up the logic
1686
+ out = cat((past_tokens, out), dim = 1)
1687
+
1688
+ if not return_intermediates:
1689
+ return out
1690
+
1691
+ intermediates = TransformerIntermediates(
1692
+ stack(time_attn_kv_caches),
1693
+ safe_stack(normed_time_attn_inputs),
1694
+ safe_stack(normed_space_attn_inputs)
1695
+ )
1696
+
1697
+ return out, intermediates
1197
1698
 
1198
1699
  # video tokenizer
1199
1700
 
@@ -1219,12 +1720,15 @@ class VideoTokenizer(Module):
1219
1720
  per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
1220
1721
  lpips_loss_network: Module | None = None,
1221
1722
  lpips_loss_weight = 0.2,
1723
+ encoder_add_decor_aux_loss = False,
1724
+ decor_auxx_loss_weight = 0.1,
1725
+ decorr_sample_frac = 0.25,
1222
1726
  nd_rotary_kwargs: dict = dict(
1223
1727
  rope_min_freq = 1.,
1224
1728
  rope_max_freq = 10000.,
1225
1729
  rope_p_zero_freqs = 0.
1226
1730
  ),
1227
- num_residual_streams = 1
1731
+ num_residual_streams = 1,
1228
1732
  ):
1229
1733
  super().__init__()
1230
1734
 
@@ -1305,6 +1809,7 @@ class VideoTokenizer(Module):
1305
1809
  time_block_every = time_block_every,
1306
1810
  num_special_spatial_tokens = num_latent_tokens,
1307
1811
  num_residual_streams = num_residual_streams,
1812
+ special_attend_only_itself = True,
1308
1813
  final_norm = True
1309
1814
  )
1310
1815
 
@@ -1318,10 +1823,24 @@ class VideoTokenizer(Module):
1318
1823
  if self.has_lpips_loss:
1319
1824
  self.lpips = LPIPSLoss(lpips_loss_network)
1320
1825
 
1826
+ # decorr aux loss
1827
+ # https://arxiv.org/abs/2510.14657
1828
+
1829
+ self.encoder_add_decor_aux_loss = encoder_add_decor_aux_loss
1830
+ self.decorr_aux_loss_weight = decor_auxx_loss_weight
1831
+
1832
+ self.decorr_loss = DecorrelationLoss(decorr_sample_frac, soft_validate_num_sampled = True) if encoder_add_decor_aux_loss else None
1833
+
1321
1834
  @property
1322
1835
  def device(self):
1323
1836
  return self.zero.device
1324
1837
 
1838
+ def muon_parameters(self):
1839
+ return [
1840
+ *self.encoder_transformer.muon_parameters(),
1841
+ *self.decoder_transformer.muon_parameters()
1842
+ ]
1843
+
1325
1844
  @torch.no_grad()
1326
1845
  def tokenize(
1327
1846
  self,
@@ -1425,7 +1944,7 @@ class VideoTokenizer(Module):
1425
1944
 
1426
1945
  # encoder attention
1427
1946
 
1428
- tokens = self.encoder_transformer(tokens)
1947
+ tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
1429
1948
 
1430
1949
  # latent bottleneck
1431
1950
 
@@ -1447,19 +1966,30 @@ class VideoTokenizer(Module):
1447
1966
  if self.has_lpips_loss:
1448
1967
  lpips_loss = self.lpips(video, recon_video)
1449
1968
 
1969
+ time_decorr_loss = space_decorr_loss = self.zero
1970
+
1971
+ if self.encoder_add_decor_aux_loss:
1972
+ if exists(time_attn_normed_inputs):
1973
+ time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
1974
+
1975
+ if exists(space_attn_normed_inputs):
1976
+ space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
1977
+
1450
1978
  # losses
1451
1979
 
1452
1980
  total_loss = (
1453
1981
  recon_loss +
1454
- lpips_loss * self.lpips_loss_weight
1982
+ lpips_loss * self.lpips_loss_weight +
1983
+ time_decorr_loss * self.decorr_aux_loss_weight +
1984
+ space_decorr_loss * self.decorr_aux_loss_weight
1455
1985
  )
1456
1986
 
1457
1987
  if not return_all_losses:
1458
1988
  return total_loss
1459
1989
 
1460
- losses = (recon_loss, lpips_loss)
1990
+ losses = (recon_loss, lpips_loss, decorr_loss)
1461
1991
 
1462
- return total_loss, TokenizerLosses(losses)
1992
+ return total_loss, TokenizerLosses(*losses)
1463
1993
 
1464
1994
  # dynamics model, axial space-time transformer
1465
1995
 
@@ -1475,13 +2005,15 @@ class DynamicsWorldModel(Module):
1475
2005
  num_latent_tokens = None,
1476
2006
  num_agents = 1,
1477
2007
  num_tasks = 0,
2008
+ num_video_views = 1,
2009
+ dim_proprio = None,
1478
2010
  reward_encoder_kwargs: dict = dict(),
1479
2011
  depth = 4,
1480
2012
  pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
1481
2013
  time_block_every = 4, # every 4th block is time
1482
- attn_kwargs: dict = dict(
1483
- heads = 8,
1484
- ),
2014
+ attn_kwargs: dict = dict(),
2015
+ transformer_kwargs: dict = dict(),
2016
+ attn_heads = 8,
1485
2017
  attn_dim_head = 64,
1486
2018
  attn_softclamp_value = 50.,
1487
2019
  ff_kwargs: dict = dict(),
@@ -1493,15 +2025,25 @@ class DynamicsWorldModel(Module):
1493
2025
  num_discrete_actions: int | tuple[int, ...] = 0,
1494
2026
  num_continuous_actions = 0,
1495
2027
  continuous_norm_stats = None,
1496
- reward_loss_weight = 0.1,
2028
+ multi_token_pred_len = 8,
1497
2029
  value_head_mlp_depth = 3,
1498
2030
  policy_head_mlp_depth = 3,
1499
- behavior_clone_weight = 0.1,
2031
+ latent_flow_loss_weight = 1.,
2032
+ reward_loss_weight: float | list[float] = 1.,
2033
+ discrete_action_loss_weight: float | list[float] = 1.,
2034
+ continuous_action_loss_weight: float | list[float] = 1.,
1500
2035
  num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
1501
2036
  num_residual_streams = 1,
2037
+ keep_reward_ema_stats = False,
2038
+ reward_ema_decay = 0.998,
2039
+ reward_quantile_filter = (0.05, 0.95),
1502
2040
  gae_discount_factor = 0.997,
1503
2041
  gae_lambda = 0.95,
1504
2042
  ppo_eps_clip = 0.2,
2043
+ pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
2044
+ pmpo_reverse_kl = True,
2045
+ pmpo_kl_div_loss_weight = .3,
2046
+ normalize_advantages = None,
1505
2047
  value_clip = 0.4,
1506
2048
  policy_entropy_weight = .01,
1507
2049
  gae_use_accelerated = False
@@ -1535,7 +2077,7 @@ class DynamicsWorldModel(Module):
1535
2077
  )
1536
2078
 
1537
2079
  self.to_latent_pred = Sequential(
1538
- Reduce('b t n s d -> b t n d', 'mean'),
2080
+ Reduce('b t v n s d -> b t v n d', 'mean'),
1539
2081
  RMSNorm(dim),
1540
2082
  LinearNoBias(dim, dim_latent)
1541
2083
  )
@@ -1545,15 +2087,38 @@ class DynamicsWorldModel(Module):
1545
2087
  latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
1546
2088
 
1547
2089
  self.latents_to_spatial_tokens = Sequential(
1548
- Rearrange('b t n d -> b t (n d)'),
2090
+ Rearrange('... n d -> ... (n d)'),
1549
2091
  Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
1550
- Rearrange('b t (s d) -> b t s d', s = num_spatial_tokens)
2092
+ Rearrange('... (s d) -> ... s d', s = num_spatial_tokens)
1551
2093
  )
1552
2094
 
1553
2095
  self.to_latent_pred = Sequential(
1554
2096
  RMSNorm(dim),
1555
2097
  LinearNoBias(dim, dim_latent * latent_tokens_to_space),
1556
- Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
2098
+ Rearrange('b t v s (n d) -> b t v (s n) d', n = latent_tokens_to_space)
2099
+ )
2100
+
2101
+ # number of video views, for robotics, which could have third person + wrist camera at least
2102
+
2103
+ assert num_video_views >= 1
2104
+ self.video_has_multi_view = num_video_views > 1
2105
+
2106
+ self.num_video_views = num_video_views
2107
+
2108
+ if self.video_has_multi_view:
2109
+ self.view_emb = nn.Parameter(torch.randn(num_video_views, dim) * 1e-2)
2110
+
2111
+ # proprioception
2112
+
2113
+ self.has_proprio = exists(dim_proprio)
2114
+ self.dim_proprio = dim_proprio
2115
+
2116
+ if self.has_proprio:
2117
+ self.to_proprio_token = nn.Linear(dim_proprio, dim)
2118
+
2119
+ self.to_proprio_pred = Sequential(
2120
+ RMSNorm(dim),
2121
+ nn.Linear(dim, dim_proprio)
1557
2122
  )
1558
2123
 
1559
2124
  # register tokens
@@ -1597,6 +2162,7 @@ class DynamicsWorldModel(Module):
1597
2162
  # learned set of latent genes
1598
2163
 
1599
2164
  self.agent_has_genes = num_latent_genes > 0
2165
+ self.num_latent_genes = num_latent_genes
1600
2166
  self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
1601
2167
 
1602
2168
  # policy head
@@ -1616,10 +2182,14 @@ class DynamicsWorldModel(Module):
1616
2182
  num_continuous_actions = num_continuous_actions,
1617
2183
  continuous_norm_stats = continuous_norm_stats,
1618
2184
  can_unembed = True,
1619
- unembed_dim = dim * 4
2185
+ unembed_dim = dim * 4,
2186
+ num_unembed_preds = multi_token_pred_len,
2187
+ squeeze_unembed_preds = False
1620
2188
  )
1621
2189
 
1622
- self.behavior_clone_weight = behavior_clone_weight
2190
+ # multi token prediction length
2191
+
2192
+ self.multi_token_pred_len = multi_token_pred_len
1623
2193
 
1624
2194
  # each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
1625
2195
 
@@ -1632,12 +2202,15 @@ class DynamicsWorldModel(Module):
1632
2202
  learned_embedding = add_reward_embed_to_agent_token
1633
2203
  )
1634
2204
 
1635
- self.to_reward_pred = Sequential(
2205
+ to_reward_pred = Sequential(
1636
2206
  RMSNorm(dim),
1637
2207
  LinearNoBias(dim, self.reward_encoder.num_bins)
1638
2208
  )
1639
2209
 
1640
- self.reward_loss_weight = reward_loss_weight
2210
+ self.to_reward_pred = Ensemble(
2211
+ to_reward_pred,
2212
+ multi_token_pred_len
2213
+ )
1641
2214
 
1642
2215
  # value head
1643
2216
 
@@ -1653,13 +2226,16 @@ class DynamicsWorldModel(Module):
1653
2226
  self.transformer = AxialSpaceTimeTransformer(
1654
2227
  dim = dim,
1655
2228
  depth = depth,
2229
+ attn_heads = attn_heads,
1656
2230
  attn_dim_head = attn_dim_head,
1657
2231
  attn_softclamp_value = attn_softclamp_value,
1658
2232
  attn_kwargs = attn_kwargs,
1659
2233
  ff_kwargs = ff_kwargs,
1660
2234
  num_residual_streams = num_residual_streams,
1661
2235
  num_special_spatial_tokens = num_agents,
1662
- final_norm = False
2236
+ time_block_every = time_block_every,
2237
+ final_norm = False,
2238
+ **transformer_kwargs
1663
2239
  )
1664
2240
 
1665
2241
  # ppo related
@@ -1670,9 +2246,40 @@ class DynamicsWorldModel(Module):
1670
2246
 
1671
2247
  self.ppo_eps_clip = ppo_eps_clip
1672
2248
  self.value_clip = value_clip
1673
- self.policy_entropy_weight = value_clip
2249
+ self.policy_entropy_weight = policy_entropy_weight
2250
+
2251
+ # pmpo related
2252
+
2253
+ self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
2254
+ self.pmpo_kl_div_loss_weight = pmpo_kl_div_loss_weight
2255
+ self.pmpo_reverse_kl = pmpo_reverse_kl
2256
+
2257
+ # rewards related
2258
+
2259
+ self.keep_reward_ema_stats = keep_reward_ema_stats
2260
+ self.reward_ema_decay = reward_ema_decay
2261
+
2262
+ self.register_buffer('reward_quantile_filter', tensor(reward_quantile_filter), persistent = False)
2263
+
2264
+ self.register_buffer('ema_returns_mean', tensor(0.))
2265
+ self.register_buffer('ema_returns_var', tensor(1.))
2266
+
2267
+ # loss related
2268
+
2269
+ self.flow_loss_normalizer = LossNormalizer(1)
2270
+ self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
2271
+ self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
2272
+ self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None
2273
+
2274
+ self.latent_flow_loss_weight = latent_flow_loss_weight
2275
+
2276
+ self.register_buffer('reward_loss_weight', tensor(reward_loss_weight))
2277
+ self.register_buffer('discrete_action_loss_weight', tensor(discrete_action_loss_weight))
2278
+ self.register_buffer('continuous_action_loss_weight', tensor(continuous_action_loss_weight))
1674
2279
 
1675
- # zero
2280
+ assert self.reward_loss_weight.numel() in {1, multi_token_pred_len}
2281
+ assert self.discrete_action_loss_weight.numel() in {1, multi_token_pred_len}
2282
+ assert self.continuous_action_loss_weight.numel() in {1, multi_token_pred_len}
1676
2283
 
1677
2284
  self.register_buffer('zero', tensor(0.), persistent = False)
1678
2285
 
@@ -1680,17 +2287,19 @@ class DynamicsWorldModel(Module):
1680
2287
  def device(self):
1681
2288
  return self.zero.device
1682
2289
 
1683
- def get_times_from_signal_level(
1684
- self,
1685
- signal_levels,
1686
- align_dims_left_to = None
1687
- ):
1688
- times = signal_levels.float() / self.max_steps
2290
+ # types of parameters
1689
2291
 
1690
- if not exists(align_dims_left_to):
1691
- return times
2292
+ def muon_parameters(self):
2293
+ return self.transformer.muon_parameters()
1692
2294
 
1693
- return align_dims_left(times, align_dims_left_to)
2295
+ def policy_head_parameters(self):
2296
+ return [
2297
+ *self.policy_head.parameters(),
2298
+ *self.action_embedder.unembed_parameters() # includes the unembed from the action-embedder
2299
+ ]
2300
+
2301
+ def value_head_parameters(self):
2302
+ return self.value_head.parameters()
1694
2303
 
1695
2304
  def parameter(self):
1696
2305
  params = super().parameters()
@@ -1700,53 +2309,402 @@ class DynamicsWorldModel(Module):
1700
2309
 
1701
2310
  return list(set(params) - set(self.video_tokenizer.parameters()))
1702
2311
 
1703
- def learn_policy_from_generations(
2312
+ # helpers for shortcut flow matching
2313
+
2314
+ def get_times_from_signal_level(
2315
+ self,
2316
+ signal_levels,
2317
+ align_dims_left_to = None
2318
+ ):
2319
+ times = signal_levels.float() / self.max_steps
2320
+
2321
+ if not exists(align_dims_left_to):
2322
+ return times
2323
+
2324
+ return align_dims_left(times, align_dims_left_to)
2325
+
2326
+ # evolutionary policy optimization - https://web3.arxiv.org/abs/2503.19037
2327
+
2328
+ @torch.no_grad()
2329
+ def evolve_(
2330
+ self,
2331
+ fitness,
2332
+ select_frac = 0.5,
2333
+ tournament_frac = 0.5
2334
+ ):
2335
+ assert fitness.numel() == self.num_latent_genes
2336
+
2337
+ pop = self.latent_genes
2338
+
2339
+ pop_size = self.num_latent_genes
2340
+ num_selected = ceil(pop_size * select_frac)
2341
+ num_children = pop_size - num_selected
2342
+
2343
+ dim_gene = pop.shape[-1]
2344
+
2345
+ # natural selection just a sort and slice
2346
+
2347
+ selected_fitness, selected_indices = fitness.topk(num_selected, dim = -1)
2348
+ selected = pop[selected_indices]
2349
+
2350
+ # use tournament - one tournament per child
2351
+
2352
+ tournament_size = max(2, ceil(num_selected * tournament_frac))
2353
+
2354
+ tournaments = torch.randn((num_children, num_selected), device = self.device).argsort(dim = -1)[:, :tournament_size]
2355
+
2356
+ parent_ids = selected_fitness[tournaments].topk(2, dim = -1).indices # get top 2 winners as parents
2357
+
2358
+ parents = selected[parent_ids]
2359
+
2360
+ # crossover by random interpolation from parent1 to parent2
2361
+
2362
+ random_uniform_mix = torch.randn((num_children, dim_gene), device = self.device).sigmoid()
2363
+
2364
+ parent1, parent2 = parents.unbind(dim = 1)
2365
+ children = parent1.lerp(parent2, random_uniform_mix)
2366
+
2367
+ # store next population
2368
+
2369
+ next_pop = cat((selected, children))
2370
+
2371
+ self.latent_genes.copy_(next_pop)
2372
+
2373
+ # interacting with env for experience
2374
+
2375
+ @torch.no_grad()
2376
+ def interact_with_env(
2377
+ self,
2378
+ env,
2379
+ seed = None,
2380
+ agent_index = 0,
2381
+ step_size = 4,
2382
+ max_timesteps = 16,
2383
+ env_is_vectorized = False,
2384
+ use_time_kv_cache = True,
2385
+ store_agent_embed = True,
2386
+ store_old_action_unembeds = True,
2387
+ ):
2388
+ assert exists(self.video_tokenizer)
2389
+
2390
+ init_frame = env.reset()
2391
+
2392
+ # frame to video
2393
+
2394
+ if env_is_vectorized:
2395
+ video = rearrange(init_frame, 'b c vh vw -> b c 1 vh vw')
2396
+ else:
2397
+ video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
2398
+
2399
+ batch, device = video.shape[0], video.device
2400
+
2401
+ # accumulate
2402
+
2403
+ rewards = None
2404
+ discrete_actions = None
2405
+ continuous_actions = None
2406
+ discrete_log_probs = None
2407
+ continuous_log_probs = None
2408
+ values = None
2409
+ latents = None
2410
+
2411
+ acc_agent_embed = None
2412
+ acc_policy_embed = None
2413
+
2414
+ # keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
2415
+
2416
+ is_terminated = full((batch,), False, device = device)
2417
+ is_truncated = full((batch,), False, device = device)
2418
+
2419
+ episode_lens = full((batch,), 0, device = device)
2420
+
2421
+ # maybe time kv cache
2422
+
2423
+ time_kv_cache = None
2424
+
2425
+ step_index = 0
2426
+
2427
+ while not is_terminated.all():
2428
+ step_index += 1
2429
+
2430
+ latents = self.video_tokenizer(video, return_latents = True)
2431
+
2432
+ _, (agent_embed, next_time_kv_cache) = self.forward(
2433
+ latents = latents,
2434
+ signal_levels = self.max_steps - 1,
2435
+ step_sizes = step_size,
2436
+ rewards = rewards,
2437
+ discrete_actions = discrete_actions,
2438
+ continuous_actions = continuous_actions,
2439
+ time_kv_cache = time_kv_cache,
2440
+ latent_is_noised = True,
2441
+ return_pred_only = True,
2442
+ return_intermediates = True
2443
+ )
2444
+
2445
+ # time kv cache
2446
+
2447
+ if use_time_kv_cache:
2448
+ time_kv_cache = next_time_kv_cache
2449
+
2450
+ # get one agent
2451
+
2452
+ one_agent_embed = agent_embed[..., -1:, agent_index, :]
2453
+
2454
+ # values
2455
+
2456
+ value_bins = self.value_head(one_agent_embed)
2457
+ value = self.reward_encoder.bins_to_scalar_value(value_bins)
2458
+
2459
+ values = safe_cat((values, value), dim = 1)
2460
+
2461
+ # policy embed
2462
+
2463
+ policy_embed = self.policy_head(one_agent_embed)
2464
+
2465
+ if store_old_action_unembeds:
2466
+ acc_policy_embed = safe_cat((acc_policy_embed, policy_embed), dim = 1)
2467
+
2468
+ # sample actions
2469
+
2470
+ sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
2471
+
2472
+ discrete_actions = safe_cat((discrete_actions, sampled_discrete_actions), dim = 1)
2473
+ continuous_actions = safe_cat((continuous_actions, sampled_continuous_actions), dim = 1)
2474
+
2475
+ # get the log prob and values for policy optimization
2476
+
2477
+ one_discrete_log_probs, one_continuous_log_probs = self.action_embedder.log_probs(
2478
+ policy_embed,
2479
+ pred_head_index = 0,
2480
+ discrete_targets = sampled_discrete_actions,
2481
+ continuous_targets = sampled_continuous_actions,
2482
+ )
2483
+
2484
+ discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
2485
+ continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
2486
+
2487
+ # pass the sampled action to the environment and get back next state and reward
2488
+
2489
+ env_step_out = env.step((sampled_discrete_actions, sampled_continuous_actions))
2490
+
2491
+ if len(env_step_out) == 2:
2492
+ next_frame, reward = env_step_out
2493
+ terminated = full((batch,), False)
2494
+ truncated = full((batch,), False)
2495
+
2496
+ elif len(env_step_out) == 3:
2497
+ next_frame, reward, terminated = env_step_out
2498
+ truncated = full((batch,), False)
2499
+
2500
+ elif len(env_step_out) == 4:
2501
+ next_frame, reward, terminated, truncated = env_step_out
2502
+
2503
+ elif len(env_step_out) == 5:
2504
+ next_frame, reward, terminated, truncated, info = env_step_out
2505
+
2506
+ # update episode lens
2507
+
2508
+ episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
2509
+
2510
+ # update `is_terminated`
2511
+
2512
+ # (1) - environment says it is terminated
2513
+ # (2) - previous step is truncated (this step is for bootstrap value)
2514
+
2515
+ is_terminated |= (terminated | is_truncated)
2516
+
2517
+ # update `is_truncated`
2518
+
2519
+ if step_index <= max_timesteps:
2520
+ is_truncated |= truncated
2521
+
2522
+ if step_index == max_timesteps:
2523
+ # if the step index is at the max time step allowed, set the truncated flag, if not already terminated
2524
+
2525
+ is_truncated |= ~is_terminated
2526
+
2527
+ # batch and time dimension
2528
+
2529
+ if env_is_vectorized:
2530
+ next_frame = rearrange(next_frame, 'b c vh vw -> b c 1 vh vw')
2531
+ reward = rearrange(reward, 'b -> b 1')
2532
+ else:
2533
+ next_frame = rearrange(next_frame, 'c vh vw -> 1 c 1 vh vw')
2534
+ reward = rearrange(reward, ' -> 1 1')
2535
+
2536
+ # concat
2537
+
2538
+ video = cat((video, next_frame), dim = 2)
2539
+ rewards = safe_cat((rewards, reward), dim = 1)
2540
+
2541
+ acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
2542
+
2543
+ # package up one experience for learning
2544
+
2545
+ batch, device = latents.shape[0], latents.device
2546
+
2547
+ one_experience = Experience(
2548
+ latents = latents,
2549
+ video = video[:, :, :-1],
2550
+ rewards = rewards,
2551
+ actions = (discrete_actions, continuous_actions),
2552
+ log_probs = (discrete_log_probs, continuous_log_probs),
2553
+ values = values,
2554
+ old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
2555
+ agent_embed = acc_agent_embed if store_agent_embed else None,
2556
+ step_size = step_size,
2557
+ agent_index = agent_index,
2558
+ is_truncated = is_truncated,
2559
+ lens = episode_lens,
2560
+ is_from_world_model = False
2561
+ )
2562
+
2563
+ return one_experience
2564
+
2565
+ # ppo
2566
+
2567
+ def learn_from_experience(
1704
2568
  self,
1705
- generation: Experience,
2569
+ experience: Experience,
1706
2570
  policy_optim: Optimizer | None = None,
1707
- value_optim: Optimizer | None = None
2571
+ value_optim: Optimizer | None = None,
2572
+ only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
2573
+ use_pmpo = True,
2574
+ normalize_advantages = None,
2575
+ eps = 1e-6
1708
2576
  ):
1709
- latents = generation.latents
1710
- actions = generation.actions
1711
- old_log_probs = generation.log_probs
1712
- old_values = generation.values
1713
- rewards = generation.rewards
2577
+ assert isinstance(experience, Experience)
2578
+
2579
+ experience = experience.to(self.device)
2580
+
2581
+ latents = experience.latents
2582
+ actions = experience.actions
2583
+ old_log_probs = experience.log_probs
2584
+ old_values = experience.values
2585
+ rewards = experience.rewards
2586
+ agent_embeds = experience.agent_embed
2587
+ old_action_unembeds = experience.old_action_unembeds
2588
+
2589
+ step_size = experience.step_size
2590
+ agent_index = experience.agent_index
2591
+
2592
+ assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization - world_model.generate(..., return_log_probs_and_values = True)'
2593
+
2594
+ batch, time = latents.shape[0], latents.shape[1]
1714
2595
 
1715
- step_size = generation.step_size
1716
- agent_index = generation.agent_index
2596
+ # calculate returns
1717
2597
 
1718
- assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
2598
+ # mask out anything after the `lens`, which may include a bootstrapped node at the very end if `is_truncated = True`
2599
+
2600
+ if not exists(experience.is_truncated):
2601
+ experience.is_truncated = full((batch,), True, device = latents.device)
2602
+
2603
+ if exists(experience.lens):
2604
+ mask_for_gae = lens_to_mask(experience.lens, time)
2605
+
2606
+ rewards = rewards.masked_fill(~mask_for_gae, 0.)
2607
+ old_values = old_values.masked_fill(~mask_for_gae, 0.)
2608
+
2609
+ # calculate returns
1719
2610
 
1720
2611
  returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
1721
2612
 
1722
- # apparently they just use the sign of the advantage
2613
+ # handle variable lengths
2614
+
2615
+ max_time = latents.shape[1]
2616
+ is_var_len = exists(experience.lens)
2617
+
2618
+ mask = None
2619
+
2620
+ if is_var_len:
2621
+ learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value
2622
+ mask = lens_to_mask(learnable_lens, max_time)
2623
+
2624
+ # determine whether to finetune entire transformer or just learn the heads
2625
+
2626
+ world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
2627
+
2628
+ # maybe keep track returns statistics and normalize returns and values before calculating advantage, as done in dreamer v3
2629
+
2630
+ if self.keep_reward_ema_stats:
2631
+ ema_returns_mean, ema_returns_var = self.ema_returns_mean, self.ema_returns_var
2632
+
2633
+ decay = 1. - self.reward_ema_decay
2634
+
2635
+ # quantile filter
2636
+
2637
+ lo, hi = torch.quantile(returns, self.reward_quantile_filter).tolist()
2638
+ returns_for_stats = returns.clamp(lo, hi)
2639
+
2640
+ # mean, var - todo - handle distributed
2641
+
2642
+ returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var()
2643
+
2644
+ # ema
2645
+
2646
+ ema_returns_mean.lerp_(returns_mean, decay)
2647
+ ema_returns_var.lerp_(returns_var, decay)
2648
+
2649
+ # normalize
2650
+
2651
+ ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
2652
+
2653
+ normed_returns = (returns - ema_returns_mean) / ema_returns_std
2654
+ normed_old_values = (old_values - ema_returns_mean) / ema_returns_std
2655
+
2656
+ advantage = normed_returns - normed_old_values
2657
+ else:
2658
+ advantage = returns - old_values
2659
+
2660
+ # if using pmpo, do not normalize advantages, but can be overridden
2661
+
2662
+ normalize_advantages = default(normalize_advantages, not use_pmpo)
2663
+
2664
+ if normalize_advantages:
2665
+ advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
2666
+
1723
2667
  # https://arxiv.org/abs/2410.04166v1
1724
2668
 
1725
- advantage = (returns - old_values).sign()
2669
+ if use_pmpo:
2670
+ pos_advantage_mask = advantage >= 0.
2671
+ neg_advantage_mask = ~pos_advantage_mask
1726
2672
 
1727
2673
  # replay for the action logits and values
2674
+ # but only do so if fine tuning the entire world model for RL
1728
2675
 
1729
2676
  discrete_actions, continuous_actions = actions
1730
2677
 
1731
- _, agent_embed = self.forward(
1732
- latents = latents,
1733
- signal_levels = self.max_steps - 1,
1734
- step_sizes = step_size,
1735
- rewards = rewards,
1736
- discrete_actions = discrete_actions,
1737
- continuous_actions = continuous_actions,
1738
- latent_is_noised = True,
1739
- return_pred_only = True,
1740
- return_agent_tokens = True
1741
- )
2678
+ if (
2679
+ not only_learn_policy_value_heads or
2680
+ not exists(agent_embeds)
2681
+ ):
2682
+
2683
+ with world_model_forward_context():
2684
+ _, (agent_embeds, _) = self.forward(
2685
+ latents = latents,
2686
+ signal_levels = self.max_steps - 1,
2687
+ step_sizes = step_size,
2688
+ rewards = rewards,
2689
+ discrete_actions = discrete_actions,
2690
+ continuous_actions = continuous_actions,
2691
+ latent_is_noised = True,
2692
+ return_pred_only = True,
2693
+ return_intermediates = True
2694
+ )
2695
+
2696
+ agent_embeds = agent_embeds[..., agent_index, :]
1742
2697
 
1743
- agent_embed = agent_embed[..., agent_index, :]
2698
+ # maybe detach agent embed
2699
+
2700
+ if only_learn_policy_value_heads:
2701
+ agent_embeds = agent_embeds.detach()
1744
2702
 
1745
2703
  # ppo
1746
2704
 
1747
- policy_embed = self.policy_head(agent_embed)
2705
+ policy_embed = self.policy_head(agent_embeds)
1748
2706
 
1749
- log_probs, entropies = self.action_embedder.log_probs(policy_embed, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True)
2707
+ log_probs, entropies = self.action_embedder.log_probs(policy_embed, pred_head_index = 0, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True)
1750
2708
 
1751
2709
  # concat discrete and continuous actions into one for optimizing
1752
2710
 
@@ -1754,29 +2712,77 @@ class DynamicsWorldModel(Module):
1754
2712
  log_probs = safe_cat(log_probs, dim = -1)
1755
2713
  entropies = safe_cat(entropies, dim = -1)
1756
2714
 
1757
- ratio = (log_probs - old_log_probs).exp()
2715
+ advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions
1758
2716
 
1759
- clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
2717
+ if use_pmpo:
2718
+ # pmpo - weighting the positive and negative advantages equally - ignoring magnitude of advantage and taking the sign
2719
+ # seems to be weighted across batch and time, iiuc
2720
+ # eq (10) in https://arxiv.org/html/2410.04166v1
1760
2721
 
1761
- advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions
2722
+ if exists(mask):
2723
+ pos_advantage_mask &= mask
2724
+ neg_advantage_mask &= mask
2725
+
2726
+ α = self.pmpo_pos_to_neg_weight
2727
+
2728
+ pos = masked_mean(log_probs, pos_advantage_mask)
2729
+ neg = -masked_mean(log_probs, neg_advantage_mask)
2730
+
2731
+ policy_loss = -(α * pos + (1. - α) * neg)
2732
+
2733
+ # take care of kl
2734
+
2735
+ if self.pmpo_kl_div_loss_weight > 0.:
2736
+
2737
+ new_unembedded_actions = self.action_embedder.unembed(policy_embed, pred_head_index = 0)
2738
+
2739
+ kl_div_inputs, kl_div_targets = new_unembedded_actions, old_action_unembeds
2740
+
2741
+ # mentioned that the "reverse direction for the prior KL" was used
2742
+ # make optional, as observed instability in toy task
2743
+
2744
+ if self.pmpo_reverse_kl:
2745
+ kl_div_inputs, kl_div_targets = kl_div_targets, kl_div_inputs
2746
+
2747
+ discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(kl_div_inputs, kl_div_targets)
1762
2748
 
1763
- # clipped surrogate loss
2749
+ # accumulate discrete and continuous kl div
1764
2750
 
1765
- policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
1766
- policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
2751
+ kl_div_loss = 0.
1767
2752
 
1768
- policy_loss = policy_loss.mean()
2753
+ if exists(discrete_kl_div):
2754
+ kl_div_loss = kl_div_loss + discrete_kl_div[mask].mean()
2755
+
2756
+ if exists(continuous_kl_div):
2757
+ kl_div_loss = kl_div_loss + continuous_kl_div[mask].mean()
2758
+
2759
+ policy_loss = policy_loss + kl_div_loss * self.pmpo_kl_div_loss_weight
2760
+
2761
+ else:
2762
+ # ppo clipped surrogate loss
2763
+
2764
+ ratio = (log_probs - old_log_probs).exp()
2765
+ clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
2766
+
2767
+ policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
2768
+ policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
2769
+
2770
+ policy_loss = masked_mean(policy_loss, mask)
1769
2771
 
1770
2772
  # handle entropy loss for naive exploration bonus
1771
2773
 
1772
- entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum').mean()
2774
+ entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
2775
+
2776
+ entropy_loss = masked_mean(entropy_loss, mask)
2777
+
2778
+ # total policy loss
1773
2779
 
1774
2780
  total_policy_loss = (
1775
2781
  policy_loss +
1776
2782
  entropy_loss * self.policy_entropy_weight
1777
2783
  )
1778
2784
 
1779
- # maye take policy optimizer step
2785
+ # maybe take policy optimizer step
1780
2786
 
1781
2787
  if exists(policy_optim):
1782
2788
  total_policy_loss.backward()
@@ -1786,7 +2792,7 @@ class DynamicsWorldModel(Module):
1786
2792
 
1787
2793
  # value loss
1788
2794
 
1789
- value_bins = self.value_head(agent_embed)
2795
+ value_bins = self.value_head(agent_embeds)
1790
2796
  values = self.reward_encoder.bins_to_scalar_value(value_bins)
1791
2797
 
1792
2798
  clipped_values = old_values + (values - old_values).clamp(-self.value_clip, self.value_clip)
@@ -1794,10 +2800,19 @@ class DynamicsWorldModel(Module):
1794
2800
 
1795
2801
  return_bins = self.reward_encoder(returns)
1796
2802
 
2803
+ value_bins, return_bins, clipped_value_bins = tuple(rearrange(t, 'b t l -> b l t') for t in (value_bins, return_bins, clipped_value_bins))
2804
+
1797
2805
  value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none')
1798
2806
  value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none')
1799
2807
 
1800
- value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
2808
+ value_loss = torch.maximum(value_loss_1, value_loss_2)
2809
+
2810
+ # maybe variable length
2811
+
2812
+ if is_var_len:
2813
+ value_loss = value_loss[mask].mean()
2814
+ else:
2815
+ value_loss = value_loss.mean()
1801
2816
 
1802
2817
  # maybe take value optimizer step
1803
2818
 
@@ -1816,22 +2831,49 @@ class DynamicsWorldModel(Module):
1816
2831
  num_steps = 4,
1817
2832
  batch_size = 1,
1818
2833
  agent_index = 0,
2834
+ tasks: int | Tensor | None = None,
2835
+ latent_gene_ids = None,
1819
2836
  image_height = None,
1820
2837
  image_width = None,
1821
2838
  return_decoded_video = None,
1822
2839
  context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
2840
+ time_kv_cache: Tensor | None = None,
2841
+ use_time_kv_cache = True,
1823
2842
  return_rewards_per_frame = False,
1824
2843
  return_agent_actions = False,
1825
- return_log_probs_and_values = False
2844
+ return_log_probs_and_values = False,
2845
+ return_for_policy_optimization = False,
2846
+ return_time_kv_cache = False,
2847
+ store_agent_embed = True,
2848
+ store_old_action_unembeds = True
1826
2849
 
1827
2850
  ): # (b t n d) | (b c t h w)
1828
2851
 
2852
+ # handy flag for returning generations for rl
2853
+
2854
+ if return_for_policy_optimization:
2855
+ return_agent_actions |= True
2856
+ return_log_probs_and_values |= True
2857
+ return_rewards_per_frame |= True
2858
+
2859
+ # more variables
2860
+
2861
+ has_proprio = self.has_proprio
1829
2862
  was_training = self.training
1830
2863
  self.eval()
1831
2864
 
2865
+ # validation
2866
+
1832
2867
  assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
1833
2868
  assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
1834
2869
 
2870
+ if isinstance(tasks, int):
2871
+ tasks = full((batch_size,), tasks, device = self.device)
2872
+
2873
+ assert not exists(tasks) or tasks.shape[0] == batch_size
2874
+
2875
+ # get state latent shape
2876
+
1835
2877
  latent_shape = self.latent_shape
1836
2878
 
1837
2879
  # derive step size
@@ -1841,9 +2883,16 @@ class DynamicsWorldModel(Module):
1841
2883
  # denoising
1842
2884
  # teacher forcing to start with
1843
2885
 
1844
- latents = empty((batch_size, 0, *latent_shape), device = self.device)
2886
+ latents = empty((batch_size, 0, self.num_video_views, *latent_shape), device = self.device)
2887
+
2888
+ past_latents_context_noise = latents.clone()
1845
2889
 
1846
- past_context_noise = latents.clone()
2890
+ # maybe internal state
2891
+
2892
+ if has_proprio:
2893
+ proprio = empty((batch_size, 0, self.dim_proprio), device = self.device)
2894
+
2895
+ past_proprio_context_noise = proprio.clone()
1847
2896
 
1848
2897
  # maybe return actions
1849
2898
 
@@ -1858,8 +2907,17 @@ class DynamicsWorldModel(Module):
1858
2907
  decoded_continuous_log_probs = None
1859
2908
  decoded_values = None
1860
2909
 
2910
+ # maybe store agent embed
2911
+
2912
+ acc_agent_embed = None
2913
+
2914
+ # maybe store old actions for kl
2915
+
2916
+ acc_policy_embed = None
2917
+
1861
2918
  # maybe return rewards
1862
2919
 
2920
+ decoded_rewards = None
1863
2921
  if return_rewards_per_frame:
1864
2922
  decoded_rewards = empty((batch_size, 0), device = self.device, dtype = torch.float32)
1865
2923
 
@@ -1868,55 +2926,131 @@ class DynamicsWorldModel(Module):
1868
2926
  while latents.shape[1] < time_steps:
1869
2927
 
1870
2928
  curr_time_steps = latents.shape[1]
1871
- noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
1872
2929
 
1873
- for step in range(num_steps):
2930
+ # determine whether to take an extra step if
2931
+ # (1) using time kv cache
2932
+ # (2) decoding anything off agent embedding (rewards, actions, etc)
2933
+
2934
+ take_extra_step = (
2935
+ use_time_kv_cache or
2936
+ return_rewards_per_frame or
2937
+ store_agent_embed or
2938
+ return_agent_actions
2939
+ )
2940
+
2941
+ # prepare noised latent / proprio inputs
2942
+
2943
+ noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
2944
+
2945
+ noised_proprio = None
2946
+
2947
+ if has_proprio:
2948
+ noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device)
2949
+
2950
+ # denoising steps
2951
+
2952
+ for step in range(num_steps + int(take_extra_step)):
2953
+
2954
+ is_last_step = (step + 1) == num_steps
2955
+
1874
2956
  signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
2957
+
2958
+ # noising past latent context
1875
2959
 
1876
- noised_context = latents.lerp(past_context_noise, context_signal_noise) # the paragraph after eq (8)
2960
+ noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8)
1877
2961
 
1878
- noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
2962
+ noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * v n d')
2963
+
2964
+ # handle proprio
2965
+
2966
+ noised_proprio_with_context = None
2967
+
2968
+ if has_proprio:
2969
+ noised_proprio_context = proprio.lerp(past_proprio_context_noise, context_signal_noise)
2970
+ noised_proprio_with_context, _ = pack((noised_proprio_context, noised_proprio), 'b * d')
2971
+
2972
+ # proper signal levels
1879
2973
 
1880
2974
  signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
1881
2975
 
1882
- pred, agent_embed = self.forward(
2976
+ pred, (agent_embed, next_time_kv_cache) = self.forward(
1883
2977
  latents = noised_latent_with_context,
1884
2978
  signal_levels = signal_levels_with_context,
1885
2979
  step_sizes = step_size,
1886
2980
  rewards = decoded_rewards,
2981
+ tasks = tasks,
2982
+ latent_gene_ids = latent_gene_ids,
1887
2983
  discrete_actions = decoded_discrete_actions,
1888
2984
  continuous_actions = decoded_continuous_actions,
2985
+ proprio = noised_proprio_with_context,
2986
+ time_kv_cache = time_kv_cache,
1889
2987
  latent_is_noised = True,
2988
+ latent_has_view_dim = True,
1890
2989
  return_pred_only = True,
1891
- return_agent_tokens = True
2990
+ return_intermediates = True,
1892
2991
  )
1893
2992
 
1894
- _, pred = unpack(pred, pack_context_shape, 'b * n d')
2993
+ if use_time_kv_cache and is_last_step:
2994
+ time_kv_cache = next_time_kv_cache
2995
+
2996
+ # early break if taking an extra step for agent embedding off cleaned latents for decoding
2997
+
2998
+ if take_extra_step and is_last_step:
2999
+ break
3000
+
3001
+ # maybe proprio
3002
+
3003
+ if has_proprio:
3004
+ pred, pred_proprio = pred
3005
+
3006
+ # unpack pred
3007
+
3008
+ _, pred = unpack(pred, pack_context_shape, 'b * v n d')
3009
+
3010
+ if has_proprio:
3011
+ _, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d')
1895
3012
 
1896
3013
  # derive flow, based on whether in x-space or not
1897
3014
 
1898
- if self.pred_orig_latent:
1899
- times = self.get_times_from_signal_level(signal_levels, noised_latent)
1900
- flow = (pred - noised_latent) / (1. - times)
1901
- else:
1902
- flow = pred
3015
+ def denoise_step(pred, noised, signal_levels):
3016
+ if self.pred_orig_latent:
3017
+ times = self.get_times_from_signal_level(signal_levels)
3018
+ aligned_times = align_dims_left(times, noised)
3019
+
3020
+ flow = (pred - noised) / (1. - aligned_times)
3021
+ else:
3022
+ flow = pred
3023
+
3024
+ return flow * (step_size / self.max_steps)
1903
3025
 
1904
3026
  # denoise
1905
3027
 
1906
- noised_latent += flow * (step_size / self.max_steps)
3028
+ noised_latent += denoise_step(pred, noised_latent, signal_levels)
3029
+
3030
+ if has_proprio:
3031
+ noised_proprio += denoise_step(pred_proprio, noised_proprio, signal_levels)
1907
3032
 
1908
3033
  denoised_latent = noised_latent # it is now denoised
1909
3034
 
3035
+ if has_proprio:
3036
+ denoised_proprio = noised_proprio
3037
+
1910
3038
  # take care of the rewards by predicting on the agent token embedding on the last denoising step
1911
3039
 
1912
3040
  if return_rewards_per_frame:
1913
3041
  one_agent_embed = agent_embed[:, -1:, agent_index]
1914
3042
 
1915
- reward_logits = self.to_reward_pred(one_agent_embed)
3043
+ reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0)
1916
3044
  pred_reward = self.reward_encoder.bins_to_scalar_value(reward_logits, normalize = True)
1917
3045
 
1918
3046
  decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
1919
3047
 
3048
+ # maybe store agent embed
3049
+
3050
+ if store_agent_embed:
3051
+ one_agent_embed = agent_embed[:, -1:, agent_index]
3052
+ acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
3053
+
1920
3054
  # decode the agent actions if needed
1921
3055
 
1922
3056
  if return_agent_actions:
@@ -1926,7 +3060,14 @@ class DynamicsWorldModel(Module):
1926
3060
 
1927
3061
  policy_embed = self.policy_head(one_agent_embed)
1928
3062
 
1929
- sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed)
3063
+ # maybe store old actions
3064
+
3065
+ if store_old_action_unembeds:
3066
+ acc_policy_embed = safe_cat((acc_policy_embed, policy_embed), dim = 1)
3067
+
3068
+ # sample actions
3069
+
3070
+ sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
1930
3071
 
1931
3072
  decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
1932
3073
  decoded_continuous_actions = safe_cat((decoded_continuous_actions, sampled_continuous_actions), dim = 1)
@@ -1934,6 +3075,7 @@ class DynamicsWorldModel(Module):
1934
3075
  if return_log_probs_and_values:
1935
3076
  discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
1936
3077
  policy_embed,
3078
+ pred_head_index = 0,
1937
3079
  discrete_targets = sampled_discrete_actions,
1938
3080
  continuous_targets = sampled_continuous_actions,
1939
3081
  )
@@ -1952,7 +3094,14 @@ class DynamicsWorldModel(Module):
1952
3094
 
1953
3095
  # add new fixed context noise for the temporal consistency
1954
3096
 
1955
- past_context_noise = cat((past_context_noise, randn_like(denoised_latent)), dim = 1)
3097
+ past_latents_context_noise = cat((past_latents_context_noise, randn_like(denoised_latent)), dim = 1)
3098
+
3099
+ # handle proprio
3100
+
3101
+ if has_proprio:
3102
+ proprio = cat((proprio, denoised_proprio), dim = 1)
3103
+
3104
+ past_proprio_context_noise = cat((past_proprio_context_noise, randn_like(denoised_proprio)), dim = 1)
1956
3105
 
1957
3106
  # restore state
1958
3107
 
@@ -1966,24 +3115,50 @@ class DynamicsWorldModel(Module):
1966
3115
  video = None
1967
3116
 
1968
3117
  if return_decoded_video:
3118
+
3119
+ latents_for_video = rearrange(latents, 'b t v n d -> b v t n d')
3120
+ latents_for_video, unpack_view = pack_one(latents_for_video, '* t n d')
3121
+
1969
3122
  video = self.video_tokenizer.decode(
1970
- latents,
3123
+ latents_for_video,
1971
3124
  height = image_height,
1972
3125
  width = image_width
1973
3126
  )
1974
3127
 
3128
+ video = unpack_view(video, '* t c vh vw')
3129
+
3130
+ # remove the lone view dimension
3131
+
3132
+ if not self.video_has_multi_view:
3133
+ latents = rearrange(latents, 'b t 1 ... -> b t ...')
3134
+
3135
+ if exists(video):
3136
+ video = rearrange(video, 'b 1 ... -> b ...')
3137
+
1975
3138
  # only return video or latent if not requesting anything else, for first stage training
1976
3139
 
1977
- if not has_at_least_one(return_rewards_per_frame, return_agent_actions):
1978
- return video if return_decoded_video else latents
3140
+ if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
3141
+ out = video if return_decoded_video else latents
3142
+
3143
+ if not return_time_kv_cache:
3144
+ return out
3145
+
3146
+ return out, time_kv_cache
1979
3147
 
1980
3148
  # returning agent actions, rewards, and log probs + values for policy optimization
1981
3149
 
3150
+ batch, device = latents.shape[0], latents.device
3151
+ experience_lens = full((batch,), time_steps, device = device)
3152
+
1982
3153
  gen = Experience(
1983
3154
  latents = latents,
1984
3155
  video = video,
3156
+ proprio = proprio if has_proprio else None,
3157
+ agent_embed = acc_agent_embed if store_agent_embed else None,
3158
+ old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
1985
3159
  step_size = step_size,
1986
3160
  agent_index = agent_index,
3161
+ lens = experience_lens,
1987
3162
  is_from_world_model = True
1988
3163
  )
1989
3164
 
@@ -1998,13 +3173,17 @@ class DynamicsWorldModel(Module):
1998
3173
 
1999
3174
  gen.values = decoded_values
2000
3175
 
2001
- return gen
3176
+ if not return_time_kv_cache:
3177
+ return gen
3178
+
3179
+ return gen, time_kv_cache
2002
3180
 
2003
3181
  def forward(
2004
3182
  self,
2005
3183
  *,
2006
- video = None, # (b c t vh vw)
2007
- latents = None, # (b t n d) | (b t d)
3184
+ video = None, # (b v? c t vh vw)
3185
+ latents = None, # (b t v? n d) | (b t v? d)
3186
+ lens = None, # (b)
2008
3187
  signal_levels = None, # () | (b) | (b t)
2009
3188
  step_sizes = None, # () | (b)
2010
3189
  step_sizes_log2 = None, # () | (b)
@@ -2015,25 +3194,46 @@ class DynamicsWorldModel(Module):
2015
3194
  continuous_actions = None, # (b t na) | (b t-1 na)
2016
3195
  discrete_action_types = None, # (na)
2017
3196
  continuous_action_types = None, # (na)
3197
+ proprio = None, # (b t dp)
3198
+ time_kv_cache = None,
2018
3199
  return_pred_only = False,
2019
3200
  latent_is_noised = False,
2020
3201
  return_all_losses = False,
2021
- return_agent_tokens = False,
2022
- add_autoregressive_action_loss = False
3202
+ return_intermediates = False,
3203
+ add_autoregressive_action_loss = True,
3204
+ update_loss_ema = None,
3205
+ latent_has_view_dim = False
2023
3206
  ):
2024
3207
  # handle video or latents
2025
3208
 
2026
3209
  assert exists(video) ^ exists(latents)
2027
3210
 
3211
+ # standardize view dimension
3212
+
3213
+ if not self.video_has_multi_view:
3214
+ if exists(video):
3215
+ video = rearrange(video, 'b ... -> b 1 ...')
3216
+
3217
+ if exists(latents) and not latent_has_view_dim:
3218
+ latents = rearrange(latents, 'b t ... -> b t 1 ...')
3219
+
3220
+ # if raw video passed in, tokenize
3221
+
2028
3222
  if exists(video):
3223
+ assert video.ndim == 6
3224
+
3225
+ video, unpack_views = pack_one(video, '* c t vh vw')
2029
3226
  assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
2030
3227
 
2031
3228
  latents = self.video_tokenizer.tokenize(video)
3229
+ latents = unpack_views(latents, '* t n d')
3230
+ latents = rearrange(latents, 'b v t n d -> b t v n d')
2032
3231
 
2033
- if latents.ndim == 3:
2034
- latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
3232
+ if latents.ndim == 4:
3233
+ latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
2035
3234
 
2036
- assert latents.shape[-2:] == self.latent_shape
3235
+ assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}'
3236
+ assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}'
2037
3237
 
2038
3238
  # variables
2039
3239
 
@@ -2108,16 +3308,17 @@ class DynamicsWorldModel(Module):
2108
3308
 
2109
3309
  # times is from 0 to 1
2110
3310
 
2111
- times = self.get_times_from_signal_level(signal_levels, latents)
3311
+ times = self.get_times_from_signal_level(signal_levels)
2112
3312
 
2113
3313
  if not latent_is_noised:
2114
3314
  # get the noise
2115
3315
 
2116
3316
  noise = randn_like(latents)
3317
+ aligned_times = align_dims_left(times, latents)
2117
3318
 
2118
3319
  # noise from 0 as noise to 1 as data
2119
3320
 
2120
- noised_latents = noise.lerp(latents, times)
3321
+ noised_latents = noise.lerp(latents, aligned_times)
2121
3322
 
2122
3323
  else:
2123
3324
  noised_latents = latents
@@ -2165,6 +3366,27 @@ class DynamicsWorldModel(Module):
2165
3366
 
2166
3367
  reward_tokens = add('1 d, b t d', self.reward_learned_embed, reward_tokens)
2167
3368
 
3369
+ # maybe proprioception
3370
+
3371
+ assert xnor(self.has_proprio, exists(proprio)), 'proprio must be passed in if `dim_proprio` is set and vice versa'
3372
+
3373
+ noised_proprio = None
3374
+
3375
+ if self.has_proprio:
3376
+
3377
+ if not latent_is_noised:
3378
+ # get the noise
3379
+
3380
+ proprio_noise = randn_like(proprio)
3381
+ aligned_times = align_dims_left(times, proprio)
3382
+
3383
+ # noise from 0 as noise to 1 as data
3384
+
3385
+ noised_proprio = proprio_noise.lerp(proprio, aligned_times)
3386
+
3387
+ else:
3388
+ noised_proprio = proprio
3389
+
2168
3390
  # maybe create the action tokens
2169
3391
 
2170
3392
  if exists(discrete_actions) or exists(continuous_actions):
@@ -2185,16 +3407,27 @@ class DynamicsWorldModel(Module):
2185
3407
 
2186
3408
  action_tokens = add('1 d, b t d', self.action_learned_embed, action_tokens)
2187
3409
 
3410
+ elif self.action_embedder.has_actions:
3411
+ action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1])
3412
+
2188
3413
  else:
2189
3414
  action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
2190
3415
 
2191
3416
  # main function, needs to be defined as such for shortcut training - additional calls for consistency loss
2192
3417
 
2193
- def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False):
3418
+ def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
3419
+
2194
3420
  # latents to spatial tokens
2195
3421
 
2196
3422
  space_tokens = self.latents_to_spatial_tokens(noised_latents)
2197
3423
 
3424
+ # maybe add view embedding
3425
+
3426
+ if self.video_has_multi_view:
3427
+ space_tokens = add('b t v ... d, v d', space_tokens, self.view_emb)
3428
+
3429
+ # merge spatial tokens
3430
+
2198
3431
  space_tokens, inverse_pack_space_per_latent = pack_one(space_tokens, 'b t * d')
2199
3432
 
2200
3433
  num_spatial_tokens = space_tokens.shape[-2]
@@ -2212,6 +3445,13 @@ class DynamicsWorldModel(Module):
2212
3445
 
2213
3446
  registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time)
2214
3447
 
3448
+ # maybe proprio
3449
+
3450
+ if exists(noised_proprio):
3451
+ proprio_token = self.to_proprio_token(noised_proprio)
3452
+ else:
3453
+ proprio_token = registers[:, :, 0:0]
3454
+
2215
3455
  # determine signal + step size embed for their diffusion forcing + shortcut
2216
3456
 
2217
3457
  signal_embed = self.signal_levels_embed(signal_levels)
@@ -2224,15 +3464,15 @@ class DynamicsWorldModel(Module):
2224
3464
 
2225
3465
  # pack to tokens for attending
2226
3466
 
2227
- tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
3467
+ tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
2228
3468
 
2229
3469
  # attention
2230
3470
 
2231
- tokens = self.transformer(tokens)
3471
+ tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
2232
3472
 
2233
3473
  # unpack
2234
3474
 
2235
- flow_token, space_tokens, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
3475
+ flow_token, space_tokens, proprio_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
2236
3476
 
2237
3477
  # pooling
2238
3478
 
@@ -2240,10 +3480,22 @@ class DynamicsWorldModel(Module):
2240
3480
 
2241
3481
  pred = self.to_latent_pred(space_tokens)
2242
3482
 
3483
+ # maybe proprio
3484
+
3485
+ if self.has_proprio:
3486
+ pred_proprio = self.to_proprio_pred(proprio_token)
3487
+
3488
+ pred = (pred, pred_proprio)
3489
+
3490
+ # returning
3491
+
2243
3492
  if not return_agent_tokens:
2244
3493
  return pred
2245
3494
 
2246
- return pred, agent_tokens
3495
+ if not return_time_kv_cache:
3496
+ return pred, agent_tokens
3497
+
3498
+ return pred, (agent_tokens, next_time_kv_cache)
2247
3499
 
2248
3500
  # curry into get_prediction what does not change during first call as well as the shortcut ones
2249
3501
 
@@ -2251,13 +3503,47 @@ class DynamicsWorldModel(Module):
2251
3503
 
2252
3504
  # forward the network
2253
3505
 
2254
- pred, encoded_agent_tokens = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True)
3506
+ pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
2255
3507
 
2256
3508
  if return_pred_only:
2257
- if not return_agent_tokens:
3509
+ if not return_intermediates:
3510
+ return pred
3511
+
3512
+ return pred, (encoded_agent_tokens, next_time_kv_cache)
3513
+
3514
+ # pack the predictions to calculate flow for different modalities all at once
3515
+
3516
+ if self.has_proprio:
3517
+ pred, for_flow_loss_packed_shape = pack(pred, 'b t *')
3518
+
3519
+ noised, _ = pack((noised_latents, noised_proprio), 'b t *')
3520
+ data, _ = pack((latents, proprio), 'b t *')
3521
+ noise, _ = pack((noise, proprio_noise), 'b t *')
3522
+ else:
3523
+ noised = noised_latents
3524
+ data = latents
3525
+
3526
+ # wrapper function for maybe unpacking and packing modalities for doing flow math in unison
3527
+
3528
+ def maybe_pack_unpack(fn):
3529
+ @wraps(fn)
3530
+ @torch.no_grad()
3531
+ def inner(noised, *args, **kwargs):
3532
+
3533
+ noised_proprio = None
3534
+
3535
+ if self.has_proprio:
3536
+ noised, noised_proprio = unpack(noised, for_flow_loss_packed_shape, 'b t *')
3537
+
3538
+ pred = fn(noised, noised_proprio, *args, **kwargs)
3539
+
3540
+ if self.has_proprio:
3541
+ pred, _ = pack(pred, 'b t *')
3542
+
2258
3543
  return pred
3544
+ return inner
2259
3545
 
2260
- return pred, encoded_agent_tokens
3546
+ wrapped_get_prediction = maybe_pack_unpack(_get_prediction)
2261
3547
 
2262
3548
  # determine the target for the loss
2263
3549
 
@@ -2274,46 +3560,45 @@ class DynamicsWorldModel(Module):
2274
3560
  # x-space as in paper is in else clause
2275
3561
 
2276
3562
  if is_v_space_pred:
2277
- pred_target = flow = latents - noise
3563
+ pred_target = flow = data - noise
2278
3564
  else:
2279
- pred_target = latents
3565
+ pred_target = data
2280
3566
  else:
2281
3567
  # shortcut training - Frans et al. https://arxiv.org/abs/2410.12557
2282
3568
 
2283
3569
  # basically a consistency loss where you ensure quantity of two half steps equals one step
2284
3570
  # dreamer then makes it works for x-space with some math
2285
3571
 
2286
- get_prediction_no_grad = torch.no_grad()(_get_prediction)
2287
-
2288
3572
  step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
2289
3573
  half_step_size = 2 ** step_sizes_log2_minus_one
2290
3574
 
2291
- first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one)
3575
+ first_step_pred = wrapped_get_prediction(noised, signal_levels, step_sizes_log2_minus_one)
2292
3576
 
2293
3577
  # first derive b'
2294
3578
 
2295
3579
  if is_v_space_pred:
2296
3580
  first_step_pred_flow = first_step_pred
2297
3581
  else:
2298
- first_times = self.get_times_from_signal_level(signal_levels, noised_latents)
2299
- first_step_pred_flow = (first_step_pred - noised_latents) / (1. - first_times)
3582
+ first_times = self.get_times_from_signal_level(signal_levels, noised)
3583
+
3584
+ first_step_pred_flow = (first_step_pred - noised) / (1. - first_times)
2300
3585
 
2301
3586
  # take a half step
2302
3587
 
2303
- half_step_size_align_left = align_dims_left(half_step_size, noised_latents)
3588
+ half_step_size_align_left = align_dims_left(half_step_size, noised)
2304
3589
 
2305
- denoised_latent = noised_latents + first_step_pred_flow * (half_step_size_align_left / self.max_steps)
3590
+ denoised = noised + first_step_pred_flow * (half_step_size_align_left / self.max_steps)
2306
3591
 
2307
3592
  # get second prediction for b''
2308
3593
 
2309
3594
  signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
2310
- second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one)
3595
+ second_step_pred = wrapped_get_prediction(denoised, signal_levels_plus_half_step, step_sizes_log2_minus_one)
2311
3596
 
2312
3597
  if is_v_space_pred:
2313
3598
  second_step_pred_flow = second_step_pred
2314
3599
  else:
2315
- second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised_latent)
2316
- second_step_pred_flow = (second_step_pred - denoised_latent) / (1. - second_times)
3600
+ second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised)
3601
+ second_step_pred_flow = (second_step_pred - denoised) / (1. - second_times)
2317
3602
 
2318
3603
  # pred target is sg(b' + b'') / 2
2319
3604
 
@@ -2322,7 +3607,7 @@ class DynamicsWorldModel(Module):
2322
3607
  # need to convert x-space to v-space
2323
3608
 
2324
3609
  if is_x_space:
2325
- pred = (pred - noised_latents) / (1. - first_times)
3610
+ pred = (pred - noised) / (1. - first_times)
2326
3611
  maybe_shortcut_loss_weight = (1. - first_times) ** 2
2327
3612
 
2328
3613
  # mse loss
@@ -2335,9 +3620,23 @@ class DynamicsWorldModel(Module):
2335
3620
 
2336
3621
  if exists(self.loss_weight_fn):
2337
3622
  loss_weight = self.loss_weight_fn(times)
3623
+ loss_weight = align_dims_left(loss_weight, flow_losses)
3624
+
2338
3625
  flow_losses = flow_losses * loss_weight
2339
3626
 
2340
- flow_loss = flow_losses.mean()
3627
+ # handle variable lengths if needed
3628
+
3629
+ is_var_len = exists(lens)
3630
+
3631
+ if is_var_len:
3632
+
3633
+ loss_mask = lens_to_mask(lens, time)
3634
+ loss_mask_without_last = loss_mask[:, :-1]
3635
+
3636
+ flow_loss = flow_losses[loss_mask].mean()
3637
+
3638
+ else:
3639
+ flow_loss = flow_losses.mean()
2341
3640
 
2342
3641
  # now take care of the agent token losses
2343
3642
 
@@ -2348,58 +3647,114 @@ class DynamicsWorldModel(Module):
2348
3647
  if rewards.ndim == 2: # (b t)
2349
3648
  encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
2350
3649
 
2351
- reward_pred = self.to_reward_pred(encoded_agent_tokens)
2352
- reward_loss = F.cross_entropy(reward_pred, two_hot_encoding)
3650
+ reward_pred = self.to_reward_pred(encoded_agent_tokens[:, :-1])
3651
+
3652
+ reward_pred = rearrange(reward_pred, 'mtp b t l -> b l t mtp')
3653
+
3654
+ reward_targets, reward_loss_mask = create_multi_token_prediction_targets(two_hot_encoding[:, :-1], self.multi_token_pred_len)
3655
+
3656
+ reward_targets = rearrange(reward_targets, 'b t mtp l -> b l t mtp')
3657
+
3658
+ reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
3659
+
3660
+ reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.)
3661
+
3662
+ if is_var_len:
3663
+ reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
3664
+ else:
3665
+ reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
2353
3666
 
2354
3667
  # maybe autoregressive action loss
2355
3668
 
2356
- behavior_clone_loss = self.zero
3669
+ discrete_action_loss = self.zero
3670
+ continuous_action_loss = self.zero
2357
3671
 
2358
3672
  if (
2359
3673
  self.num_agents == 1 and
2360
3674
  add_autoregressive_action_loss and
3675
+ time > 1,
2361
3676
  (exists(discrete_actions) or exists(continuous_actions))
2362
3677
  ):
2363
3678
  assert self.action_embedder.has_actions
2364
3679
 
3680
+ # handle actions having time vs time - 1 length
3681
+ # remove the first action if it is equal to time (as it would come from some agent token in the past)
3682
+
3683
+ if exists(discrete_actions) and discrete_actions.shape[1] == time:
3684
+ discrete_actions = discrete_actions[:, 1:]
3685
+
3686
+ if exists(continuous_actions) and continuous_actions.shape[1] == time:
3687
+ continuous_actions = continuous_actions[:, 1:]
3688
+
2365
3689
  # only for 1 agent
2366
3690
 
2367
3691
  agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d')
2368
3692
  policy_embed = self.policy_head(agent_tokens[:, :-1])
2369
3693
 
3694
+ # constitute multi token prediction targets
3695
+
3696
+ discrete_action_targets = continuous_action_targets = None
3697
+
3698
+ if exists(discrete_actions):
3699
+ discrete_action_targets, discrete_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
3700
+ discrete_action_targets = rearrange(discrete_action_targets, 'b t mtp ... -> mtp b t ...')
3701
+ discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
3702
+
3703
+ if exists(continuous_actions):
3704
+ continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(continuous_actions, self.multi_token_pred_len)
3705
+ continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
3706
+ continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')
3707
+
2370
3708
  discrete_log_probs, continuous_log_probs = self.action_embedder.log_probs(
2371
3709
  policy_embed,
2372
- discrete_targets = discrete_actions[:, 1:] if exists(discrete_actions) else None,
2373
- continuous_targets = continuous_actions[:, 1:] if exists(continuous_actions) else None
3710
+ discrete_targets = discrete_action_targets if exists(discrete_actions) else None,
3711
+ continuous_targets = continuous_action_targets if exists(continuous_actions) else None
2374
3712
  )
2375
3713
 
2376
3714
  if exists(discrete_log_probs):
2377
- behavior_clone_loss = behavior_clone_loss + discrete_log_probs.sum(dim = -1).mean()
3715
+ discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
3716
+
3717
+ if is_var_len:
3718
+ discrete_action_losses = rearrange(-discrete_log_probs, 'mtp b t na -> b t na mtp')
3719
+ discrete_action_loss = reduce(discrete_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
3720
+ else:
3721
+ discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
2378
3722
 
2379
3723
  if exists(continuous_log_probs):
2380
- behavior_clone_loss = behavior_clone_loss + continuous_log_probs.sum(dim = -1).mean()
3724
+ continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
3725
+
3726
+ if is_var_len:
3727
+ continuous_action_losses = rearrange(-continuous_log_probs, 'mtp b t na -> b t na mtp')
3728
+ continuous_action_loss = reduce(continuous_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
3729
+ else:
3730
+ continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
3731
+
3732
+ # handle loss normalization
3733
+
3734
+ losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss)
3735
+
3736
+ if exists(self.flow_loss_normalizer):
3737
+ flow_loss = self.flow_loss_normalizer(flow_loss, update_ema = update_loss_ema)
2381
3738
 
2382
- # gather losses
3739
+ if exists(rewards) and exists(self.reward_loss_normalizer):
3740
+ reward_loss = self.reward_loss_normalizer(reward_loss, update_ema = update_loss_ema)
3741
+
3742
+ if exists(discrete_actions) and exists(self.discrete_actions_loss_normalizer):
3743
+ discrete_action_loss = self.discrete_actions_loss_normalizer(discrete_action_loss, update_ema = update_loss_ema)
3744
+
3745
+ if exists(continuous_actions) and exists(self.continuous_actions_loss_normalizer):
3746
+ continuous_action_loss = self.continuous_actions_loss_normalizer(continuous_action_loss, update_ema = update_loss_ema)
3747
+
3748
+ # gather losses - they sum across the multi token prediction steps for rewards and actions - eq (9)
2383
3749
 
2384
3750
  total_loss = (
2385
- flow_loss +
2386
- reward_loss * self.reward_loss_weight +
2387
- behavior_clone_loss * self.behavior_clone_weight
3751
+ flow_loss * self.latent_flow_loss_weight +
3752
+ (reward_loss * self.reward_loss_weight).sum() +
3753
+ (discrete_action_loss * self.discrete_action_loss_weight).sum() +
3754
+ (continuous_action_loss * self.continuous_action_loss_weight).sum()
2388
3755
  )
2389
3756
 
2390
3757
  if not return_all_losses:
2391
3758
  return total_loss
2392
3759
 
2393
- losses = WorldModelLosses(flow_loss, reward_loss, behavior_clone_loss)
2394
-
2395
3760
  return total_loss, losses
2396
-
2397
- # dreamer
2398
-
2399
- class Dreamer(Module):
2400
- def __init__(
2401
- self,
2402
- video_tokenizer: VideoTokenizer,
2403
- dynamics_model: DynamicsWorldModel,
2404
- ):
2405
- super().__init__()