dreamer4 0.1.4__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dreamer4/dreamer4.py CHANGED
@@ -14,7 +14,7 @@ from torch.nested import nested_tensor
14
14
  from torch.distributions import Normal, kl
15
15
  from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
16
16
  from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
17
- from torch.utils._pytree import tree_flatten, tree_unflatten
17
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
18
18
 
19
19
  import torchvision
20
20
  from torchvision.models import VGG16_Weights
@@ -27,6 +27,8 @@ from x_mlps_pytorch.normed_mlp import create_mlp
27
27
 
28
28
  from hyper_connections import get_init_and_expand_reduce_stream_functions
29
29
 
30
+ from vit_pytorch.vit_with_decorr import DecorrelationLoss
31
+
30
32
  from assoc_scan import AssocScan
31
33
 
32
34
  # ein related
@@ -68,10 +70,14 @@ except ImportError:
68
70
 
69
71
  LinearNoBias = partial(Linear, bias = False)
70
72
 
71
- TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
73
+ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
72
74
 
73
75
  WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
74
76
 
77
+ AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
78
+
79
+ TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs'))
80
+
75
81
  MaybeTensor = Tensor | None
76
82
 
77
83
  @dataclass
@@ -91,6 +97,14 @@ class Experience:
91
97
  agent_index: int = 0
92
98
  is_from_world_model: bool = True
93
99
 
100
+ def cpu(self):
101
+ return self.to(torch.device('cpu'))
102
+
103
+ def to(self, device):
104
+ experience_dict = asdict(self)
105
+ experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
106
+ return Experience(**experience_dict)
107
+
94
108
  def combine_experiences(
95
109
  exps: list[Experiences]
96
110
  ) -> Experience:
@@ -175,6 +189,13 @@ def sample_prob(prob):
175
189
  def is_power_two(num):
176
190
  return log2(num).is_integer()
177
191
 
192
+ def maybe(fn):
193
+ def inner(t, *args, **kwargs):
194
+ if not exists(t) or not exists(fn):
195
+ return None
196
+ return fn(t)
197
+ return inner
198
+
178
199
  # tensor helpers
179
200
 
180
201
  def is_empty(t):
@@ -209,6 +230,14 @@ def mean_log_var_to_distr(
209
230
  std = (0.5 * log_var).exp()
210
231
  return Normal(mean, std)
211
232
 
233
+ def safe_stack(tensors, dim = 0):
234
+ tensors = [*filter(exists, tensors)]
235
+
236
+ if len(tensors) == 0:
237
+ return None
238
+
239
+ return stack(tensors, dim = dim)
240
+
212
241
  def safe_cat(tensors, dim):
213
242
  tensors = [*filter(exists, tensors)]
214
243
 
@@ -1262,7 +1291,8 @@ class Attention(Module):
1262
1291
  pre_rmsnorm = True,
1263
1292
  gate_values = True,
1264
1293
  rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2
1265
- rmsnorm_key = True
1294
+ rmsnorm_key = True,
1295
+ value_residual = True
1266
1296
  ):
1267
1297
  super().__init__()
1268
1298
  self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
@@ -1301,6 +1331,14 @@ class Attention(Module):
1301
1331
  self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity()
1302
1332
  self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
1303
1333
 
1334
+ # value residual
1335
+
1336
+ self.to_learned_value_residual_mix = nn.Sequential(
1337
+ nn.Linear(dim, heads),
1338
+ Rearrange('b n h -> b h n 1'),
1339
+ nn.Sigmoid()
1340
+ ) if value_residual else None
1341
+
1304
1342
  def muon_parameters(self):
1305
1343
  # omit the queries and keys for now given what we learned from kimi 2 paper
1306
1344
 
@@ -1313,8 +1351,9 @@ class Attention(Module):
1313
1351
  self,
1314
1352
  tokens, # (b n d)
1315
1353
  kv_cache = None,
