dreamer4 0.1.4__py3-none-any.whl → 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dreamer4/dreamer4.py +68 -16
- dreamer4/trainers.py +1 -1
- {dreamer4-0.1.4.dist-info → dreamer4-0.1.10.dist-info}/METADATA +4 -3
- dreamer4-0.1.10.dist-info/RECORD +8 -0
- dreamer4-0.1.4.dist-info/RECORD +0 -8
- {dreamer4-0.1.4.dist-info → dreamer4-0.1.10.dist-info}/WHEEL +0 -0
- {dreamer4-0.1.4.dist-info → dreamer4-0.1.10.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -14,7 +14,7 @@ from torch.nested import nested_tensor
|
|
|
14
14
|
from torch.distributions import Normal, kl
|
|
15
15
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
|
16
16
|
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
17
|
-
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
17
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
18
18
|
|
|
19
19
|
import torchvision
|
|
20
20
|
from torchvision.models import VGG16_Weights
|
|
@@ -27,6 +27,8 @@ from x_mlps_pytorch.normed_mlp import create_mlp
|
|
|
27
27
|
|
|
28
28
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
29
29
|
|
|
30
|
+
from vit_pytorch.vit_with_decorr import DecorrelationLoss
|
|
31
|
+
|
|
30
32
|
from assoc_scan import AssocScan
|
|
31
33
|
|
|
32
34
|
# ein related
|
|
@@ -68,10 +70,14 @@ except ImportError:
|
|
|
68
70
|
|
|
69
71
|
LinearNoBias = partial(Linear, bias = False)
|
|
70
72
|
|
|
71
|
-
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
|
73
|
+
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
|
|
72
74
|
|
|
73
75
|
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
|
74
76
|
|
|
77
|
+
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
|
|
78
|
+
|
|
79
|
+
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs'))
|
|
80
|
+
|
|
75
81
|
MaybeTensor = Tensor | None
|
|
76
82
|
|
|
77
83
|
@dataclass
|
|
@@ -91,6 +97,14 @@ class Experience:
|
|
|
91
97
|
agent_index: int = 0
|
|
92
98
|
is_from_world_model: bool = True
|
|
93
99
|
|
|
100
|
+
def cpu(self):
|
|
101
|
+
return self.to(torch.device('cpu'))
|
|
102
|
+
|
|
103
|
+
def to(self, device):
|
|
104
|
+
experience_dict = asdict(self)
|
|
105
|
+
experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
|
|
106
|
+
return Experience(**experience_dict)
|
|
107
|
+
|
|
94
108
|
def combine_experiences(
|
|
95
109
|
exps: list[Experiences]
|
|
96
110
|
) -> Experience:
|
|
@@ -1313,7 +1327,7 @@ class Attention(Module):
|
|
|
1313
1327
|
self,
|
|
1314
1328
|
tokens, # (b n d)
|
|
1315
1329
|
kv_cache = None,
|
|
1316
|
-
|
|
1330
|
+
return_intermediates = False,
|
|
1317
1331
|
rotary_pos_emb = None,
|
|
1318
1332
|
attend_fn: Callable | None = None
|
|
1319
1333
|
):
|
|
@@ -1367,10 +1381,10 @@ class Attention(Module):
|
|
|
1367
1381
|
|
|
1368
1382
|
out = inverse_packed_batch(out)
|
|
1369
1383
|
|
|
1370
|
-
if not
|
|
1384
|
+
if not return_intermediates:
|
|
1371
1385
|
return out
|
|
1372
1386
|
|
|
1373
|
-
return out, stack((k, v))
|
|
1387
|
+
return out, AttentionIntermediates(stack((k, v)), tokens)
|
|
1374
1388
|
|
|
1375
1389
|
# feedforward
|
|
1376
1390
|
|
|
@@ -1484,7 +1498,7 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1484
1498
|
self,
|
|
1485
1499
|
tokens, # (b t s d)
|
|
1486
1500
|
kv_cache: Tensor | None = None, # (y 2 b h t d)
|
|
1487
|
-
|
|
1501
|
+
return_intermediates = False
|
|
1488
1502
|
|
|
1489
1503
|
): # (b t s d) | (y 2 b h t d)
|
|
1490
1504
|
|
|
@@ -1525,6 +1539,11 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1525
1539
|
|
|
1526
1540
|
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
|
1527
1541
|
|
|
1542
|
+
# normed attention inputs
|
|
1543
|
+
|
|
1544
|
+
normed_time_attn_inputs = []
|
|
1545
|
+
normed_space_attn_inputs = []
|
|
1546
|
+
|
|
1528
1547
|
# attention
|
|
1529
1548
|
|
|
1530
1549
|
tokens = self.expand_streams(tokens)
|
|
@@ -1545,12 +1564,12 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1545
1564
|
|
|
1546
1565
|
# attention layer
|
|
1547
1566
|
|
|
1548
|
-
tokens,
|
|
1567
|
+
tokens, attn_intermediates = attn(
|
|
1549
1568
|
tokens,
|
|
1550
1569
|
rotary_pos_emb = layer_rotary_pos_emb,
|
|
1551
1570
|
attend_fn = attend_fn,
|
|
1552
1571
|
kv_cache = maybe_kv_cache,
|
|
1553
|
-
|
|
1572
|
+
return_intermediates = True
|
|
1554
1573
|
)
|
|
1555
1574
|
|
|
1556
1575
|
tokens = post_attn_rearrange(tokens)
|
|
@@ -1562,7 +1581,13 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1562
1581
|
# save kv cache if is time layer
|
|
1563
1582
|
|
|
1564
1583
|
if layer_is_time:
|
|
1565
|
-
time_attn_kv_caches.append(next_kv_cache)
|
|
1584
|
+
time_attn_kv_caches.append(attn_intermediates.next_kv_cache)
|
|
1585
|
+
|
|
1586
|
+
# save time attention inputs for decorr
|
|
1587
|
+
|
|
1588
|
+
space_or_time_inputs = normed_time_attn_inputs if layer_is_time else normed_space_attn_inputs
|
|
1589
|
+
|
|
1590
|
+
space_or_time_inputs.append(attn_intermediates.normed_inputs)
|
|
1566
1591
|
|
|
1567
1592
|
tokens = self.reduce_streams(tokens)
|
|
1568
1593
|
|
|
@@ -1572,10 +1597,16 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1572
1597
|
# just concat the past tokens back on for now, todo - clean up the logic
|
|
1573
1598
|
out = cat((past_tokens, out), dim = 1)
|
|
1574
1599
|
|
|
1575
|
-
if not
|
|
1600
|
+
if not return_intermediates:
|
|
1576
1601
|
return out
|
|
1577
1602
|
|
|
1578
|
-
|
|
1603
|
+
intermediates = TransformerIntermediates(
|
|
1604
|
+
stack(time_attn_kv_caches),
|
|
1605
|
+
stack(normed_time_attn_inputs),
|
|
1606
|
+
stack(normed_space_attn_inputs)
|
|
1607
|
+
)
|
|
1608
|
+
|
|
1609
|
+
return out, intermediates
|
|
1579
1610
|
|
|
1580
1611
|
# video tokenizer
|
|
1581
1612
|
|
|
@@ -1601,12 +1632,15 @@ class VideoTokenizer(Module):
|
|
|
1601
1632
|
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
|
1602
1633
|
lpips_loss_network: Module | None = None,
|
|
1603
1634
|
lpips_loss_weight = 0.2,
|
|
1635
|
+
encoder_add_decor_aux_loss = False,
|
|
1636
|
+
decor_auxx_loss_weight = 0.1,
|
|
1637
|
+
decorr_sample_frac = 0.25,
|
|
1604
1638
|
nd_rotary_kwargs: dict = dict(
|
|
1605
1639
|
rope_min_freq = 1.,
|
|
1606
1640
|
rope_max_freq = 10000.,
|
|
1607
1641
|
rope_p_zero_freqs = 0.
|
|
1608
1642
|
),
|
|
1609
|
-
num_residual_streams = 1
|
|
1643
|
+
num_residual_streams = 1,
|
|
1610
1644
|
):
|
|
1611
1645
|
super().__init__()
|
|
1612
1646
|
|
|
@@ -1701,6 +1735,14 @@ class VideoTokenizer(Module):
|
|
|
1701
1735
|
if self.has_lpips_loss:
|
|
1702
1736
|
self.lpips = LPIPSLoss(lpips_loss_network)
|
|
1703
1737
|
|
|
1738
|
+
# decorr aux loss
|
|
1739
|
+
# https://arxiv.org/abs/2510.14657
|
|
1740
|
+
|
|
1741
|
+
self.encoder_add_decor_aux_loss = encoder_add_decor_aux_loss
|
|
1742
|
+
self.decorr_aux_loss_weight = decor_auxx_loss_weight
|
|
1743
|
+
|
|
1744
|
+
self.decorr_loss = DecorrelationLoss(decorr_sample_frac, soft_validate_num_sampled = True) if encoder_add_decor_aux_loss else None
|
|
1745
|
+
|
|
1704
1746
|
@property
|
|
1705
1747
|
def device(self):
|
|
1706
1748
|
return self.zero.device
|
|
@@ -1814,7 +1856,7 @@ class VideoTokenizer(Module):
|
|
|
1814
1856
|
|
|
1815
1857
|
# encoder attention
|
|
1816
1858
|
|
|
1817
|
-
tokens = self.encoder_transformer(tokens)
|
|
1859
|
+
tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
|
|
1818
1860
|
|
|
1819
1861
|
# latent bottleneck
|
|
1820
1862
|
|
|
@@ -1836,17 +1878,25 @@ class VideoTokenizer(Module):
|
|
|
1836
1878
|
if self.has_lpips_loss:
|
|
1837
1879
|
lpips_loss = self.lpips(video, recon_video)
|
|
1838
1880
|
|
|
1881
|
+
time_decorr_loss = space_decorr_loss = self.zero
|
|
1882
|
+
|
|
1883
|
+
if self.encoder_add_decor_aux_loss:
|
|
1884
|
+
time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
|
|
1885
|
+
space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
|
|
1886
|
+
|
|
1839
1887
|
# losses
|
|
1840
1888
|
|
|
1841
1889
|
total_loss = (
|
|
1842
1890
|
recon_loss +
|
|
1843
|
-
lpips_loss * self.lpips_loss_weight
|
|
1891
|
+
lpips_loss * self.lpips_loss_weight +
|
|
1892
|
+
time_decorr_loss * self.decorr_aux_loss_weight +
|
|
1893
|
+
space_decorr_loss * self.decorr_aux_loss_weight
|
|
1844
1894
|
)
|
|
1845
1895
|
|
|
1846
1896
|
if not return_all_losses:
|
|
1847
1897
|
return total_loss
|
|
1848
1898
|
|
|
1849
|
-
losses = (recon_loss, lpips_loss)
|
|
1899
|
+
losses = (recon_loss, lpips_loss, decorr_loss)
|
|
1850
1900
|
|
|
1851
1901
|
return total_loss, TokenizerLosses(*losses)
|
|
1852
1902
|
|
|
@@ -2435,6 +2485,8 @@ class DynamicsWorldModel(Module):
|
|
|
2435
2485
|
):
|
|
2436
2486
|
assert isinstance(experience, Experience)
|
|
2437
2487
|
|
|
2488
|
+
experience = experience.to(self.device)
|
|
2489
|
+
|
|
2438
2490
|
latents = experience.latents
|
|
2439
2491
|
actions = experience.actions
|
|
2440
2492
|
old_log_probs = experience.log_probs
|
|
@@ -3325,7 +3377,7 @@ class DynamicsWorldModel(Module):
|
|
|
3325
3377
|
|
|
3326
3378
|
# attention
|
|
3327
3379
|
|
|
3328
|
-
tokens, next_time_kv_cache = self.transformer(tokens, kv_cache = time_kv_cache,
|
|
3380
|
+
tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
|
|
3329
3381
|
|
|
3330
3382
|
# unpack
|
|
3331
3383
|
|
dreamer4/trainers.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dreamer4
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.10
|
|
4
4
|
Summary: Dreamer 4
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/dreamer4/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/dreamer4
|
|
@@ -44,6 +44,7 @@ Requires-Dist: hl-gauss-pytorch
|
|
|
44
44
|
Requires-Dist: hyper-connections>=0.2.1
|
|
45
45
|
Requires-Dist: torch>=2.4
|
|
46
46
|
Requires-Dist: torchvision
|
|
47
|
+
Requires-Dist: vit-pytorch>=1.15.3
|
|
47
48
|
Requires-Dist: x-mlps-pytorch>=0.0.29
|
|
48
49
|
Provides-Extra: examples
|
|
49
50
|
Provides-Extra: test
|
|
@@ -57,7 +58,7 @@ Description-Content-Type: text/markdown
|
|
|
57
58
|
|
|
58
59
|
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
|
59
60
|
|
|
60
|
-
[Discord channel](https://discord.gg/
|
|
61
|
+
[Discord channel](https://discord.gg/PmGR7KRwxq) for collaborating with other researchers interested in this work
|
|
61
62
|
|
|
62
63
|
## Appreciation
|
|
63
64
|
|
|
@@ -90,7 +91,7 @@ video = torch.randn(2, 3, 10, 256, 256)
|
|
|
90
91
|
# learn the tokenizer
|
|
91
92
|
|
|
92
93
|
loss = tokenizer(video)
|
|
93
|
-
loss.backward()
|
|
94
|
+
loss.backward()
|
|
94
95
|
|
|
95
96
|
# dynamics world model
|
|
96
97
|
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=_xr_XJGfqhCabRV0vnue4zypHZ4kXeUDZp1N6RF2AoY,122988
|
|
3
|
+
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
+
dreamer4/trainers.py,sha256=h_BMi-P2QMVi-IWQCkejPmyA0UzHgKtE1n7Qn1-IrxE,15093
|
|
5
|
+
dreamer4-0.1.10.dist-info/METADATA,sha256=oTK9b_fWDCQTC89Y30OBY_2BzJJ6ih25BzgO0D-SApg,4973
|
|
6
|
+
dreamer4-0.1.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.1.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.1.10.dist-info/RECORD,,
|
dreamer4-0.1.4.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=ghestMgz7B1oEqBRR0XkkdWe0kkh7bshhzmi6-n-XIs,120790
|
|
3
|
-
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
-
dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
|
|
5
|
-
dreamer4-0.1.4.dist-info/METADATA,sha256=GkzuqKtNJJCSh5FycWJOr49253_w926biJkSz9ic4TQ,4941
|
|
6
|
-
dreamer4-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.1.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.1.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|