dreamer4 0.0.102__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dreamer4/dreamer4.py CHANGED
@@ -14,7 +14,7 @@ from torch.nested import nested_tensor
14
14
  from torch.distributions import Normal, kl
15
15
  from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
16
16
  from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
17
- from torch.utils._pytree import tree_flatten, tree_unflatten
17
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
18
18
 
19
19
  import torchvision
20
20
  from torchvision.models import VGG16_Weights
@@ -27,6 +27,8 @@ from x_mlps_pytorch.normed_mlp import create_mlp
27
27
 
28
28
  from hyper_connections import get_init_and_expand_reduce_stream_functions
29
29
 
30
+ from vit_pytorch.vit_with_decorr import DecorrelationLoss
31
+
30
32
  from assoc_scan import AssocScan
31
33
 
32
34
  # ein related
@@ -68,10 +70,14 @@ except ImportError:
68
70
 
69
71
  LinearNoBias = partial(Linear, bias = False)
70
72
 
71
- TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
73
+ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
72
74
 
73
75
  WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
74
76
 
77
+ AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
78
+
79
+ TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs'))
80
+
75
81
  MaybeTensor = Tensor | None
76
82
 
77
83
  @dataclass
@@ -91,6 +97,14 @@ class Experience:
91
97
  agent_index: int = 0
92
98
  is_from_world_model: bool = True
93
99
 
100
+ def cpu(self):
101
+ return self.to(torch.device('cpu'))
102
+
103
+ def to(self, device):
104
+ experience_dict = asdict(self)
105
+ experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
106
+ return Experience(**experience_dict)
107
+
94
108
  def combine_experiences(
95
109
  exps: list[Experiences]
96
110
  ) -> Experience:
@@ -1179,10 +1193,11 @@ def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = F
1179
1193
 
1180
1194
  def block_mask_special_tokens_right(
1181
1195
  seq_len,
1182
- num_tokens
1196
+ num_tokens,
1197
+ special_attend_only_itself = False
1183
1198
  ):
1184
1199
  def inner(b, h, q, k):
1185
- return special_token_mask(q, k, seq_len, num_tokens)
1200
+ return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself)
1186
1201
  return inner
1187
1202
 
1188
1203
  def compose_mask(mask1, mask2):
@@ -1312,7 +1327,7 @@ class Attention(Module):
1312
1327
  self,
1313
1328
  tokens, # (b n d)
1314
1329
  kv_cache = None,
1315
- return_kv_cache = False,
1330
+ return_intermediates = False,
1316
1331
  rotary_pos_emb = None,
1317
1332
  attend_fn: Callable | None = None
1318
1333
  ):
@@ -1331,6 +1346,12 @@ class Attention(Module):
1331
1346
  q = self.q_heads_rmsnorm(q)
1332
1347
  k = self.k_heads_rmsnorm(k)
1333
1348
 
1349
+ # rotary
1350
+
1351
+ if exists(rotary_pos_emb):
1352
+ q = apply_rotations(rotary_pos_emb, q)
1353
+ k = apply_rotations(rotary_pos_emb, k)
1354
+
1334
1355
  # caching
1335
1356
 
1336
1357
  if exists(kv_cache):
@@ -1338,12 +1359,6 @@ class Attention(Module):
1338
1359
  k = cat((ck, k), dim = -2)
1339
1360
  v = cat((cv, v), dim = -2)
1340
1361
 
1341
- # rotary
1342
-
1343
- if exists(rotary_pos_emb):
1344
- q = apply_rotations(rotary_pos_emb, q)
1345
- k = apply_rotations(rotary_pos_emb, k)
1346
-
1347
1362
  # attention
1348
1363
 
1349
1364
  attend_fn = default(attend_fn, naive_attend)
@@ -1366,10 +1381,10 @@ class Attention(Module):
1366
1381
 
1367
1382
  out = inverse_packed_batch(out)
1368
1383
 
1369
- if not return_kv_cache:
1384
+ if not return_intermediates:
1370
1385
  return out
1371
1386
 
1372
- return out, stack((k, v))
1387
+ return out, AttentionIntermediates(stack((k, v)), tokens)
1373
1388
 
