dreamer4 0.1.0__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -60,7 +60,7 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
60
60
  ## Install
61
61
 
62
62
  ```bash
63
- $ pip install dreamer4-pytorch
63
+ $ pip install dreamer4
64
64
  ```
65
65
 
66
66
  ## Usage
@@ -137,4 +137,4 @@ actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
137
137
  }
138
138
  ```
139
139
 
140
- *the conquest of nature is to be achieved through number and measure* - angels to Descartes, in a dream, the story goes.
140
+ *the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*
@@ -7,7 +7,7 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
7
7
  ## Install
8
8
 
9
9
  ```bash
10
- $ pip install dreamer4-pytorch
10
+ $ pip install dreamer4
11
11
  ```
12
12
 
13
13
  ## Usage
@@ -84,4 +84,4 @@ actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
84
84
  }
85
85
  ```
86
86
 
87
- *the conquest of nature is to be achieved through number and measure* - angels to Descartes, in a dream, the story goes.
87
+ *the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*
@@ -1331,6 +1331,12 @@ class Attention(Module):
1331
1331
  q = self.q_heads_rmsnorm(q)
1332
1332
  k = self.k_heads_rmsnorm(k)
1333
1333
 
1334
+ # rotary
1335
+
1336
+ if exists(rotary_pos_emb):
1337
+ q = apply_rotations(rotary_pos_emb, q)
1338
+ k = apply_rotations(rotary_pos_emb, k)
1339
+
1334
1340
  # caching
1335
1341
 
1336
1342
  if exists(kv_cache):
@@ -1338,12 +1344,6 @@ class Attention(Module):
1338
1344
  k = cat((ck, k), dim = -2)
1339
1345
  v = cat((cv, v), dim = -2)
1340
1346
 
1341
- # rotary
1342
-
1343
- if exists(rotary_pos_emb):
1344
- q = apply_rotations(rotary_pos_emb, q)
1345
- k = apply_rotations(rotary_pos_emb, k)
1346
-
1347
1347
  # attention
1348
1348
 
1349
1349
  attend_fn = default(attend_fn, naive_attend)
@@ -1507,12 +1507,11 @@ class AxialSpaceTimeTransformer(Module):
1507
1507
 
1508
1508
  has_kv_cache = exists(kv_cache)
1509
1509
 
1510
-
1511
1510
  if has_kv_cache:
1512
1511
  past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
1513
1512
 
1514
1513
  rotary_seq_len = 1
1515
- rotary_pos_offset = past_tokens.shape[-2]
1514
+ rotary_pos_offset = past_tokens.shape[1]
1516
1515
  else:
1517
1516
  rotary_seq_len = time
1518
1517
  rotary_pos_offset = 0
@@ -1687,6 +1686,7 @@ class VideoTokenizer(Module):
1687
1686
  time_block_every = time_block_every,
1688
1687
  num_special_spatial_tokens = num_latent_tokens,
1689
1688
  num_residual_streams = num_residual_streams,
1689
+ special_attend_only_itself = True,
1690
1690
  final_norm = True
1691
1691
  )
1692
1692
 
@@ -2456,8 +2456,8 @@ class DynamicsWorldModel(Module):
2456
2456
  if exists(experience.lens):
2457
2457
  mask_for_gae = lens_to_mask(experience.lens, time)
2458
2458
 
2459
- rewards = rewards.masked_fill(mask_for_gae, 0.)
2460
- old_values = old_values.masked_fill(mask_for_gae, 0.)
2459
+ rewards = rewards.masked_fill(~mask_for_gae, 0.)
2460
+ old_values = old_values.masked_fill(~mask_for_gae, 0.)
2461
2461
 
2462
2462
  # calculate returns
2463
2463
 
@@ -2492,7 +2492,7 @@ class DynamicsWorldModel(Module):
2492
2492
 
2493
2493
  # mean, var - todo - handle distributed
2494
2494
 
2495
- returns_mean, returns_var = returns.mean(), returns.var()
2495
+ returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var()
2496
2496
 
2497
2497
  # ema
2498
2498
 
@@ -3510,7 +3510,7 @@ class DynamicsWorldModel(Module):
3510
3510
 
3511
3511
  reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
3512
3512
 
3513
- reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
3513
+ reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.)
3514
3514
 
3515
3515
  if is_var_len:
3516
3516
  reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
@@ -3554,7 +3554,7 @@ class DynamicsWorldModel(Module):
3554
3554
  discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
3555
3555
 
3556
3556
  if exists(continuous_actions):
3557
- continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
3557
+ continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(continuous_actions, self.multi_token_pred_len)
3558
3558
  continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
3559
3559
  continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')
3560
3560
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.1.0"
3
+ version = "0.1.2"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes