dreamer4 0.0.99__tar.gz → 0.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dreamer4-0.0.99 → dreamer4-0.1.5}/PKG-INFO +94 -3
- dreamer4-0.1.5/README.md +112 -0
- {dreamer4-0.0.99 → dreamer4-0.1.5}/dreamer4/dreamer4.py +92 -32
- {dreamer4-0.0.99 → dreamer4-0.1.5}/dreamer4/trainers.py +1 -1
- {dreamer4-0.0.99 → dreamer4-0.1.5}/pyproject.toml +1 -1
- {dreamer4-0.0.99 → dreamer4-0.1.5}/tests/test_dreamer.py +6 -0
- dreamer4-0.0.99/README.md +0 -21
- {dreamer4-0.0.99 → dreamer4-0.1.5}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.0.99 → dreamer4-0.1.5}/.github/workflows/test.yml +0 -0
- {dreamer4-0.0.99 → dreamer4-0.1.5}/.gitignore +0 -0
- {dreamer4-0.0.99 → dreamer4-0.1.5}/LICENSE +0 -0
- {dreamer4-0.0.99 → dreamer4-0.1.5}/dreamer4/__init__.py +0 -0
- {dreamer4-0.0.99 → dreamer4-0.1.5}/dreamer4/mocks.py +0 -0
- {dreamer4-0.0.99 → dreamer4-0.1.5}/dreamer4-fig2.png +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dreamer4
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.1.5
|
|
4
4
|
Summary: Dreamer 4
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/dreamer4/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/dreamer4
|
|
@@ -53,11 +53,100 @@ Description-Content-Type: text/markdown
|
|
|
53
53
|
|
|
54
54
|
<img src="./dreamer4-fig2.png" width="400px"></img>
|
|
55
55
|
|
|
56
|
-
## Dreamer 4
|
|
56
|
+
## Dreamer 4
|
|
57
57
|
|
|
58
58
|
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
|
59
59
|
|
|
60
|
-
[
|
|
60
|
+
[Discord channel](https://discord.gg/ab4BEk3W) for collaborating with other researchers interested in this work
|
|
61
|
+
|
|
62
|
+
## Appreciation
|
|
63
|
+
|
|
64
|
+
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
|
|
65
|
+
|
|
66
|
+
## Install
|
|
67
|
+
|
|
68
|
+
```bash
|
|
69
|
+
$ pip install dreamer4
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
## Usage
|
|
73
|
+
|
|
74
|
+
```python
|
|
75
|
+
import torch
|
|
76
|
+
from dreamer4 import VideoTokenizer, DynamicsWorldModel
|
|
77
|
+
|
|
78
|
+
# video tokenizer, learned through MAE + lpips
|
|
79
|
+
|
|
80
|
+
tokenizer = VideoTokenizer(
|
|
81
|
+
dim = 512,
|
|
82
|
+
dim_latent = 32,
|
|
83
|
+
patch_size = 32,
|
|
84
|
+
image_height = 256,
|
|
85
|
+
image_width = 256
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
video = torch.randn(2, 3, 10, 256, 256)
|
|
89
|
+
|
|
90
|
+
# learn the tokenizer
|
|
91
|
+
|
|
92
|
+
loss = tokenizer(video)
|
|
93
|
+
loss.backward() # ler
|
|
94
|
+
|
|
95
|
+
# dynamics world model
|
|
96
|
+
|
|
97
|
+
world_model = DynamicsWorldModel(
|
|
98
|
+
dim = 512,
|
|
99
|
+
dim_latent = 32,
|
|
100
|
+
video_tokenizer = tokenizer,
|
|
101
|
+
num_discrete_actions = 4,
|
|
102
|
+
num_residual_streams = 1
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# state, action, rewards
|
|
106
|
+
|
|
107
|
+
video = torch.randn(2, 3, 10, 256, 256)
|
|
108
|
+
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
|
109
|
+
rewards = torch.randn(2, 10)
|
|
110
|
+
|
|
111
|
+
# learn dynamics / behavior cloned model
|
|
112
|
+
|
|
113
|
+
loss = world_model(
|
|
114
|
+
video = video,
|
|
115
|
+
rewards = rewards,
|
|
116
|
+
discrete_actions = discrete_actions
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
loss.backward()
|
|
120
|
+
|
|
121
|
+
# do the above with much data
|
|
122
|
+
|
|
123
|
+
# then generate dreams
|
|
124
|
+
|
|
125
|
+
dreams = world_model.generate(
|
|
126
|
+
10,
|
|
127
|
+
batch_size = 2,
|
|
128
|
+
return_decoded_video = True,
|
|
129
|
+
return_for_policy_optimization = True
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# learn from the dreams
|
|
133
|
+
|
|
134
|
+
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
|
|
135
|
+
|
|
136
|
+
(actor_loss + critic_loss).backward()
|
|
137
|
+
|
|
138
|
+
# learn from environment
|
|
139
|
+
|
|
140
|
+
from dreamer4.mocks import MockEnv
|
|
141
|
+
|
|
142
|
+
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
|
|
143
|
+
|
|
144
|
+
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
|
|
145
|
+
|
|
146
|
+
actor_loss, critic_loss = world_model.learn_from_experience(experience)
|
|
147
|
+
|
|
148
|
+
(actor_loss + critic_loss).backward()
|
|
149
|
+
```
|
|
61
150
|
|
|
62
151
|
## Citation
|
|
63
152
|
|
|
@@ -72,3 +161,5 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
|
|
|
72
161
|
url = {https://arxiv.org/abs/2509.24527},
|
|
73
162
|
}
|
|
74
163
|
```
|
|
164
|
+
|
|
165
|
+
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*
|
dreamer4-0.1.5/README.md
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
<img src="./dreamer4-fig2.png" width="400px"></img>
|
|
2
|
+
|
|
3
|
+
## Dreamer 4
|
|
4
|
+
|
|
5
|
+
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
|
6
|
+
|
|
7
|
+
[Discord channel](https://discord.gg/ab4BEk3W) for collaborating with other researchers interested in this work
|
|
8
|
+
|
|
9
|
+
## Appreciation
|
|
10
|
+
|
|
11
|
+
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
|
|
12
|
+
|
|
13
|
+
## Install
|
|
14
|
+
|
|
15
|
+
```bash
|
|
16
|
+
$ pip install dreamer4
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
## Usage
|
|
20
|
+
|
|
21
|
+
```python
|
|
22
|
+
import torch
|
|
23
|
+
from dreamer4 import VideoTokenizer, DynamicsWorldModel
|
|
24
|
+
|
|
25
|
+
# video tokenizer, learned through MAE + lpips
|
|
26
|
+
|
|
27
|
+
tokenizer = VideoTokenizer(
|
|
28
|
+
dim = 512,
|
|
29
|
+
dim_latent = 32,
|
|
30
|
+
patch_size = 32,
|
|
31
|
+
image_height = 256,
|
|
32
|
+
image_width = 256
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
video = torch.randn(2, 3, 10, 256, 256)
|
|
36
|
+
|
|
37
|
+
# learn the tokenizer
|
|
38
|
+
|
|
39
|
+
loss = tokenizer(video)
|
|
40
|
+
loss.backward() # ler
|
|
41
|
+
|
|
42
|
+
# dynamics world model
|
|
43
|
+
|
|
44
|
+
world_model = DynamicsWorldModel(
|
|
45
|
+
dim = 512,
|
|
46
|
+
dim_latent = 32,
|
|
47
|
+
video_tokenizer = tokenizer,
|
|
48
|
+
num_discrete_actions = 4,
|
|
49
|
+
num_residual_streams = 1
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# state, action, rewards
|
|
53
|
+
|
|
54
|
+
video = torch.randn(2, 3, 10, 256, 256)
|
|
55
|
+
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
|
56
|
+
rewards = torch.randn(2, 10)
|
|
57
|
+
|
|
58
|
+
# learn dynamics / behavior cloned model
|
|
59
|
+
|
|
60
|
+
loss = world_model(
|
|
61
|
+
video = video,
|
|
62
|
+
rewards = rewards,
|
|
63
|
+
discrete_actions = discrete_actions
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
loss.backward()
|
|
67
|
+
|
|
68
|
+
# do the above with much data
|
|
69
|
+
|
|
70
|
+
# then generate dreams
|
|
71
|
+
|
|
72
|
+
dreams = world_model.generate(
|
|
73
|
+
10,
|
|
74
|
+
batch_size = 2,
|
|
75
|
+
return_decoded_video = True,
|
|
76
|
+
return_for_policy_optimization = True
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# learn from the dreams
|
|
80
|
+
|
|
81
|
+
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
|
|
82
|
+
|
|
83
|
+
(actor_loss + critic_loss).backward()
|
|
84
|
+
|
|
85
|
+
# learn from environment
|
|
86
|
+
|
|
87
|
+
from dreamer4.mocks import MockEnv
|
|
88
|
+
|
|
89
|
+
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
|
|
90
|
+
|
|
91
|
+
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
|
|
92
|
+
|
|
93
|
+
actor_loss, critic_loss = world_model.learn_from_experience(experience)
|
|
94
|
+
|
|
95
|
+
(actor_loss + critic_loss).backward()
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
## Citation
|
|
99
|
+
|
|
100
|
+
```bibtex
|
|
101
|
+
@misc{hafner2025trainingagentsinsidescalable,
|
|
102
|
+
title = {Training Agents Inside of Scalable World Models},
|
|
103
|
+
author = {Danijar Hafner and Wilson Yan and Timothy Lillicrap},
|
|
104
|
+
year = {2025},
|
|
105
|
+
eprint = {2509.24527},
|
|
106
|
+
archivePrefix = {arXiv},
|
|
107
|
+
primaryClass = {cs.AI},
|
|
108
|
+
url = {https://arxiv.org/abs/2509.24527},
|
|
109
|
+
}
|
|
110
|
+
```
|
|
111
|
+
|
|
112
|
+
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*
|
|
@@ -14,7 +14,7 @@ from torch.nested import nested_tensor
|
|
|
14
14
|
from torch.distributions import Normal, kl
|
|
15
15
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
|
16
16
|
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
17
|
-
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
17
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
18
18
|
|
|
19
19
|
import torchvision
|
|
20
20
|
from torchvision.models import VGG16_Weights
|
|
@@ -91,6 +91,14 @@ class Experience:
|
|
|
91
91
|
agent_index: int = 0
|
|
92
92
|
is_from_world_model: bool = True
|
|
93
93
|
|
|
94
|
+
def cpu(self):
|
|
95
|
+
return self.to(torch.device('cpu'))
|
|
96
|
+
|
|
97
|
+
def to(self, device):
|
|
98
|
+
experience_dict = asdict(self)
|
|
99
|
+
experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
|
|
100
|
+
return Experience(**experience_dict)
|
|
101
|
+
|
|
94
102
|
def combine_experiences(
|
|
95
103
|
exps: list[Experiences]
|
|
96
104
|
) -> Experience:
|
|
@@ -1179,10 +1187,11 @@ def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = F
|
|
|
1179
1187
|
|
|
1180
1188
|
def block_mask_special_tokens_right(
|
|
1181
1189
|
seq_len,
|
|
1182
|
-
num_tokens
|
|
1190
|
+
num_tokens,
|
|
1191
|
+
special_attend_only_itself = False
|
|
1183
1192
|
):
|
|
1184
1193
|
def inner(b, h, q, k):
|
|
1185
|
-
return special_token_mask(q, k, seq_len, num_tokens)
|
|
1194
|
+
return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself)
|
|
1186
1195
|
return inner
|
|
1187
1196
|
|
|
1188
1197
|
def compose_mask(mask1, mask2):
|
|
@@ -1331,6 +1340,12 @@ class Attention(Module):
|
|
|
1331
1340
|
q = self.q_heads_rmsnorm(q)
|
|
1332
1341
|
k = self.k_heads_rmsnorm(k)
|
|
1333
1342
|
|
|
1343
|
+
# rotary
|
|
1344
|
+
|
|
1345
|
+
if exists(rotary_pos_emb):
|
|
1346
|
+
q = apply_rotations(rotary_pos_emb, q)
|
|
1347
|
+
k = apply_rotations(rotary_pos_emb, k)
|
|
1348
|
+
|
|
1334
1349
|
# caching
|
|
1335
1350
|
|
|
1336
1351
|
if exists(kv_cache):
|
|
@@ -1338,12 +1353,6 @@ class Attention(Module):
|
|
|
1338
1353
|
k = cat((ck, k), dim = -2)
|
|
1339
1354
|
v = cat((cv, v), dim = -2)
|
|
1340
1355
|
|
|
1341
|
-
# rotary
|
|
1342
|
-
|
|
1343
|
-
if exists(rotary_pos_emb):
|
|
1344
|
-
q = apply_rotations(rotary_pos_emb, q)
|
|
1345
|
-
k = apply_rotations(rotary_pos_emb, k)
|
|
1346
|
-
|
|
1347
1356
|
# attention
|
|
1348
1357
|
|
|
1349
1358
|
attend_fn = default(attend_fn, naive_attend)
|
|
@@ -1493,7 +1502,8 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1493
1502
|
|
|
1494
1503
|
# attend functions for space and time
|
|
1495
1504
|
|
|
1496
|
-
|
|
1505
|
+
has_kv_cache = exists(kv_cache)
|
|
1506
|
+
use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix
|
|
1497
1507
|
|
|
1498
1508
|
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
|
|
1499
1509
|
|
|
@@ -1505,14 +1515,12 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1505
1515
|
|
|
1506
1516
|
time_attn_kv_caches = []
|
|
1507
1517
|
|
|
1508
|
-
has_kv_cache = exists(kv_cache)
|
|
1509
|
-
|
|
1510
1518
|
|
|
1511
1519
|
if has_kv_cache:
|
|
1512
1520
|
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
|
1513
1521
|
|
|
1514
1522
|
rotary_seq_len = 1
|
|
1515
|
-
rotary_pos_offset = past_tokens.shape[
|
|
1523
|
+
rotary_pos_offset = past_tokens.shape[1]
|
|
1516
1524
|
else:
|
|
1517
1525
|
rotary_seq_len = time
|
|
1518
1526
|
rotary_pos_offset = 0
|
|
@@ -1687,6 +1695,7 @@ class VideoTokenizer(Module):
|
|
|
1687
1695
|
time_block_every = time_block_every,
|
|
1688
1696
|
num_special_spatial_tokens = num_latent_tokens,
|
|
1689
1697
|
num_residual_streams = num_residual_streams,
|
|
1698
|
+
special_attend_only_itself = True,
|
|
1690
1699
|
final_norm = True
|
|
1691
1700
|
)
|
|
1692
1701
|
|
|
@@ -1847,7 +1856,7 @@ class VideoTokenizer(Module):
|
|
|
1847
1856
|
|
|
1848
1857
|
losses = (recon_loss, lpips_loss)
|
|
1849
1858
|
|
|
1850
|
-
return total_loss, TokenizerLosses(losses)
|
|
1859
|
+
return total_loss, TokenizerLosses(*losses)
|
|
1851
1860
|
|
|
1852
1861
|
# dynamics model, axial space-time transformer
|
|
1853
1862
|
|
|
@@ -1900,7 +1909,9 @@ class DynamicsWorldModel(Module):
|
|
|
1900
1909
|
gae_lambda = 0.95,
|
|
1901
1910
|
ppo_eps_clip = 0.2,
|
|
1902
1911
|
pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
|
|
1903
|
-
|
|
1912
|
+
pmpo_reverse_kl = True,
|
|
1913
|
+
pmpo_kl_div_loss_weight = .3,
|
|
1914
|
+
normalize_advantages = None,
|
|
1904
1915
|
value_clip = 0.4,
|
|
1905
1916
|
policy_entropy_weight = .01,
|
|
1906
1917
|
gae_use_accelerated = False
|
|
@@ -2102,12 +2113,13 @@ class DynamicsWorldModel(Module):
|
|
|
2102
2113
|
|
|
2103
2114
|
self.ppo_eps_clip = ppo_eps_clip
|
|
2104
2115
|
self.value_clip = value_clip
|
|
2105
|
-
self.policy_entropy_weight =
|
|
2116
|
+
self.policy_entropy_weight = policy_entropy_weight
|
|
2106
2117
|
|
|
2107
2118
|
# pmpo related
|
|
2108
2119
|
|
|
2109
2120
|
self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
|
|
2110
2121
|
self.pmpo_kl_div_loss_weight = pmpo_kl_div_loss_weight
|
|
2122
|
+
self.pmpo_reverse_kl = pmpo_reverse_kl
|
|
2111
2123
|
|
|
2112
2124
|
# rewards related
|
|
2113
2125
|
|
|
@@ -2124,7 +2136,7 @@ class DynamicsWorldModel(Module):
|
|
|
2124
2136
|
self.flow_loss_normalizer = LossNormalizer(1)
|
|
2125
2137
|
self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
|
|
2126
2138
|
self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
|
2127
|
-
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if
|
|
2139
|
+
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None
|
|
2128
2140
|
|
|
2129
2141
|
self.latent_flow_loss_weight = latent_flow_loss_weight
|
|
2130
2142
|
|
|
@@ -2355,6 +2367,9 @@ class DynamicsWorldModel(Module):
|
|
|
2355
2367
|
elif len(env_step_out) == 4:
|
|
2356
2368
|
next_frame, reward, terminated, truncated = env_step_out
|
|
2357
2369
|
|
|
2370
|
+
elif len(env_step_out) == 5:
|
|
2371
|
+
next_frame, reward, terminated, truncated, info = env_step_out
|
|
2372
|
+
|
|
2358
2373
|
# update episode lens
|
|
2359
2374
|
|
|
2360
2375
|
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
|
@@ -2423,8 +2438,12 @@ class DynamicsWorldModel(Module):
|
|
|
2423
2438
|
value_optim: Optimizer | None = None,
|
|
2424
2439
|
only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
|
2425
2440
|
use_pmpo = True,
|
|
2441
|
+
normalize_advantages = None,
|
|
2426
2442
|
eps = 1e-6
|
|
2427
2443
|
):
|
|
2444
|
+
assert isinstance(experience, Experience)
|
|
2445
|
+
|
|
2446
|
+
experience = experience.to(self.device)
|
|
2428
2447
|
|
|
2429
2448
|
latents = experience.latents
|
|
2430
2449
|
actions = experience.actions
|
|
@@ -2437,7 +2456,7 @@ class DynamicsWorldModel(Module):
|
|
|
2437
2456
|
step_size = experience.step_size
|
|
2438
2457
|
agent_index = experience.agent_index
|
|
2439
2458
|
|
|
2440
|
-
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
|
|
2459
|
+
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization - world_model.generate(..., return_log_probs_and_values = True)'
|
|
2441
2460
|
|
|
2442
2461
|
batch, time = latents.shape[0], latents.shape[1]
|
|
2443
2462
|
|
|
@@ -2451,8 +2470,8 @@ class DynamicsWorldModel(Module):
|
|
|
2451
2470
|
if exists(experience.lens):
|
|
2452
2471
|
mask_for_gae = lens_to_mask(experience.lens, time)
|
|
2453
2472
|
|
|
2454
|
-
rewards = rewards.masked_fill(mask_for_gae, 0.)
|
|
2455
|
-
old_values = old_values.masked_fill(mask_for_gae, 0.)
|
|
2473
|
+
rewards = rewards.masked_fill(~mask_for_gae, 0.)
|
|
2474
|
+
old_values = old_values.masked_fill(~mask_for_gae, 0.)
|
|
2456
2475
|
|
|
2457
2476
|
# calculate returns
|
|
2458
2477
|
|
|
@@ -2487,7 +2506,7 @@ class DynamicsWorldModel(Module):
|
|
|
2487
2506
|
|
|
2488
2507
|
# mean, var - todo - handle distributed
|
|
2489
2508
|
|
|
2490
|
-
returns_mean, returns_var =
|
|
2509
|
+
returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var()
|
|
2491
2510
|
|
|
2492
2511
|
# ema
|
|
2493
2512
|
|
|
@@ -2505,16 +2524,19 @@ class DynamicsWorldModel(Module):
|
|
|
2505
2524
|
else:
|
|
2506
2525
|
advantage = returns - old_values
|
|
2507
2526
|
|
|
2508
|
-
#
|
|
2527
|
+
# if using pmpo, do not normalize advantages, but can be overridden
|
|
2528
|
+
|
|
2529
|
+
normalize_advantages = default(normalize_advantages, not use_pmpo)
|
|
2530
|
+
|
|
2531
|
+
if normalize_advantages:
|
|
2532
|
+
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
|
2533
|
+
|
|
2509
2534
|
# https://arxiv.org/abs/2410.04166v1
|
|
2510
2535
|
|
|
2511
2536
|
if use_pmpo:
|
|
2512
2537
|
pos_advantage_mask = advantage >= 0.
|
|
2513
2538
|
neg_advantage_mask = ~pos_advantage_mask
|
|
2514
2539
|
|
|
2515
|
-
else:
|
|
2516
|
-
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
|
2517
|
-
|
|
2518
2540
|
# replay for the action logits and values
|
|
2519
2541
|
# but only do so if fine tuning the entire world model for RL
|
|
2520
2542
|
|
|
@@ -2578,11 +2600,18 @@ class DynamicsWorldModel(Module):
|
|
|
2578
2600
|
# take care of kl
|
|
2579
2601
|
|
|
2580
2602
|
if self.pmpo_kl_div_loss_weight > 0.:
|
|
2603
|
+
|
|
2581
2604
|
new_unembedded_actions = self.action_embedder.unembed(policy_embed, pred_head_index = 0)
|
|
2582
2605
|
|
|
2606
|
+
kl_div_inputs, kl_div_targets = new_unembedded_actions, old_action_unembeds
|
|
2607
|
+
|
|
2583
2608
|
# mentioned that the "reverse direction for the prior KL" was used
|
|
2609
|
+
# make optional, as observed instability in toy task
|
|
2610
|
+
|
|
2611
|
+
if self.pmpo_reverse_kl:
|
|
2612
|
+
kl_div_inputs, kl_div_targets = kl_div_targets, kl_div_inputs
|
|
2584
2613
|
|
|
2585
|
-
discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(
|
|
2614
|
+
discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(kl_div_inputs, kl_div_targets)
|
|
2586
2615
|
|
|
2587
2616
|
# accumulate discrete and continuous kl div
|
|
2588
2617
|
|
|
@@ -2680,12 +2709,22 @@ class DynamicsWorldModel(Module):
|
|
|
2680
2709
|
return_rewards_per_frame = False,
|
|
2681
2710
|
return_agent_actions = False,
|
|
2682
2711
|
return_log_probs_and_values = False,
|
|
2712
|
+
return_for_policy_optimization = False,
|
|
2683
2713
|
return_time_kv_cache = False,
|
|
2684
2714
|
store_agent_embed = True,
|
|
2685
2715
|
store_old_action_unembeds = True
|
|
2686
2716
|
|
|
2687
2717
|
): # (b t n d) | (b c t h w)
|
|
2688
2718
|
|
|
2719
|
+
# handy flag for returning generations for rl
|
|
2720
|
+
|
|
2721
|
+
if return_for_policy_optimization:
|
|
2722
|
+
return_agent_actions |= True
|
|
2723
|
+
return_log_probs_and_values |= True
|
|
2724
|
+
return_rewards_per_frame |= True
|
|
2725
|
+
|
|
2726
|
+
# more variables
|
|
2727
|
+
|
|
2689
2728
|
has_proprio = self.has_proprio
|
|
2690
2729
|
was_training = self.training
|
|
2691
2730
|
self.eval()
|
|
@@ -2755,6 +2794,19 @@ class DynamicsWorldModel(Module):
|
|
|
2755
2794
|
|
|
2756
2795
|
curr_time_steps = latents.shape[1]
|
|
2757
2796
|
|
|
2797
|
+
# determine whether to take an extra step if
|
|
2798
|
+
# (1) using time kv cache
|
|
2799
|
+
# (2) decoding anything off agent embedding (rewards, actions, etc)
|
|
2800
|
+
|
|
2801
|
+
take_extra_step = (
|
|
2802
|
+
use_time_kv_cache or
|
|
2803
|
+
return_rewards_per_frame or
|
|
2804
|
+
store_agent_embed or
|
|
2805
|
+
return_agent_actions
|
|
2806
|
+
)
|
|
2807
|
+
|
|
2808
|
+
# prepare noised latent / proprio inputs
|
|
2809
|
+
|
|
2758
2810
|
noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
|
|
2759
2811
|
|
|
2760
2812
|
noised_proprio = None
|
|
@@ -2762,7 +2814,10 @@ class DynamicsWorldModel(Module):
|
|
|
2762
2814
|
if has_proprio:
|
|
2763
2815
|
noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device)
|
|
2764
2816
|
|
|
2765
|
-
|
|
2817
|
+
# denoising steps
|
|
2818
|
+
|
|
2819
|
+
for step in range(num_steps + int(take_extra_step)):
|
|
2820
|
+
|
|
2766
2821
|
is_last_step = (step + 1) == num_steps
|
|
2767
2822
|
|
|
2768
2823
|
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
|
@@ -2805,6 +2860,11 @@ class DynamicsWorldModel(Module):
|
|
|
2805
2860
|
if use_time_kv_cache and is_last_step:
|
|
2806
2861
|
time_kv_cache = next_time_kv_cache
|
|
2807
2862
|
|
|
2863
|
+
# early break if taking an extra step for agent embedding off cleaned latents for decoding
|
|
2864
|
+
|
|
2865
|
+
if take_extra_step and is_last_step:
|
|
2866
|
+
break
|
|
2867
|
+
|
|
2808
2868
|
# maybe proprio
|
|
2809
2869
|
|
|
2810
2870
|
if has_proprio:
|
|
@@ -3007,7 +3067,7 @@ class DynamicsWorldModel(Module):
|
|
|
3007
3067
|
latent_is_noised = False,
|
|
3008
3068
|
return_all_losses = False,
|
|
3009
3069
|
return_intermediates = False,
|
|
3010
|
-
add_autoregressive_action_loss =
|
|
3070
|
+
add_autoregressive_action_loss = True,
|
|
3011
3071
|
update_loss_ema = None,
|
|
3012
3072
|
latent_has_view_dim = False
|
|
3013
3073
|
):
|
|
@@ -3039,8 +3099,8 @@ class DynamicsWorldModel(Module):
|
|
|
3039
3099
|
if latents.ndim == 4:
|
|
3040
3100
|
latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
|
|
3041
3101
|
|
|
3042
|
-
assert latents.shape[-2:] == self.latent_shape
|
|
3043
|
-
assert latents.shape[2] == self.num_video_views
|
|
3102
|
+
assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}'
|
|
3103
|
+
assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}'
|
|
3044
3104
|
|
|
3045
3105
|
# variables
|
|
3046
3106
|
|
|
@@ -3464,7 +3524,7 @@ class DynamicsWorldModel(Module):
|
|
|
3464
3524
|
|
|
3465
3525
|
reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
|
|
3466
3526
|
|
|
3467
|
-
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
|
3527
|
+
reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.)
|
|
3468
3528
|
|
|
3469
3529
|
if is_var_len:
|
|
3470
3530
|
reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
|
|
@@ -3508,7 +3568,7 @@ class DynamicsWorldModel(Module):
|
|
|
3508
3568
|
discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
|
|
3509
3569
|
|
|
3510
3570
|
if exists(continuous_actions):
|
|
3511
|
-
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(
|
|
3571
|
+
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(continuous_actions, self.multi_token_pred_len)
|
|
3512
3572
|
continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
|
|
3513
3573
|
continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')
|
|
3514
3574
|
|
|
@@ -680,6 +680,12 @@ def test_online_rl(
|
|
|
680
680
|
|
|
681
681
|
combined_experience = combine_experiences([one_experience, another_experience])
|
|
682
682
|
|
|
683
|
+
# quick test moving the experience to different devices
|
|
684
|
+
|
|
685
|
+
if torch.cuda.is_available():
|
|
686
|
+
combined_experience = combined_experience.to(torch.device('cuda'))
|
|
687
|
+
combined_experience = combined_experience.to(world_model_and_policy.device)
|
|
688
|
+
|
|
683
689
|
if store_agent_embed:
|
|
684
690
|
assert exists(combined_experience.agent_embed)
|
|
685
691
|
|
dreamer4-0.0.99/README.md
DELETED
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
<img src="./dreamer4-fig2.png" width="400px"></img>
|
|
2
|
-
|
|
3
|
-
## Dreamer 4 (wip)
|
|
4
|
-
|
|
5
|
-
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
|
6
|
-
|
|
7
|
-
[Temporary Discord](https://discord.gg/MkACrrkrYR)
|
|
8
|
-
|
|
9
|
-
## Citation
|
|
10
|
-
|
|
11
|
-
```bibtex
|
|
12
|
-
@misc{hafner2025trainingagentsinsidescalable,
|
|
13
|
-
title = {Training Agents Inside of Scalable World Models},
|
|
14
|
-
author = {Danijar Hafner and Wilson Yan and Timothy Lillicrap},
|
|
15
|
-
year = {2025},
|
|
16
|
-
eprint = {2509.24527},
|
|
17
|
-
archivePrefix = {arXiv},
|
|
18
|
-
primaryClass = {cs.AI},
|
|
19
|
-
url = {https://arxiv.org/abs/2509.24527},
|
|
20
|
-
}
|
|
21
|
-
```
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|