1374
1389
  # feedforward
1375
1390
 
@@ -1483,7 +1498,7 @@ class AxialSpaceTimeTransformer(Module):
1483
1498
  self,
1484
1499
  tokens, # (b t s d)
1485
1500
  kv_cache: Tensor | None = None, # (y 2 b h t d)
1486
- return_kv_cache = False
1501
+ return_intermediates = False
1487
1502
 
1488
1503
  ): # (b t s d) | (y 2 b h t d)
1489
1504
 
@@ -1493,7 +1508,8 @@ class AxialSpaceTimeTransformer(Module):
1493
1508
 
1494
1509
  # attend functions for space and time
1495
1510
 
1496
- use_flex = exists(flex_attention) and tokens.is_cuda
1511
+ has_kv_cache = exists(kv_cache)
1512
+ use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix
1497
1513
 
1498
1514
  attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
1499
1515
 
@@ -1505,14 +1521,12 @@ class AxialSpaceTimeTransformer(Module):
1505
1521
 
1506
1522
  time_attn_kv_caches = []
1507
1523
 
1508
- has_kv_cache = exists(kv_cache)
1509
-
1510
1524
 
1511
1525
  if has_kv_cache:
1512
1526
  past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
1513
1527
 
1514
1528
  rotary_seq_len = 1
1515
- rotary_pos_offset = past_tokens.shape[-2]
1529
+ rotary_pos_offset = past_tokens.shape[1]
1516
1530
  else:
1517
1531
  rotary_seq_len = time
1518
1532
  rotary_pos_offset = 0
@@ -1525,6 +1539,11 @@ class AxialSpaceTimeTransformer(Module):
1525
1539
 
1526
1540
  rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
1527
1541
 
1542
+ # normed attention inputs
1543
+
1544
+ normed_time_attn_inputs = []
1545
+ normed_space_attn_inputs = []
1546
+
1528
1547
  # attention
1529
1548
 
1530
1549
  tokens = self.expand_streams(tokens)
@@ -1545,12 +1564,12 @@ class AxialSpaceTimeTransformer(Module):
1545
1564
 
1546
1565
  # attention layer
1547
1566
 
1548
- tokens, next_kv_cache = attn(
1567
+ tokens, attn_intermediates = attn(
1549
1568
  tokens,
1550
1569
  rotary_pos_emb = layer_rotary_pos_emb,
1551
1570
  attend_fn = attend_fn,
1552
1571
  kv_cache = maybe_kv_cache,
1553
- return_kv_cache = True
1572
+ return_intermediates = True
1554
1573
  )
1555
1574
 
1556
1575
  tokens = post_attn_rearrange(tokens)
@@ -1562,7 +1581,13 @@ class AxialSpaceTimeTransformer(Module):
1562
1581
  # save kv cache if is time layer
1563
1582
 
1564
1583
  if layer_is_time:
1565
- time_attn_kv_caches.append(next_kv_cache)
1584
+ time_attn_kv_caches.append(attn_intermediates.next_kv_cache)
1585
+
1586
+ # save time attention inputs for decorr
1587
+
1588
+ space_or_time_inputs = normed_time_attn_inputs if layer_is_time else normed_space_attn_inputs
1589
+
1590
+ space_or_time_inputs.append(attn_intermediates.normed_inputs)
1566
1591
 
1567
1592
  tokens = self.reduce_streams(tokens)
1568
1593
 
@@ -1572,10 +1597,16 @@ class AxialSpaceTimeTransformer(Module):
1572
1597
  # just concat the past tokens back on for now, todo - clean up the logic
1573
1598
  out = cat((past_tokens, out), dim = 1)
1574
1599
 
1575
- if not return_kv_cache:
1600
+ if not return_intermediates:
1576
1601
  return out
1577
1602
 
1578
- return out, stack(time_attn_kv_caches)
1603
+ intermediates = TransformerIntermediates(
1604
+ stack(time_attn_kv_caches),
1605
+ stack(normed_time_attn_inputs),
1606
+ stack(normed_space_attn_inputs)
1607
+ )
1608
+
1609
+ return out, intermediates
1579
1610
 
1580
1611
  # video tokenizer
1581
1612
 
@@ -1601,12 +1632,15 @@ class VideoTokenizer(Module):
1601
1632
  per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
1602
1633
  lpips_loss_network: Module | None = None,
1603
1634
  lpips_loss_weight = 0.2,
1635
+ encoder_add_decor_aux_loss = False,
1636
+ decor_auxx_loss_weight = 0.1,
1637
+ decorr_sample_frac = 0.25,
1604
1638
  nd_rotary_kwargs: dict = dict(
1605
1639
  rope_min_freq = 1.,
1606
1640
  rope_max_freq = 10000.,
1607
1641
  rope_p_zero_freqs = 0.
1608
1642
  ),
1609
- num_residual_streams = 1
1643
+ num_residual_streams = 1,
1610
1644
  ):