1316
- return_kv_cache = False,
1354
+ return_intermediates = False,
1317
1355
  rotary_pos_emb = None,
1356
+ residual_values = None, # (b n h d)
1318
1357
  attend_fn: Callable | None = None
1319
1358
  ):
1320
1359
  tokens, inverse_packed_batch = pack_one(tokens, '* n d')
@@ -1327,6 +1366,17 @@ class Attention(Module):
1327
1366
 
1328
1367
  q, k, v = map(self.split_heads, (q, k, v))
1329
1368
 
1369
+ # handle maybe value residual
1370
+
1371
+ if exists(residual_values):
1372
+ residual_values = rearrange(residual_values, '... n h d -> (...) h n d')
1373
+
1374
+ assert exists(self.to_learned_value_residual_mix)
1375
+
1376
+ learned_mix = self.to_learned_value_residual_mix(tokens)
1377
+
1378
+ v = v.lerp(residual_values, learned_mix)
1379
+
1330
1380
  # qk rmsnorm
1331
1381
 
1332
1382
  q = self.q_heads_rmsnorm(q)
@@ -1367,10 +1417,10 @@ class Attention(Module):
1367
1417
 
1368
1418
  out = inverse_packed_batch(out)
1369
1419
 
1370
- if not return_kv_cache:
1420
+ if not return_intermediates:
1371
1421
  return out
1372
1422
 
1373
- return out, stack((k, v))
1423
+ return out, AttentionIntermediates(stack((k, v)), tokens)
1374
1424
 
1375
1425
  # feedforward
1376
1426
 
@@ -1410,6 +1460,7 @@ class AxialSpaceTimeTransformer(Module):
1410
1460
  self,
1411
1461
  dim,
1412
1462
  depth,
1463
+ attn_heads = 8,
1413
1464
  attn_dim_head = 64,
1414
1465
  attn_softclamp_value = 50.,
1415
1466
  time_block_every = 4,
@@ -1418,7 +1469,8 @@ class AxialSpaceTimeTransformer(Module):
1418
1469
  num_residual_streams = 1,
1419
1470
  num_special_spatial_tokens = 1,
1420
1471
  special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
1421
- final_norm = True
1472
+ final_norm = True,
1473
+ value_residual = True # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS
1422
1474
  ):
1423
1475
  super().__init__()
1424
1476
  assert depth >= time_block_every, f'depth must be at least {time_block_every}'
@@ -1439,6 +1491,19 @@ class AxialSpaceTimeTransformer(Module):
1439
1491
 
1440
1492
  self.time_rotary = Rotary1D(attn_dim_head)
1441
1493
 
1494
+ # project initial for value residuals
1495
+
1496
+ self.value_residual = value_residual
1497
+
1498
+ if value_residual:
1499
+ dim_inner = attn_dim_head * attn_heads
1500
+
1501
+ self.to_value_residual = nn.Sequential(
1502
+ nn.RMSNorm(dim),
1503
+ nn.Linear(dim, dim_inner, bias = False),
1504
+ Rearrange('... (h d) -> ... h d', h = attn_heads)
1505
+ )
1506
+
1442
1507
  # transformer
1443
1508
 
1444
1509
  layers = []
@@ -1450,13 +1515,13 @@ class AxialSpaceTimeTransformer(Module):
1450
1515
  is_time_block = divisible_by(layer_index, time_block_every)
1451
1516
  is_time.append(is_time_block)
1452
1517
 
1453
- rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity()
1454
- rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity()
1518
+ rearrange_to_attend = Rearrange('b t s ... -> b s t ...') if is_time_block else Identity()
1519
+ rearrange_from_attend = Rearrange('b s t ... -> b t s ...') if is_time_block else Identity()
1455
1520
 
