dreamer4 0.1.2__tar.gz → 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- {dreamer4-0.1.2 → dreamer4-0.1.4}/PKG-INFO +30 -5
- {dreamer4-0.1.2 → dreamer4-0.1.4}/README.md +29 -4
- {dreamer4-0.1.2 → dreamer4-0.1.4}/dreamer4/dreamer4.py +13 -9
- {dreamer4-0.1.2 → dreamer4-0.1.4}/pyproject.toml +1 -1
- {dreamer4-0.1.2 → dreamer4-0.1.4}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.1.2 → dreamer4-0.1.4}/.github/workflows/test.yml +0 -0
- {dreamer4-0.1.2 → dreamer4-0.1.4}/.gitignore +0 -0
- {dreamer4-0.1.2 → dreamer4-0.1.4}/LICENSE +0 -0
- {dreamer4-0.1.2 → dreamer4-0.1.4}/dreamer4/__init__.py +0 -0
- {dreamer4-0.1.2 → dreamer4-0.1.4}/dreamer4/mocks.py +0 -0
- {dreamer4-0.1.2 → dreamer4-0.1.4}/dreamer4/trainers.py +0 -0
- {dreamer4-0.1.2 → dreamer4-0.1.4}/dreamer4-fig2.png +0 -0
- {dreamer4-0.1.2 → dreamer4-0.1.4}/tests/test_dreamer.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dreamer4
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.4
|
|
4
4
|
Summary: Dreamer 4
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/dreamer4/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/dreamer4
|
|
@@ -57,6 +57,12 @@ Description-Content-Type: text/markdown
|
|
|
57
57
|
|
|
58
58
|
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
|
59
59
|
|
|
60
|
+
[Discord channel](https://discord.gg/ab4BEk3W) for collaborating with other researchers interested in this work
|
|
61
|
+
|
|
62
|
+
## Appreciation
|
|
63
|
+
|
|
64
|
+
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
|
|
65
|
+
|
|
60
66
|
## Install
|
|
61
67
|
|
|
62
68
|
```bash
|
|
@@ -79,9 +85,16 @@ tokenizer = VideoTokenizer(
|
|
|
79
85
|
image_width = 256
|
|
80
86
|
)
|
|
81
87
|
|
|
88
|
+
video = torch.randn(2, 3, 10, 256, 256)
|
|
89
|
+
|
|
90
|
+
# learn the tokenizer
|
|
91
|
+
|
|
92
|
+
loss = tokenizer(video)
|
|
93
|
+
loss.backward() # ler
|
|
94
|
+
|
|
82
95
|
# dynamics world model
|
|
83
96
|
|
|
84
|
-
|
|
97
|
+
world_model = DynamicsWorldModel(
|
|
85
98
|
dim = 512,
|
|
86
99
|
dim_latent = 32,
|
|
87
100
|
video_tokenizer = tokenizer,
|
|
@@ -97,7 +110,7 @@ rewards = torch.randn(2, 10)
|
|
|
97
110
|
|
|
98
111
|
# learn dynamics / behavior cloned model
|
|
99
112
|
|
|
100
|
-
loss =
|
|
113
|
+
loss = world_model(
|
|
101
114
|
video = video,
|
|
102
115
|
rewards = rewards,
|
|
103
116
|
discrete_actions = discrete_actions
|
|
@@ -109,7 +122,7 @@ loss.backward()
|
|
|
109
122
|
|
|
110
123
|
# then generate dreams
|
|
111
124
|
|
|
112
|
-
dreams =
|
|
125
|
+
dreams = world_model.generate(
|
|
113
126
|
10,
|
|
114
127
|
batch_size = 2,
|
|
115
128
|
return_decoded_video = True,
|
|
@@ -118,7 +131,19 @@ dreams = dynamics.generate(
|
|
|
118
131
|
|
|
119
132
|
# learn from the dreams
|
|
120
133
|
|
|
121
|
-
actor_loss, critic_loss =
|
|
134
|
+
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
|
|
135
|
+
|
|
136
|
+
(actor_loss + critic_loss).backward()
|
|
137
|
+
|
|
138
|
+
# learn from environment
|
|
139
|
+
|
|
140
|
+
from dreamer4.mocks import MockEnv
|
|
141
|
+
|
|
142
|
+
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
|
|
143
|
+
|
|
144
|
+
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
|
|
145
|
+
|
|
146
|
+
actor_loss, critic_loss = world_model.learn_from_experience(experience)
|
|
122
147
|
|
|
123
148
|
(actor_loss + critic_loss).backward()
|
|
124
149
|
```
|
|
@@ -4,6 +4,12 @@
|
|
|
4
4
|
|
|
5
5
|
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
|
6
6
|
|
|
7
|
+
[Discord channel](https://discord.gg/ab4BEk3W) for collaborating with other researchers interested in this work
|
|
8
|
+
|
|
9
|
+
## Appreciation
|
|
10
|
+
|
|
11
|
+
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
|
|
12
|
+
|
|
7
13
|
## Install
|
|
8
14
|
|
|
9
15
|
```bash
|
|
@@ -26,9 +32,16 @@ tokenizer = VideoTokenizer(
|
|
|
26
32
|
image_width = 256
|
|
27
33
|
)
|
|
28
34
|
|
|
35
|
+
video = torch.randn(2, 3, 10, 256, 256)
|
|
36
|
+
|
|
37
|
+
# learn the tokenizer
|
|
38
|
+
|
|
39
|
+
loss = tokenizer(video)
|
|
40
|
+
loss.backward() # ler
|
|
41
|
+
|
|
29
42
|
# dynamics world model
|
|
30
43
|
|
|
31
|
-
|
|
44
|
+
world_model = DynamicsWorldModel(
|
|
32
45
|
dim = 512,
|
|
33
46
|
dim_latent = 32,
|
|
34
47
|
video_tokenizer = tokenizer,
|
|
@@ -44,7 +57,7 @@ rewards = torch.randn(2, 10)
|
|
|
44
57
|
|
|
45
58
|
# learn dynamics / behavior cloned model
|
|
46
59
|
|
|
47
|
-
loss =
|
|
60
|
+
loss = world_model(
|
|
48
61
|
video = video,
|
|
49
62
|
rewards = rewards,
|
|
50
63
|
discrete_actions = discrete_actions
|
|
@@ -56,7 +69,7 @@ loss.backward()
|
|
|
56
69
|
|
|
57
70
|
# then generate dreams
|
|
58
71
|
|
|
59
|
-
dreams =
|
|
72
|
+
dreams = world_model.generate(
|
|
60
73
|
10,
|
|
61
74
|
batch_size = 2,
|
|
62
75
|
return_decoded_video = True,
|
|
@@ -65,7 +78,19 @@ dreams = dynamics.generate(
|
|
|
65
78
|
|
|
66
79
|
# learn from the dreams
|
|
67
80
|
|
|
68
|
-
actor_loss, critic_loss =
|
|
81
|
+
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
|
|
82
|
+
|
|
83
|
+
(actor_loss + critic_loss).backward()
|
|
84
|
+
|
|
85
|
+
# learn from environment
|
|
86
|
+
|
|
87
|
+
from dreamer4.mocks import MockEnv
|
|
88
|
+
|
|
89
|
+
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
|
|
90
|
+
|
|
91
|
+
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
|
|
92
|
+
|
|
93
|
+
actor_loss, critic_loss = world_model.learn_from_experience(experience)
|
|
69
94
|
|
|
70
95
|
(actor_loss + critic_loss).backward()
|
|
71
96
|
```
|
|
@@ -1179,10 +1179,11 @@ def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = F
|
|
|
1179
1179
|
|
|
1180
1180
|
def block_mask_special_tokens_right(
|
|
1181
1181
|
seq_len,
|
|
1182
|
-
num_tokens
|
|
1182
|
+
num_tokens,
|
|
1183
|
+
special_attend_only_itself = False
|
|
1183
1184
|
):
|
|
1184
1185
|
def inner(b, h, q, k):
|
|
1185
|
-
return special_token_mask(q, k, seq_len, num_tokens)
|
|
1186
|
+
return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself)
|
|
1186
1187
|
return inner
|
|
1187
1188
|
|
|
1188
1189
|
def compose_mask(mask1, mask2):
|
|
@@ -1493,7 +1494,8 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1493
1494
|
|
|
1494
1495
|
# attend functions for space and time
|
|
1495
1496
|
|
|
1496
|
-
|
|
1497
|
+
has_kv_cache = exists(kv_cache)
|
|
1498
|
+
use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix
|
|
1497
1499
|
|
|
1498
1500
|
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
|
|
1499
1501
|
|
|
@@ -1505,7 +1507,6 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1505
1507
|
|
|
1506
1508
|
time_attn_kv_caches = []
|
|
1507
1509
|
|
|
1508
|
-
has_kv_cache = exists(kv_cache)
|
|
1509
1510
|
|
|
1510
1511
|
if has_kv_cache:
|
|
1511
1512
|
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
|
@@ -1847,7 +1848,7 @@ class VideoTokenizer(Module):
|
|
|
1847
1848
|
|
|
1848
1849
|
losses = (recon_loss, lpips_loss)
|
|
1849
1850
|
|
|
1850
|
-
return total_loss, TokenizerLosses(losses)
|
|
1851
|
+
return total_loss, TokenizerLosses(*losses)
|
|
1851
1852
|
|
|
1852
1853
|
# dynamics model, axial space-time transformer
|
|
1853
1854
|
|
|
@@ -2104,7 +2105,7 @@ class DynamicsWorldModel(Module):
|
|
|
2104
2105
|
|
|
2105
2106
|
self.ppo_eps_clip = ppo_eps_clip
|
|
2106
2107
|
self.value_clip = value_clip
|
|
2107
|
-
self.policy_entropy_weight =
|
|
2108
|
+
self.policy_entropy_weight = policy_entropy_weight
|
|
2108
2109
|
|
|
2109
2110
|
# pmpo related
|
|
2110
2111
|
|
|
@@ -2127,7 +2128,7 @@ class DynamicsWorldModel(Module):
|
|
|
2127
2128
|
self.flow_loss_normalizer = LossNormalizer(1)
|
|
2128
2129
|
self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
|
|
2129
2130
|
self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
|
2130
|
-
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if
|
|
2131
|
+
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None
|
|
2131
2132
|
|
|
2132
2133
|
self.latent_flow_loss_weight = latent_flow_loss_weight
|
|
2133
2134
|
|
|
@@ -2358,6 +2359,9 @@ class DynamicsWorldModel(Module):
|
|
|
2358
2359
|
elif len(env_step_out) == 4:
|
|
2359
2360
|
next_frame, reward, terminated, truncated = env_step_out
|
|
2360
2361
|
|
|
2362
|
+
elif len(env_step_out) == 5:
|
|
2363
|
+
next_frame, reward, terminated, truncated, info = env_step_out
|
|
2364
|
+
|
|
2361
2365
|
# update episode lens
|
|
2362
2366
|
|
|
2363
2367
|
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
|
@@ -3085,8 +3089,8 @@ class DynamicsWorldModel(Module):
|
|
|
3085
3089
|
if latents.ndim == 4:
|
|
3086
3090
|
latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
|
|
3087
3091
|
|
|
3088
|
-
assert latents.shape[-2:] == self.latent_shape
|
|
3089
|
-
assert latents.shape[2] == self.num_video_views
|
|
3092
|
+
assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}'
|
|
3093
|
+
assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}'
|
|
3090
3094
|
|
|
3091
3095
|
# variables
|
|
3092
3096
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|