1611
1645
  super().__init__()
1612
1646
 
@@ -1687,6 +1721,7 @@ class VideoTokenizer(Module):
1687
1721
  time_block_every = time_block_every,
1688
1722
  num_special_spatial_tokens = num_latent_tokens,
1689
1723
  num_residual_streams = num_residual_streams,
1724
+ special_attend_only_itself = True,
1690
1725
  final_norm = True
1691
1726
  )
1692
1727
 
@@ -1700,6 +1735,14 @@ class VideoTokenizer(Module):
1700
1735
  if self.has_lpips_loss:
1701
1736
  self.lpips = LPIPSLoss(lpips_loss_network)
1702
1737
 
1738
+ # decorr aux loss
1739
+ # https://arxiv.org/abs/2510.14657
1740
+
1741
+ self.encoder_add_decor_aux_loss = encoder_add_decor_aux_loss
1742
+ self.decorr_aux_loss_weight = decor_auxx_loss_weight
1743
+
1744
+ self.decorr_loss = DecorrelationLoss(decorr_sample_frac, soft_validate_num_sampled = True) if encoder_add_decor_aux_loss else None
1745
+
1703
1746
  @property
1704
1747
  def device(self):
1705
1748
  return self.zero.device
@@ -1813,7 +1856,7 @@ class VideoTokenizer(Module):
1813
1856
 
1814
1857
  # encoder attention
1815
1858
 
1816
- tokens = self.encoder_transformer(tokens)
1859
+ tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
1817
1860
 
1818
1861
  # latent bottleneck
1819
1862
 
@@ -1835,19 +1878,27 @@ class VideoTokenizer(Module):
1835
1878
  if self.has_lpips_loss:
1836
1879
  lpips_loss = self.lpips(video, recon_video)
1837
1880
 
1881
+ time_decorr_loss = space_decorr_loss = self.zero
1882
+
1883
+ if self.encoder_add_decor_aux_loss:
1884
+ time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
1885
+ space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
1886
+
1838
1887
  # losses
1839
1888
 
1840
1889
  total_loss = (
1841
1890
  recon_loss +
1842
- lpips_loss * self.lpips_loss_weight
1891
+ lpips_loss * self.lpips_loss_weight +
1892
+ time_decorr_loss * self.decorr_aux_loss_weight +
1893
+ space_decorr_loss * self.decorr_aux_loss_weight
1843
1894
  )
1844
1895
 
1845
1896
  if not return_all_losses:
1846
1897
  return total_loss
1847
1898
 
1848
- losses = (recon_loss, lpips_loss)
1899
+ losses = (recon_loss, lpips_loss, decorr_loss)
1849
1900
 
1850
- return total_loss, TokenizerLosses(losses)
1901
+ return total_loss, TokenizerLosses(*losses)
1851
1902
 
1852
1903
  # dynamics model, axial space-time transformer
1853
1904
 
@@ -2104,7 +2155,7 @@ class DynamicsWorldModel(Module):
2104
2155
 
2105
2156
  self.ppo_eps_clip = ppo_eps_clip
2106
2157
  self.value_clip = value_clip
2107
- self.policy_entropy_weight = value_clip
2158
+ self.policy_entropy_weight = policy_entropy_weight
2108
2159
 
2109
2160
  # pmpo related
2110
2161
 
@@ -2127,7 +2178,7 @@ class DynamicsWorldModel(Module):
2127
2178
  self.flow_loss_normalizer = LossNormalizer(1)
2128
2179
  self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
2129
2180
  self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
2130
- self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
2181
+ self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None
2131
2182
 
2132
2183
  self.latent_flow_loss_weight = latent_flow_loss_weight
2133
2184
 
@@ -2358,6 +2409,9 @@ class DynamicsWorldModel(Module):
2358
2409
  elif len(env_step_out) == 4:
2359
2410
  next_frame, reward, terminated, truncated = env_step_out
2360
2411
 
2412
+ elif len(env_step_out) == 5:
2413
+ next_frame, reward, terminated, truncated, info = env_step_out
2414
+
2361
2415
  # update episode lens
2362
2416
 
2363
2417
  episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
@@ -2429,6 +2483,9 @@ class DynamicsWorldModel(Module):
2429
2483
  normalize_advantages = None,
2430
2484
  eps = 1e-6
2431
2485
  ):
2486
+ assert isinstance(experience, Experience)
2487
+
2488
+ experience = experience.to(self.device)
2432
2489
 
2433
2490
  latents = experience.latents
2434
2491
  actions = experience.actions
@@ -2441,7 +2498,7 @@ class DynamicsWorldModel(Module):
2441
2498
  step_size = experience.step_size
2442
2499
  agent_index = experience.agent_index
2443
2500
 
2444
- assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
2501
+ assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization - world_model.generate(..., return_log_probs_and_values = True)'
2445
2502
 
2446
2503
  batch, time = latents.shape[0], latents.shape[1]
2447
2504
 
@@ -2455,8 +2512,8 @@ class DynamicsWorldModel(Module):
2455
2512
  if exists(experience.lens):
2456
2513
  mask_for_gae = lens_to_mask(experience.lens, time)
2457
2514
 
2458
- rewards = rewards.masked_fill(mask_for_gae, 0.)
2459
- old_values = old_values.masked_fill(mask_for_gae, 0.)
2515
+ rewards = rewards.masked_fill(~mask_for_gae, 0.)
2516
+ old_values = old_values.masked_fill(~mask_for_gae, 0.)
2460
2517
 
2461
2518
  # calculate returns
2462
2519
 
@@ -2491,7 +2548,7 @@ class DynamicsWorldModel(Module):
2491
2548
 
2492
2549
  # mean, var - todo - handle distributed
2493
2550
 
2494
- returns_mean, returns_var = returns.mean(), returns.var()
2551
+ returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var()
2495
2552
 
2496
2553
  # ema
2497
2554
 
@@ -2694,12 +2751,22 @@ class DynamicsWorldModel(Module):
2694
2751
  return_rewards_per_frame = False,
2695
2752
  return_agent_actions = False,
2696
2753
  return_log_probs_and_values = False,
2754
+ return_for_policy_optimization = False,
2697
2755
  return_time_kv_cache = False,
2698
2756
  store_agent_embed = True,
2699
2757
  store_old_action_unembeds = True
2700
2758
 
2701
2759
  ): # (b t n d) | (b c t h w)
2702
2760
 
2761
+ # handy flag for returning generations for rl
2762
+
2763
+ if return_for_policy_optimization:
2764
+ return_agent_actions |= True
2765
+ return_log_probs_and_values |= True
2766
+ return_rewards_per_frame |= True
2767
+
2768
+ # more variables
2769
+
2703
2770
  has_proprio = self.has_proprio
2704
2771
  was_training = self.training
2705
2772
  self.eval()
@@ -2769,6 +2836,19 @@ class DynamicsWorldModel(Module):
2769
2836
 
2770
2837
  curr_time_steps = latents.shape[1]
2771
2838
 