1456
1521
  layers.append(ModuleList([
1457
1522
  rearrange_to_attend,
1458
1523
  rearrange_from_attend,
1459
- hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
1524
+ hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, value_residual = value_residual, **attn_kwargs)),
1460
1525
  hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
1461
1526
  ]))
1462
1527
 
@@ -1484,7 +1549,7 @@ class AxialSpaceTimeTransformer(Module):
1484
1549
  self,
1485
1550
  tokens, # (b t s d)
1486
1551
  kv_cache: Tensor | None = None, # (y 2 b h t d)
1487
- return_kv_cache = False
1552
+ return_intermediates = False
1488
1553
 
1489
1554
  ): # (b t s d) | (y 2 b h t d)
1490
1555
 
@@ -1507,7 +1572,6 @@ class AxialSpaceTimeTransformer(Module):
1507
1572
 
1508
1573
  time_attn_kv_caches = []
1509
1574
 
1510
-
1511
1575
  if has_kv_cache:
1512
1576
  past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
1513
1577
 
@@ -1525,6 +1589,18 @@ class AxialSpaceTimeTransformer(Module):
1525
1589
 
1526
1590
  rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
1527
1591
 
1592
+ # value residual
1593
+
1594
+ residual_values = None
1595
+
1596
+ if self.value_residual:
1597
+ residual_values = self.to_value_residual(tokens)
1598
+
1599
+ # normed attention inputs
1600
+
1601
+ normed_time_attn_inputs = []
1602
+ normed_space_attn_inputs = []
1603
+
1528
1604
  # attention
1529
1605
 
1530
1606
  tokens = self.expand_streams(tokens)
@@ -1543,14 +1619,19 @@ class AxialSpaceTimeTransformer(Module):
1543
1619
 
1544
1620
  maybe_kv_cache = next(iter_kv_cache, None) if layer_is_time else None
1545
1621
 
1622
+ # residual values
1623
+
1624
+ layer_residual_values = maybe(pre_attn_rearrange)(residual_values)
1625
+
1546
1626
  # attention layer
1547
1627
 
1548
- tokens, next_kv_cache = attn(
1628
+ tokens, attn_intermediates = attn(
1549
1629
  tokens,
1550
1630
  rotary_pos_emb = layer_rotary_pos_emb,
1551
1631
  attend_fn = attend_fn,
1552
1632
  kv_cache = maybe_kv_cache,
1553
- return_kv_cache = True
1633
+ residual_values = layer_residual_values,
1634
+ return_intermediates = True
1554
1635
  )
1555
1636
 
1556
1637
  tokens = post_attn_rearrange(tokens)
@@ -1562,7 +1643,13 @@ class AxialSpaceTimeTransformer(Module):
1562
1643
  # save kv cache if is time layer
1563
1644
 
1564
1645
  if layer_is_time:
1565
- time_attn_kv_caches.append(next_kv_cache)
1646
+ time_attn_kv_caches.append(attn_intermediates.next_kv_cache)
1647
+
1648
+ # save time attention inputs for decorr
1649
+
1650
+ space_or_time_inputs = normed_time_attn_inputs if layer_is_time else normed_space_attn_inputs
1651
+
1652
+ space_or_time_inputs.append(attn_intermediates.normed_inputs)
1566
1653
 
1567
1654
  tokens = self.reduce_streams(tokens)
1568
1655
 
@@ -1572,10 +1659,16 @@ class AxialSpaceTimeTransformer(Module):
1572
1659
  # just concat the past tokens back on for now, todo - clean up the logic
1573
1660
  out = cat((past_tokens, out), dim = 1)
1574
1661
 
1575
- if not return_kv_cache:
1662
+ if not return_intermediates:
1576
1663
  return out
1577
1664
 
1578
- return out, stack(time_attn_kv_caches)
1665
+ intermediates = TransformerIntermediates(
1666
+ stack(time_attn_kv_caches),
1667
+ safe_stack(normed_time_attn_inputs),
1668
+ safe_stack(normed_space_attn_inputs)
1669
+ )
1670
+
1671
+ return out, intermediates
1579
1672
 
1580
1673
  # video tokenizer
1581
1674
 
@@ -1601,12 +1694,15 @@ class VideoTokenizer(Module):
1601
1694
  per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
1602
1695
  lpips_loss_network: Module | None = None,
1603
1696
  lpips_loss_weight = 0.2,
1697
+ encoder_add_decor_aux_loss = False,
1698
+ decor_auxx_loss_weight = 0.1,
1699
+ decorr_sample_frac = 0.25,
1604
1700
  nd_rotary_kwargs: dict = dict(
1605
1701
  rope_min_freq = 1.,
1606
1702
  rope_max_freq = 10000.,
1607
1703
  rope_p_zero_freqs = 0.
1608
1704
  ),
1609
- num_residual_streams = 1
1705
+ num_residual_streams = 1,
1610
1706
  ):
