dreamer4 0.1.4__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dreamer4/dreamer4.py CHANGED
@@ -14,7 +14,7 @@ from torch.nested import nested_tensor
14
14
  from torch.distributions import Normal, kl
15
15
  from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
16
16
  from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
17
- from torch.utils._pytree import tree_flatten, tree_unflatten
17
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
18
18
 
19
19
  import torchvision
20
20
  from torchvision.models import VGG16_Weights
@@ -27,6 +27,8 @@ from x_mlps_pytorch.normed_mlp import create_mlp
27
27
 
28
28
  from hyper_connections import get_init_and_expand_reduce_stream_functions
29
29
 
30
+ from vit_pytorch.vit_with_decorr import DecorrelationLoss
31
+
30
32
  from assoc_scan import AssocScan
31
33
 
32
34
  # ein related
@@ -68,10 +70,14 @@ except ImportError:
68
70
 
69
71
  LinearNoBias = partial(Linear, bias = False)
70
72
 
71
- TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
73
+ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
72
74
 
73
75
  WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
74
76
 
77
+ AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
78
+
79
+ TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs'))
80
+
75
81
  MaybeTensor = Tensor | None
76
82
 
77
83
  @dataclass
@@ -91,6 +97,14 @@ class Experience:
91
97
  agent_index: int = 0
92
98
  is_from_world_model: bool = True
93
99
 
100
+ def cpu(self):
101
+ return self.to(torch.device('cpu'))
102
+
103
+ def to(self, device):
104
+ experience_dict = asdict(self)
105
+ experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
106
+ return Experience(**experience_dict)
107
+
94
108
  def combine_experiences(
95
109
  exps: list[Experiences]
96
110
  ) -> Experience:
@@ -1313,7 +1327,7 @@ class Attention(Module):
1313
1327
  self,
1314
1328
  tokens, # (b n d)
1315
1329
  kv_cache = None,
1316
- return_kv_cache = False,
1330
+ return_intermediates = False,
1317
1331
  rotary_pos_emb = None,
1318
1332
  attend_fn: Callable | None = None
1319
1333
  ):
@@ -1367,10 +1381,10 @@ class Attention(Module):
1367
1381
 
1368
1382
  out = inverse_packed_batch(out)
1369
1383
 
1370
- if not return_kv_cache:
1384
+ if not return_intermediates:
1371
1385
  return out
1372
1386
 
1373
- return out, stack((k, v))
1387
+ return out, AttentionIntermediates(stack((k, v)), tokens)
1374
1388
 
1375
1389
  # feedforward
1376
1390
 
@@ -1484,7 +1498,7 @@ class AxialSpaceTimeTransformer(Module):
1484
1498
  self,
1485
1499
  tokens, # (b t s d)
1486
1500
  kv_cache: Tensor | None = None, # (y 2 b h t d)
1487
- return_kv_cache = False
1501
+ return_intermediates = False
1488
1502
 
1489
1503
  ): # (b t s d) | (y 2 b h t d)
1490
1504
 
@@ -1525,6 +1539,11 @@ class AxialSpaceTimeTransformer(Module):
1525
1539
 
1526
1540
  rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
1527
1541
 
1542
+ # normed attention inputs
1543
+
1544
+ normed_time_attn_inputs = []
1545
+ normed_space_attn_inputs = []
1546
+
1528
1547
  # attention
1529
1548
 
1530
1549
  tokens = self.expand_streams(tokens)
@@ -1545,12 +1564,12 @@ class AxialSpaceTimeTransformer(Module):
1545
1564
 
1546
1565
  # attention layer
1547
1566
 