2839
+ # determine whether to take an extra step if
2840
+ # (1) using time kv cache
2841
+ # (2) decoding anything off agent embedding (rewards, actions, etc)
2842
+
2843
+ take_extra_step = (
2844
+ use_time_kv_cache or
2845
+ return_rewards_per_frame or
2846
+ store_agent_embed or
2847
+ return_agent_actions
2848
+ )
2849
+
2850
+ # prepare noised latent / proprio inputs
2851
+
2772
2852
  noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
2773
2853
 
2774
2854
  noised_proprio = None
@@ -2776,7 +2856,10 @@ class DynamicsWorldModel(Module):
2776
2856
  if has_proprio:
2777
2857
  noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device)
2778
2858
 
2779
- for step in range(num_steps):
2859
+ # denoising steps
2860
+
2861
+ for step in range(num_steps + int(take_extra_step)):
2862
+
2780
2863
  is_last_step = (step + 1) == num_steps
2781
2864
 
2782
2865
  signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
@@ -2819,6 +2902,11 @@ class DynamicsWorldModel(Module):
2819
2902
  if use_time_kv_cache and is_last_step:
2820
2903
  time_kv_cache = next_time_kv_cache
2821
2904
 
2905
+ # early break if taking an extra step for agent embedding off cleaned latents for decoding
2906
+
2907
+ if take_extra_step and is_last_step:
2908
+ break
2909
+
2822
2910
  # maybe proprio
2823
2911
 
2824
2912
  if has_proprio:
@@ -3021,7 +3109,7 @@ class DynamicsWorldModel(Module):
3021
3109
  latent_is_noised = False,
3022
3110
  return_all_losses = False,
3023
3111
  return_intermediates = False,
3024
- add_autoregressive_action_loss = False,
3112
+ add_autoregressive_action_loss = True,
3025
3113
  update_loss_ema = None,
3026
3114
  latent_has_view_dim = False
3027
3115
  ):
@@ -3053,8 +3141,8 @@ class DynamicsWorldModel(Module):
3053
3141
  if latents.ndim == 4:
3054
3142
  latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
3055
3143
 
3056
- assert latents.shape[-2:] == self.latent_shape
3057
- assert latents.shape[2] == self.num_video_views
3144
+ assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}'
3145
+ assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}'
3058
3146
 
3059
3147
  # variables
3060
3148
 
@@ -3289,7 +3377,7 @@ class DynamicsWorldModel(Module):
3289
3377
 
3290
3378
  # attention
3291
3379
 
3292
- tokens, next_time_kv_cache = self.transformer(tokens, kv_cache = time_kv_cache, return_kv_cache = True)
3380
+ tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
3293
3381
 
3294
3382
  # unpack
3295
3383
 
@@ -3478,7 +3566,7 @@ class DynamicsWorldModel(Module):
3478
3566
 
3479
3567
  reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
3480
3568
 
3481
- reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
3569
+ reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.)
3482
3570
 
3483
3571
  if is_var_len:
3484
3572
  reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
@@ -3522,7 +3610,7 @@ class DynamicsWorldModel(Module):
3522
3610
  discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
3523
3611
 
3524
3612
  if exists(continuous_actions):
3525
- continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
3613
+ continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(continuous_actions, self.multi_token_pred_len)
3526
3614
  continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
3527
3615
  continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')
3528
3616
 
dreamer4/trainers.py CHANGED
@@ -528,7 +528,7 @@ class SimTrainer(Module):
528
528
 
529
529
  total_experience += num_experience
530
530
 
531
- experiences.append(experience)
531
+ experiences.append(experience.cpu())
532
532
 
533
533
  combined_experiences = combine_experiences(experiences)
534
534
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.102
3
+ Version: 0.1.10
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -44,6 +44,7 @@ Requires-Dist: hl-gauss-pytorch
44
44
  Requires-Dist: hyper-connections>=0.2.1
45
45
  Requires-Dist: torch>=2.4
46
46
  Requires-Dist: torchvision
47
+ Requires-Dist: vit-pytorch>=1.15.3
47
48
  Requires-Dist: x-mlps-pytorch>=0.0.29
48
49
  Provides-Extra: examples
49
50
  Provides-Extra: test
@@ -53,11 +54,100 @@ Description-Content-Type: text/markdown
53
54
 