1611
1707
  super().__init__()
1612
1708
 
@@ -1701,6 +1797,14 @@ class VideoTokenizer(Module):
1701
1797
  if self.has_lpips_loss:
1702
1798
  self.lpips = LPIPSLoss(lpips_loss_network)
1703
1799
 
1800
+ # decorr aux loss
1801
+ # https://arxiv.org/abs/2510.14657
1802
+
1803
+ self.encoder_add_decor_aux_loss = encoder_add_decor_aux_loss
1804
+ self.decorr_aux_loss_weight = decor_auxx_loss_weight
1805
+
1806
+ self.decorr_loss = DecorrelationLoss(decorr_sample_frac, soft_validate_num_sampled = True) if encoder_add_decor_aux_loss else None
1807
+
1704
1808
  @property
1705
1809
  def device(self):
1706
1810
  return self.zero.device
@@ -1814,7 +1918,7 @@ class VideoTokenizer(Module):
1814
1918
 
1815
1919
  # encoder attention
1816
1920
 
1817
- tokens = self.encoder_transformer(tokens)
1921
+ tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
1818
1922
 
1819
1923
  # latent bottleneck
1820
1924
 
@@ -1836,17 +1940,28 @@ class VideoTokenizer(Module):
1836
1940
  if self.has_lpips_loss:
1837
1941
  lpips_loss = self.lpips(video, recon_video)
1838
1942
 
1943
+ time_decorr_loss = space_decorr_loss = self.zero
1944
+
1945
+ if self.encoder_add_decor_aux_loss:
1946
+ if exists(time_attn_normed_inputs):
1947
+ time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
1948
+
1949
+ if exists(space_attn_normed_inputs):
1950
+ space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
1951
+
1839
1952
  # losses
1840
1953
 
1841
1954
  total_loss = (
1842
1955
  recon_loss +
1843
- lpips_loss * self.lpips_loss_weight
1956
+ lpips_loss * self.lpips_loss_weight +
1957
+ time_decorr_loss * self.decorr_aux_loss_weight +
1958
+ space_decorr_loss * self.decorr_aux_loss_weight
1844
1959
  )
1845
1960
 
1846
1961
  if not return_all_losses:
1847
1962
  return total_loss
1848
1963
 
1849
- losses = (recon_loss, lpips_loss)
1964
+ losses = (recon_loss, lpips_loss, decorr_loss)
1850
1965
 
1851
1966
  return total_loss, TokenizerLosses(*losses)
1852
1967
 
@@ -1870,10 +1985,9 @@ class DynamicsWorldModel(Module):
1870
1985
  depth = 4,
1871
1986
  pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
1872
1987
  time_block_every = 4, # every 4th block is time
1873
- attn_kwargs: dict = dict(
1874
- heads = 8,
1875
- ),
1988
+ attn_kwargs: dict = dict(),
1876
1989
  transformer_kwargs: dict = dict(),
1990
+ attn_heads = 8,
1877
1991
  attn_dim_head = 64,
