dreamer4 0.1.2__tar.gz → 0.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dreamer4-0.1.2 → dreamer4-0.1.5}/PKG-INFO +30 -5
- {dreamer4-0.1.2 → dreamer4-0.1.5}/README.md +29 -4
- {dreamer4-0.1.2 → dreamer4-0.1.5}/dreamer4/dreamer4.py +24 -10
- {dreamer4-0.1.2 → dreamer4-0.1.5}/dreamer4/trainers.py +1 -1
- {dreamer4-0.1.2 → dreamer4-0.1.5}/pyproject.toml +1 -1
- {dreamer4-0.1.2 → dreamer4-0.1.5}/tests/test_dreamer.py +6 -0
- {dreamer4-0.1.2 → dreamer4-0.1.5}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.1.2 → dreamer4-0.1.5}/.github/workflows/test.yml +0 -0
- {dreamer4-0.1.2 → dreamer4-0.1.5}/.gitignore +0 -0
- {dreamer4-0.1.2 → dreamer4-0.1.5}/LICENSE +0 -0
- {dreamer4-0.1.2 → dreamer4-0.1.5}/dreamer4/__init__.py +0 -0
- {dreamer4-0.1.2 → dreamer4-0.1.5}/dreamer4/mocks.py +0 -0
- {dreamer4-0.1.2 → dreamer4-0.1.5}/dreamer4-fig2.png +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dreamer4
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.5
|
|
4
4
|
Summary: Dreamer 4
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/dreamer4/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/dreamer4
|
|
@@ -57,6 +57,12 @@ Description-Content-Type: text/markdown
|
|
|
57
57
|
|
|
58
58
|
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
|
59
59
|
|
|
60
|
+
[Discord channel](https://discord.gg/ab4BEk3W) for collaborating with other researchers interested in this work
|
|
61
|
+
|
|
62
|
+
## Appreciation
|
|
63
|
+
|
|
64
|
+
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
|
|
65
|
+
|
|
60
66
|
## Install
|
|
61
67
|
|
|
62
68
|
```bash
|
|
@@ -79,9 +85,16 @@ tokenizer = VideoTokenizer(
|
|
|
79
85
|
image_width = 256
|
|
80
86
|
)
|
|
81
87
|
|
|
88
|
+
video = torch.randn(2, 3, 10, 256, 256)
|
|
89
|
+
|
|
90
|
+
# learn the tokenizer
|
|
91
|
+
|
|
92
|
+
loss = tokenizer(video)
|
|
93
|
+
loss.backward() # ler
|
|
94
|
+
|
|
82
95
|
# dynamics world model
|
|
83
96
|
|
|
84
|
-
|
|
97
|
+
world_model = DynamicsWorldModel(
|
|
85
98
|
dim = 512,
|
|
86
99
|
dim_latent = 32,
|
|
87
100
|
video_tokenizer = tokenizer,
|
|
@@ -97,7 +110,7 @@ rewards = torch.randn(2, 10)
|
|
|
97
110
|
|
|
98
111
|
# learn dynamics / behavior cloned model
|
|
99
112
|
|
|
100
|
-
loss =
|
|
113
|
+
loss = world_model(
|
|
101
114
|
video = video,
|
|
102
115
|
rewards = rewards,
|
|
103
116
|
discrete_actions = discrete_actions
|
|
@@ -109,7 +122,7 @@ loss.backward()
|
|
|
109
122
|
|
|
110
123
|
# then generate dreams
|
|
111
124
|
|
|
112
|
-
dreams =
|
|
125
|
+
dreams = world_model.generate(
|
|
113
126
|
10,
|
|
114
127
|
batch_size = 2,
|
|
115
128
|
return_decoded_video = True,
|
|
@@ -118,7 +131,19 @@ dreams = dynamics.generate(
|
|
|
118
131
|
|
|
119
132
|
# learn from the dreams
|
|
120
133
|
|
|
121
|
-
actor_loss, critic_loss =
|
|
134
|
+
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
|
|
135
|
+
|
|
136
|
+
(actor_loss + critic_loss).backward()
|
|
137
|
+
|
|
138
|
+
# learn from environment
|
|
139
|
+
|
|
140
|
+
from dreamer4.mocks import MockEnv
|
|
141
|
+
|
|
142
|
+
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
|
|
143
|
+
|
|
144
|
+
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
|
|
145
|
+
|
|
146
|
+
actor_loss, critic_loss = world_model.learn_from_experience(experience)
|
|
122
147
|
|
|
123
148
|
(actor_loss + critic_loss).backward()
|
|
124
149
|
```
|
|
@@ -4,6 +4,12 @@
|
|
|
4
4
|
|
|
5
5
|
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
|
6
6
|
|
|
7
|
+
[Discord channel](https://discord.gg/ab4BEk3W) for collaborating with other researchers interested in this work
|
|
8
|
+
|
|
9
|
+
## Appreciation
|
|
10
|
+
|
|
11
|
+
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
|
|
12
|
+
|
|
7
13
|
## Install
|
|
8
14
|
|
|
9
15
|
```bash
|
|
@@ -26,9 +32,16 @@ tokenizer = VideoTokenizer(
|
|
|
26
32
|
image_width = 256
|
|
27
33
|
)
|
|
28
34
|
|
|
35
|
+
video = torch.randn(2, 3, 10, 256, 256)
|
|
36
|
+
|
|
37
|
+
# learn the tokenizer
|
|
38
|
+
|
|
39
|
+
loss = tokenizer(video)
|
|
40
|
+
loss.backward() # ler
|
|
41
|
+
|
|
29
42
|
# dynamics world model
|
|
30
43
|
|
|
31
|
-
|
|
44
|
+
world_model = DynamicsWorldModel(
|
|
32
45
|
dim = 512,
|
|
33
46
|
dim_latent = 32,
|
|
34
47
|
video_tokenizer = tokenizer,
|
|
@@ -44,7 +57,7 @@ rewards = torch.randn(2, 10)
|
|
|
44
57
|
|
|
45
58
|
# learn dynamics / behavior cloned model
|
|
46
59
|
|
|
47
|
-
loss =
|
|
60
|
+
loss = world_model(
|
|
48
61
|
video = video,
|
|
49
62
|
rewards = rewards,
|
|
50
63
|
discrete_actions = discrete_actions
|
|
@@ -56,7 +69,7 @@ loss.backward()
|
|
|
56
69
|
|
|
57
70
|
# then generate dreams
|
|
58
71
|
|
|
59
|
-
dreams =
|
|
72
|
+
dreams = world_model.generate(
|
|
60
73
|
10,
|
|
61
74
|
batch_size = 2,
|
|
62
75
|
return_decoded_video = True,
|
|
@@ -65,7 +78,19 @@ dreams = dynamics.generate(
|
|
|
65
78
|
|
|
66
79
|
# learn from the dreams
|
|
67
80
|
|
|
68
|
-
actor_loss, critic_loss =
|
|
81
|
+
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
|
|
82
|
+
|
|
83
|
+
(actor_loss + critic_loss).backward()
|
|
84
|
+
|
|
85
|
+
# learn from environment
|
|
86
|
+
|
|
87
|
+
from dreamer4.mocks import MockEnv
|
|
88
|
+
|
|
89
|
+
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
|
|
90
|
+
|
|
91
|
+
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
|
|
92
|
+
|
|
93
|
+
actor_loss, critic_loss = world_model.learn_from_experience(experience)
|
|
69
94
|
|
|
70
95
|
(actor_loss + critic_loss).backward()
|
|
71
96
|
```
|
|
@@ -14,7 +14,7 @@ from torch.nested import nested_tensor
|
|
|
14
14
|
from torch.distributions import Normal, kl
|
|
15
15
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
|
16
16
|
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
17
|
-
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
17
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
18
18
|
|
|
19
19
|
import torchvision
|
|
20
20
|
from torchvision.models import VGG16_Weights
|
|
@@ -91,6 +91,14 @@ class Experience:
|
|
|
91
91
|
agent_index: int = 0
|
|
92
92
|
is_from_world_model: bool = True
|
|
93
93
|
|
|
94
|
+
def cpu(self):
|
|
95
|
+
return self.to(torch.device('cpu'))
|
|
96
|
+
|
|
97
|
+
def to(self, device):
|
|
98
|
+
experience_dict = asdict(self)
|
|
99
|
+
experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
|
|
100
|
+
return Experience(**experience_dict)
|
|
101
|
+
|
|
94
102
|
def combine_experiences(
|
|
95
103
|
exps: list[Experiences]
|
|
96
104
|
) -> Experience:
|
|
@@ -1179,10 +1187,11 @@ def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = F
|
|
|
1179
1187
|
|
|
1180
1188
|
def block_mask_special_tokens_right(
|
|
1181
1189
|
seq_len,
|
|
1182
|
-
num_tokens
|
|
1190
|
+
num_tokens,
|
|
1191
|
+
special_attend_only_itself = False
|
|
1183
1192
|
):
|
|
1184
1193
|
def inner(b, h, q, k):
|
|
1185
|
-
return special_token_mask(q, k, seq_len, num_tokens)
|
|
1194
|
+
return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself)
|
|
1186
1195
|
return inner
|
|
1187
1196
|
|
|
1188
1197
|
def compose_mask(mask1, mask2):
|
|
@@ -1493,7 +1502,8 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1493
1502
|
|
|
1494
1503
|
# attend functions for space and time
|
|
1495
1504
|
|
|
1496
|
-
|
|
1505
|
+
has_kv_cache = exists(kv_cache)
|
|
1506
|
+
use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix
|
|
1497
1507
|
|
|
1498
1508
|
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
|
|
1499
1509
|
|
|
@@ -1505,7 +1515,6 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1505
1515
|
|
|
1506
1516
|
time_attn_kv_caches = []
|
|
1507
1517
|
|
|
1508
|
-
has_kv_cache = exists(kv_cache)
|
|
1509
1518
|
|
|
1510
1519
|
if has_kv_cache:
|
|
1511
1520
|
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
|
@@ -1847,7 +1856,7 @@ class VideoTokenizer(Module):
|
|
|
1847
1856
|
|
|
1848
1857
|
losses = (recon_loss, lpips_loss)
|
|
1849
1858
|
|
|
1850
|
-
return total_loss, TokenizerLosses(losses)
|
|
1859
|
+
return total_loss, TokenizerLosses(*losses)
|
|
1851
1860
|
|
|
1852
1861
|
# dynamics model, axial space-time transformer
|
|
1853
1862
|
|
|
@@ -2104,7 +2113,7 @@ class DynamicsWorldModel(Module):
|
|
|
2104
2113
|
|
|
2105
2114
|
self.ppo_eps_clip = ppo_eps_clip
|
|
2106
2115
|
self.value_clip = value_clip
|
|
2107
|
-
self.policy_entropy_weight =
|
|
2116
|
+
self.policy_entropy_weight = policy_entropy_weight
|
|
2108
2117
|
|
|
2109
2118
|
# pmpo related
|
|
2110
2119
|
|
|
@@ -2127,7 +2136,7 @@ class DynamicsWorldModel(Module):
|
|
|
2127
2136
|
self.flow_loss_normalizer = LossNormalizer(1)
|
|
2128
2137
|
self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
|
|
2129
2138
|
self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
|
2130
|
-
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if
|
|
2139
|
+
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None
|
|
2131
2140
|
|
|
2132
2141
|
self.latent_flow_loss_weight = latent_flow_loss_weight
|
|
2133
2142
|
|
|
@@ -2358,6 +2367,9 @@ class DynamicsWorldModel(Module):
|
|
|
2358
2367
|
elif len(env_step_out) == 4:
|
|
2359
2368
|
next_frame, reward, terminated, truncated = env_step_out
|
|
2360
2369
|
|
|
2370
|
+
elif len(env_step_out) == 5:
|
|
2371
|
+
next_frame, reward, terminated, truncated, info = env_step_out
|
|
2372
|
+
|
|
2361
2373
|
# update episode lens
|
|
2362
2374
|
|
|
2363
2375
|
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
|
@@ -2431,6 +2443,8 @@ class DynamicsWorldModel(Module):
|
|
|
2431
2443
|
):
|
|
2432
2444
|
assert isinstance(experience, Experience)
|
|
2433
2445
|
|
|
2446
|
+
experience = experience.to(self.device)
|
|
2447
|
+
|
|
2434
2448
|
latents = experience.latents
|
|
2435
2449
|
actions = experience.actions
|
|
2436
2450
|
old_log_probs = experience.log_probs
|
|
@@ -3085,8 +3099,8 @@ class DynamicsWorldModel(Module):
|
|
|
3085
3099
|
if latents.ndim == 4:
|
|
3086
3100
|
latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
|
|
3087
3101
|
|
|
3088
|
-
assert latents.shape[-2:] == self.latent_shape
|
|
3089
|
-
assert latents.shape[2] == self.num_video_views
|
|
3102
|
+
assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}'
|
|
3103
|
+
assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}'
|
|
3090
3104
|
|
|
3091
3105
|
# variables
|
|
3092
3106
|
|
|
@@ -680,6 +680,12 @@ def test_online_rl(
|
|
|
680
680
|
|
|
681
681
|
combined_experience = combine_experiences([one_experience, another_experience])
|
|
682
682
|
|
|
683
|
+
# quick test moving the experience to different devices
|
|
684
|
+
|
|
685
|
+
if torch.cuda.is_available():
|
|
686
|
+
combined_experience = combined_experience.to(torch.device('cuda'))
|
|
687
|
+
combined_experience = combined_experience.to(world_model_and_policy.device)
|
|
688
|
+
|
|
683
689
|
if store_agent_embed:
|
|
684
690
|
assert exists(combined_experience.agent_embed)
|
|
685
691
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|