54
55
  <img src="./dreamer4-fig2.png" width="400px"></img>
55
56
 
56
- ## Dreamer 4 (wip)
57
+ ## Dreamer 4
57
58
 
58
59
  Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
59
60
 
60
- [Temporary Discord](https://discord.gg/MkACrrkrYR)
61
+ [Discord channel](https://discord.gg/PmGR7KRwxq) for collaborating with other researchers interested in this work
62
+
63
+ ## Appreciation
64
+
65
+ - [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
66
+
67
+ ## Install
68
+
69
+ ```bash
70
+ $ pip install dreamer4
71
+ ```
72
+
73
+ ## Usage
74
+
75
+ ```python
76
+ import torch
77
+ from dreamer4 import VideoTokenizer, DynamicsWorldModel
78
+
79
+ # video tokenizer, learned through MAE + lpips
80
+
81
+ tokenizer = VideoTokenizer(
82
+ dim = 512,
83
+ dim_latent = 32,
84
+ patch_size = 32,
85
+ image_height = 256,
86
+ image_width = 256
87
+ )
88
+
89
+ video = torch.randn(2, 3, 10, 256, 256)
90
+
91
+ # learn the tokenizer
92
+
93
+ loss = tokenizer(video)
94
+ loss.backward()
95
+
96
+ # dynamics world model
97
+
98
+ world_model = DynamicsWorldModel(
99
+ dim = 512,
100
+ dim_latent = 32,
101
+ video_tokenizer = tokenizer,
102
+ num_discrete_actions = 4,
103
+ num_residual_streams = 1
104
+ )
105
+
106
+ # state, action, rewards
107
+
108
+ video = torch.randn(2, 3, 10, 256, 256)
109
+ discrete_actions = torch.randint(0, 4, (2, 10, 1))
110
+ rewards = torch.randn(2, 10)
111
+
112
+ # learn dynamics / behavior cloned model
113
+
114
+ loss = world_model(
115
+ video = video,
116
+ rewards = rewards,
117
+ discrete_actions = discrete_actions
118
+ )
119
+
120
+ loss.backward()
121
+
122
+ # do the above with much data
123
+
124
+ # then generate dreams
125
+
126
+ dreams = world_model.generate(
127
+ 10,
128
+ batch_size = 2,
129
+ return_decoded_video = True,
130
+ return_for_policy_optimization = True
131
+ )
132
+
133
+ # learn from the dreams
134
+
135
+ actor_loss, critic_loss = world_model.learn_from_experience(dreams)
136
+
137
+ (actor_loss + critic_loss).backward()
138
+
139
+ # learn from environment
140
+
141
+ from dreamer4.mocks import MockEnv
142
+
143
+ mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
144
+
145
+ experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
146
+
147
+ actor_loss, critic_loss = world_model.learn_from_experience(experience)
148
+
149
+ (actor_loss + critic_loss).backward()
150
+ ```
61
151
 
62
152
  ## Citation
63
153
 
@@ -72,3 +162,5 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
72
162
  url = {https://arxiv.org/abs/2509.24527},
73
163
  }
74
164
  ```
165
+
166
+ *the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=_xr_XJGfqhCabRV0vnue4zypHZ4kXeUDZp1N6RF2AoY,122988
3
+ dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
+ dreamer4/trainers.py,sha256=h_BMi-P2QMVi-IWQCkejPmyA0UzHgKtE1n7Qn1-IrxE,15093
5
+ dreamer4-0.1.10.dist-info/METADATA,sha256=oTK9b_fWDCQTC89Y30OBY_2BzJJ6ih25BzgO0D-SApg,4973
6
+ dreamer4-0.1.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.1.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.1.10.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=3qeVN3qdvx7iPxA0OBXw_yy5Re6rX6FIKITH9bp6RBs,119202
3
- dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
- dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
5
- dreamer4-0.0.102.dist-info/METADATA,sha256=xxVL1sFimb0azSD5sDOEzugY7rBT6oDek4YdiIS8m18,3066
6
- dreamer4-0.0.102.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.0.102.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.0.102.dist-info/RECORD,,