1548
- tokens, next_kv_cache = attn(
1567
+ tokens, attn_intermediates = attn(
1549
1568
  tokens,
1550
1569
  rotary_pos_emb = layer_rotary_pos_emb,
1551
1570
  attend_fn = attend_fn,
1552
1571
  kv_cache = maybe_kv_cache,
1553
- return_kv_cache = True
1572
+ return_intermediates = True
1554
1573
  )
1555
1574
 
1556
1575
  tokens = post_attn_rearrange(tokens)
@@ -1562,7 +1581,13 @@ class AxialSpaceTimeTransformer(Module):
1562
1581
  # save kv cache if is time layer
1563
1582
 
1564
1583
  if layer_is_time:
1565
- time_attn_kv_caches.append(next_kv_cache)
1584
+ time_attn_kv_caches.append(attn_intermediates.next_kv_cache)
1585
+
1586
+ # save time attention inputs for decorr
1587
+
1588
+ space_or_time_inputs = normed_time_attn_inputs if layer_is_time else normed_space_attn_inputs
1589
+
1590
+ space_or_time_inputs.append(attn_intermediates.normed_inputs)
1566
1591
 
1567
1592
  tokens = self.reduce_streams(tokens)
1568
1593
 
@@ -1572,10 +1597,16 @@ class AxialSpaceTimeTransformer(Module):
1572
1597
  # just concat the past tokens back on for now, todo - clean up the logic
1573
1598
  out = cat((past_tokens, out), dim = 1)
1574
1599
 
1575
- if not return_kv_cache:
1600
+ if not return_intermediates:
1576
1601
  return out
1577
1602
 
1578
- return out, stack(time_attn_kv_caches)
1603
+ intermediates = TransformerIntermediates(
1604
+ stack(time_attn_kv_caches),
1605
+ stack(normed_time_attn_inputs),
1606
+ stack(normed_space_attn_inputs)
1607
+ )
1608
+
1609
+ return out, intermediates
1579
1610
 
1580
1611
  # video tokenizer
1581
1612
 
@@ -1601,12 +1632,15 @@ class VideoTokenizer(Module):
1601
1632
  per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
1602
1633
  lpips_loss_network: Module | None = None,
1603
1634
  lpips_loss_weight = 0.2,
1635
+ encoder_add_decor_aux_loss = False,
1636
+ decor_auxx_loss_weight = 0.1,
1637
+ decorr_sample_frac = 0.25,
1604
1638
  nd_rotary_kwargs: dict = dict(
1605
1639
  rope_min_freq = 1.,
1606
1640
  rope_max_freq = 10000.,
1607
1641
  rope_p_zero_freqs = 0.
1608
1642
  ),
1609
- num_residual_streams = 1
1643
+ num_residual_streams = 1,
1610
1644
  ):
1611
1645
  super().__init__()
1612
1646
 
@@ -1701,6 +1735,14 @@ class VideoTokenizer(Module):
1701
1735
  if self.has_lpips_loss:
1702
1736
  self.lpips = LPIPSLoss(lpips_loss_network)
1703
1737
 
1738
+ # decorr aux loss
1739
+ # https://arxiv.org/abs/2510.14657
1740
+
1741
+ self.encoder_add_decor_aux_loss = encoder_add_decor_aux_loss
1742
+ self.decorr_aux_loss_weight = decor_auxx_loss_weight
1743
+
1744
+ self.decorr_loss = DecorrelationLoss(decorr_sample_frac, soft_validate_num_sampled = True) if encoder_add_decor_aux_loss else None
1745
+
1704
1746
  @property
1705
1747
  def device(self):
1706
1748
  return self.zero.device
@@ -1814,7 +1856,7 @@ class VideoTokenizer(Module):
1814
1856
 
1815
1857
  # encoder attention
1816
1858
 
1817
- tokens = self.encoder_transformer(tokens)
1859
+ tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
1818
1860
 
1819
1861
  # latent bottleneck
1820
1862
 
@@ -1836,17 +1878,25 @@ class VideoTokenizer(Module):
1836
1878
  if self.has_lpips_loss:
1837
1879
  lpips_loss = self.lpips(video, recon_video)
1838
1880
 
1881
+ time_decorr_loss = space_decorr_loss = self.zero
1882
+
1883
+ if self.encoder_add_decor_aux_loss:
1884
+ time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
1885
+ space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
1886
+
1839
1887
  # losses
1840
1888
 