1878
1992
  attn_softclamp_value = 50.,
1879
1993
  ff_kwargs: dict = dict(),
@@ -2086,6 +2200,7 @@ class DynamicsWorldModel(Module):
2086
2200
  self.transformer = AxialSpaceTimeTransformer(
2087
2201
  dim = dim,
2088
2202
  depth = depth,
2203
+ attn_heads = attn_heads,
2089
2204
  attn_dim_head = attn_dim_head,
2090
2205
  attn_softclamp_value = attn_softclamp_value,
2091
2206
  attn_kwargs = attn_kwargs,
@@ -2435,6 +2550,8 @@ class DynamicsWorldModel(Module):
2435
2550
  ):
2436
2551
  assert isinstance(experience, Experience)
2437
2552
 
2553
+ experience = experience.to(self.device)
2554
+
2438
2555
  latents = experience.latents
2439
2556
  actions = experience.actions
2440
2557
  old_log_probs = experience.log_probs
@@ -3325,7 +3442,7 @@ class DynamicsWorldModel(Module):
3325
3442
 
3326
3443
  # attention
3327
3444
 
3328
- tokens, next_time_kv_cache = self.transformer(tokens, kv_cache = time_kv_cache, return_kv_cache = True)
3445
+ tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
3329
3446
 
3330
3447
  # unpack
3331
3448
 
dreamer4/trainers.py CHANGED
@@ -528,7 +528,7 @@ class SimTrainer(Module):
528
528
 
529
529
  total_experience += num_experience
530
530
 
531
- experiences.append(experience)
531
+ experiences.append(experience.cpu())
532
532
 
533
533
  combined_experiences = combine_experiences(experiences)
534
534
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.1.4
3
+ Version: 0.1.15
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -44,6 +44,7 @@ Requires-Dist: hl-gauss-pytorch
44
44
  Requires-Dist: hyper-connections>=0.2.1
45
45
  Requires-Dist: torch>=2.4
46
46
  Requires-Dist: torchvision
47
+ Requires-Dist: vit-pytorch>=1.15.3
47
48
  Requires-Dist: x-mlps-pytorch>=0.0.29
48
49
  Provides-Extra: examples
49
50
  Provides-Extra: test
@@ -57,7 +58,7 @@ Description-Content-Type: text/markdown
57
58
 
58
59
  Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
59
60
 
60
- [Discord channel](https://discord.gg/ab4BEk3W) for collaborating with other researchers interested in this work
61
+ [Discord channel](https://discord.gg/PmGR7KRwxq) for collaborating with other researchers interested in this work
61
62
 
62
63
  ## Appreciation
63
64
 
@@ -90,7 +91,7 @@ video = torch.randn(2, 3, 10, 256, 256)
90
91
  # learn the tokenizer
91
92
 
92
93
  loss = tokenizer(video)
93
- loss.backward() # ler
94
+ loss.backward()
94
95
 
95
96
  # dynamics world model
96
97
 
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=BVMAIfhqv7wO0FWo-SBfUnyXEQcMljh6CyaHeZ8GmCI,125018
3
+ dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
+ dreamer4/trainers.py,sha256=h_BMi-P2QMVi-IWQCkejPmyA0UzHgKtE1n7Qn1-IrxE,15093
5
+ dreamer4-0.1.15.dist-info/METADATA,sha256=ghChOd76397jZ_XwFwKRv1lxP1ZFqNgQfSKBUB7DXoo,4973
6
+ dreamer4-0.1.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.1.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.1.15.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=ghestMgz7B1oEqBRR0XkkdWe0kkh7bshhzmi6-n-XIs,120790
3
- dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
- dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
5
- dreamer4-0.1.4.dist-info/METADATA,sha256=GkzuqKtNJJCSh5FycWJOr49253_w926biJkSz9ic4TQ,4941
6
- dreamer4-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.1.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.1.4.dist-info/RECORD,,