dreamer4 0.1.0__tar.gz → 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dreamer4-0.1.0 → dreamer4-0.1.2}/PKG-INFO +3 -3
- {dreamer4-0.1.0 → dreamer4-0.1.2}/README.md +2 -2
- {dreamer4-0.1.0 → dreamer4-0.1.2}/dreamer4/dreamer4.py +13 -13
- {dreamer4-0.1.0 → dreamer4-0.1.2}/pyproject.toml +1 -1
- {dreamer4-0.1.0 → dreamer4-0.1.2}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.1.0 → dreamer4-0.1.2}/.github/workflows/test.yml +0 -0
- {dreamer4-0.1.0 → dreamer4-0.1.2}/.gitignore +0 -0
- {dreamer4-0.1.0 → dreamer4-0.1.2}/LICENSE +0 -0
- {dreamer4-0.1.0 → dreamer4-0.1.2}/dreamer4/__init__.py +0 -0
- {dreamer4-0.1.0 → dreamer4-0.1.2}/dreamer4/mocks.py +0 -0
- {dreamer4-0.1.0 → dreamer4-0.1.2}/dreamer4/trainers.py +0 -0
- {dreamer4-0.1.0 → dreamer4-0.1.2}/dreamer4-fig2.png +0 -0
- {dreamer4-0.1.0 → dreamer4-0.1.2}/tests/test_dreamer.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dreamer4
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Summary: Dreamer 4
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/dreamer4/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/dreamer4
|
|
@@ -60,7 +60,7 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
|
|
|
60
60
|
## Install
|
|
61
61
|
|
|
62
62
|
```bash
|
|
63
|
-
$ pip install dreamer4
|
|
63
|
+
$ pip install dreamer4
|
|
64
64
|
```
|
|
65
65
|
|
|
66
66
|
## Usage
|
|
@@ -137,4 +137,4 @@ actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
|
|
|
137
137
|
}
|
|
138
138
|
```
|
|
139
139
|
|
|
140
|
-
*the conquest of nature is to be achieved through number and measure
|
|
140
|
+
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*
|
|
@@ -7,7 +7,7 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
|
|
|
7
7
|
## Install
|
|
8
8
|
|
|
9
9
|
```bash
|
|
10
|
-
$ pip install dreamer4
|
|
10
|
+
$ pip install dreamer4
|
|
11
11
|
```
|
|
12
12
|
|
|
13
13
|
## Usage
|
|
@@ -84,4 +84,4 @@ actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
|
|
|
84
84
|
}
|
|
85
85
|
```
|
|
86
86
|
|
|
87
|
-
*the conquest of nature is to be achieved through number and measure
|
|
87
|
+
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*
|
|
@@ -1331,6 +1331,12 @@ class Attention(Module):
|
|
|
1331
1331
|
q = self.q_heads_rmsnorm(q)
|
|
1332
1332
|
k = self.k_heads_rmsnorm(k)
|
|
1333
1333
|
|
|
1334
|
+
# rotary
|
|
1335
|
+
|
|
1336
|
+
if exists(rotary_pos_emb):
|
|
1337
|
+
q = apply_rotations(rotary_pos_emb, q)
|
|
1338
|
+
k = apply_rotations(rotary_pos_emb, k)
|
|
1339
|
+
|
|
1334
1340
|
# caching
|
|
1335
1341
|
|
|
1336
1342
|
if exists(kv_cache):
|
|
@@ -1338,12 +1344,6 @@ class Attention(Module):
|
|
|
1338
1344
|
k = cat((ck, k), dim = -2)
|
|
1339
1345
|
v = cat((cv, v), dim = -2)
|
|
1340
1346
|
|
|
1341
|
-
# rotary
|
|
1342
|
-
|
|
1343
|
-
if exists(rotary_pos_emb):
|
|
1344
|
-
q = apply_rotations(rotary_pos_emb, q)
|
|
1345
|
-
k = apply_rotations(rotary_pos_emb, k)
|
|
1346
|
-
|
|
1347
1347
|
# attention
|
|
1348
1348
|
|
|
1349
1349
|
attend_fn = default(attend_fn, naive_attend)
|
|
@@ -1507,12 +1507,11 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1507
1507
|
|
|
1508
1508
|
has_kv_cache = exists(kv_cache)
|
|
1509
1509
|
|
|
1510
|
-
|
|
1511
1510
|
if has_kv_cache:
|
|
1512
1511
|
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
|
1513
1512
|
|
|
1514
1513
|
rotary_seq_len = 1
|
|
1515
|
-
rotary_pos_offset = past_tokens.shape[
|
|
1514
|
+
rotary_pos_offset = past_tokens.shape[1]
|
|
1516
1515
|
else:
|
|
1517
1516
|
rotary_seq_len = time
|
|
1518
1517
|
rotary_pos_offset = 0
|
|
@@ -1687,6 +1686,7 @@ class VideoTokenizer(Module):
|
|
|
1687
1686
|
time_block_every = time_block_every,
|
|
1688
1687
|
num_special_spatial_tokens = num_latent_tokens,
|
|
1689
1688
|
num_residual_streams = num_residual_streams,
|
|
1689
|
+
special_attend_only_itself = True,
|
|
1690
1690
|
final_norm = True
|
|
1691
1691
|
)
|
|
1692
1692
|
|
|
@@ -2456,8 +2456,8 @@ class DynamicsWorldModel(Module):
|
|
|
2456
2456
|
if exists(experience.lens):
|
|
2457
2457
|
mask_for_gae = lens_to_mask(experience.lens, time)
|
|
2458
2458
|
|
|
2459
|
-
rewards = rewards.masked_fill(mask_for_gae, 0.)
|
|
2460
|
-
old_values = old_values.masked_fill(mask_for_gae, 0.)
|
|
2459
|
+
rewards = rewards.masked_fill(~mask_for_gae, 0.)
|
|
2460
|
+
old_values = old_values.masked_fill(~mask_for_gae, 0.)
|
|
2461
2461
|
|
|
2462
2462
|
# calculate returns
|
|
2463
2463
|
|
|
@@ -2492,7 +2492,7 @@ class DynamicsWorldModel(Module):
|
|
|
2492
2492
|
|
|
2493
2493
|
# mean, var - todo - handle distributed
|
|
2494
2494
|
|
|
2495
|
-
returns_mean, returns_var =
|
|
2495
|
+
returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var()
|
|
2496
2496
|
|
|
2497
2497
|
# ema
|
|
2498
2498
|
|
|
@@ -3510,7 +3510,7 @@ class DynamicsWorldModel(Module):
|
|
|
3510
3510
|
|
|
3511
3511
|
reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
|
|
3512
3512
|
|
|
3513
|
-
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
|
3513
|
+
reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.)
|
|
3514
3514
|
|
|
3515
3515
|
if is_var_len:
|
|
3516
3516
|
reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
|
|
@@ -3554,7 +3554,7 @@ class DynamicsWorldModel(Module):
|
|
|
3554
3554
|
discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
|
|
3555
3555
|
|
|
3556
3556
|
if exists(continuous_actions):
|
|
3557
|
-
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(
|
|
3557
|
+
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(continuous_actions, self.multi_token_pred_len)
|
|
3558
3558
|
continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
|
|
3559
3559
|
continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')
|
|
3560
3560
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|