1841
1889
  total_loss = (
1842
1890
  recon_loss +
1843
- lpips_loss * self.lpips_loss_weight
1891
+ lpips_loss * self.lpips_loss_weight +
1892
+ time_decorr_loss * self.decorr_aux_loss_weight +
1893
+ space_decorr_loss * self.decorr_aux_loss_weight
1844
1894
  )
1845
1895
 
1846
1896
  if not return_all_losses:
1847
1897
  return total_loss
1848
1898
 
1849
- losses = (recon_loss, lpips_loss)
1899
+ losses = (recon_loss, lpips_loss, decorr_loss)
1850
1900
 
1851
1901
  return total_loss, TokenizerLosses(*losses)
1852
1902
 
@@ -2435,6 +2485,8 @@ class DynamicsWorldModel(Module):
2435
2485
  ):
2436
2486
  assert isinstance(experience, Experience)
2437
2487
 
2488
+ experience = experience.to(self.device)
2489
+
2438
2490
  latents = experience.latents
2439
2491
  actions = experience.actions
2440
2492
  old_log_probs = experience.log_probs
@@ -3325,7 +3377,7 @@ class DynamicsWorldModel(Module):
3325
3377
 
3326
3378
  # attention
3327
3379
 
3328
- tokens, next_time_kv_cache = self.transformer(tokens, kv_cache = time_kv_cache, return_kv_cache = True)
3380
+ tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
3329
3381
 
3330
3382
  # unpack
3331
3383
 
dreamer4/trainers.py CHANGED
@@ -528,7 +528,7 @@ class SimTrainer(Module):
528
528
 
529
529
  total_experience += num_experience
530
530
 
531
- experiences.append(experience)
531
+ experiences.append(experience.cpu())
532
532
 
533
533
  combined_experiences = combine_experiences(experiences)
534
534
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.1.4
3
+ Version: 0.1.10
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -44,6 +44,7 @@ Requires-Dist: hl-gauss-pytorch
44
44
  Requires-Dist: hyper-connections>=0.2.1
45
45
  Requires-Dist: torch>=2.4
46
46
  Requires-Dist: torchvision
47
+ Requires-Dist: vit-pytorch>=1.15.3
47
48
  Requires-Dist: x-mlps-pytorch>=0.0.29
48
49
  Provides-Extra: examples
49
50
  Provides-Extra: test
@@ -57,7 +58,7 @@ Description-Content-Type: text/markdown
57
58
 
58
59
  Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
59
60
 
60
- [Discord channel](https://discord.gg/ab4BEk3W) for collaborating with other researchers interested in this work
61
+ [Discord channel](https://discord.gg/PmGR7KRwxq) for collaborating with other researchers interested in this work
61
62
 
62
63
  ## Appreciation
63
64
 
@@ -90,7 +91,7 @@ video = torch.randn(2, 3, 10, 256, 256)
90
91
  # learn the tokenizer
91
92
 
92
93
  loss = tokenizer(video)
93
- loss.backward() # ler
94
+ loss.backward()
94
95
 
95
96
  # dynamics world model
96
97
 
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=_xr_XJGfqhCabRV0vnue4zypHZ4kXeUDZp1N6RF2AoY,122988
3
+ dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
+ dreamer4/trainers.py,sha256=h_BMi-P2QMVi-IWQCkejPmyA0UzHgKtE1n7Qn1-IrxE,15093
5
+ dreamer4-0.1.10.dist-info/METADATA,sha256=oTK9b_fWDCQTC89Y30OBY_2BzJJ6ih25BzgO0D-SApg,4973
6
+ dreamer4-0.1.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.1.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.1.10.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=ghestMgz7B1oEqBRR0XkkdWe0kkh7bshhzmi6-n-XIs,120790
3
- dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
- dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
5
- dreamer4-0.1.4.dist-info/METADATA,sha256=GkzuqKtNJJCSh5FycWJOr49253_w926biJkSz9ic4TQ,4941
6
- dreamer4-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.1.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.1.4.dist-info/RECORD,,