dreamer4 0.0.102__py3-none-any.whl → 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dreamer4/dreamer4.py +129 -41
- dreamer4/trainers.py +1 -1
- {dreamer4-0.0.102.dist-info → dreamer4-0.1.10.dist-info}/METADATA +95 -3
- dreamer4-0.1.10.dist-info/RECORD +8 -0
- dreamer4-0.0.102.dist-info/RECORD +0 -8
- {dreamer4-0.0.102.dist-info → dreamer4-0.1.10.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.102.dist-info → dreamer4-0.1.10.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -14,7 +14,7 @@ from torch.nested import nested_tensor
|
|
|
14
14
|
from torch.distributions import Normal, kl
|
|
15
15
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
|
16
16
|
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
17
|
-
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
17
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
18
18
|
|
|
19
19
|
import torchvision
|
|
20
20
|
from torchvision.models import VGG16_Weights
|
|
@@ -27,6 +27,8 @@ from x_mlps_pytorch.normed_mlp import create_mlp
|
|
|
27
27
|
|
|
28
28
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
29
29
|
|
|
30
|
+
from vit_pytorch.vit_with_decorr import DecorrelationLoss
|
|
31
|
+
|
|
30
32
|
from assoc_scan import AssocScan
|
|
31
33
|
|
|
32
34
|
# ein related
|
|
@@ -68,10 +70,14 @@ except ImportError:
|
|
|
68
70
|
|
|
69
71
|
LinearNoBias = partial(Linear, bias = False)
|
|
70
72
|
|
|
71
|
-
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
|
73
|
+
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
|
|
72
74
|
|
|
73
75
|
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
|
74
76
|
|
|
77
|
+
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
|
|
78
|
+
|
|
79
|
+
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs'))
|
|
80
|
+
|
|
75
81
|
MaybeTensor = Tensor | None
|
|
76
82
|
|
|
77
83
|
@dataclass
|
|
@@ -91,6 +97,14 @@ class Experience:
|
|
|
91
97
|
agent_index: int = 0
|
|
92
98
|
is_from_world_model: bool = True
|
|
93
99
|
|
|
100
|
+
def cpu(self):
|
|
101
|
+
return self.to(torch.device('cpu'))
|
|
102
|
+
|
|
103
|
+
def to(self, device):
|
|
104
|
+
experience_dict = asdict(self)
|
|
105
|
+
experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
|
|
106
|
+
return Experience(**experience_dict)
|
|
107
|
+
|
|
94
108
|
def combine_experiences(
|
|
95
109
|
exps: list[Experiences]
|
|
96
110
|
) -> Experience:
|
|
@@ -1179,10 +1193,11 @@ def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = F
|
|
|
1179
1193
|
|
|
1180
1194
|
def block_mask_special_tokens_right(
|
|
1181
1195
|
seq_len,
|
|
1182
|
-
num_tokens
|
|
1196
|
+
num_tokens,
|
|
1197
|
+
special_attend_only_itself = False
|
|
1183
1198
|
):
|
|
1184
1199
|
def inner(b, h, q, k):
|
|
1185
|
-
return special_token_mask(q, k, seq_len, num_tokens)
|
|
1200
|
+
return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself)
|
|
1186
1201
|
return inner
|
|
1187
1202
|
|
|
1188
1203
|
def compose_mask(mask1, mask2):
|
|
@@ -1312,7 +1327,7 @@ class Attention(Module):
|
|
|
1312
1327
|
self,
|
|
1313
1328
|
tokens, # (b n d)
|
|
1314
1329
|
kv_cache = None,
|
|
1315
|
-
|
|
1330
|
+
return_intermediates = False,
|
|
1316
1331
|
rotary_pos_emb = None,
|
|
1317
1332
|
attend_fn: Callable | None = None
|
|
1318
1333
|
):
|
|
@@ -1331,6 +1346,12 @@ class Attention(Module):
|
|
|
1331
1346
|
q = self.q_heads_rmsnorm(q)
|
|
1332
1347
|
k = self.k_heads_rmsnorm(k)
|
|
1333
1348
|
|
|
1349
|
+
# rotary
|
|
1350
|
+
|
|
1351
|
+
if exists(rotary_pos_emb):
|
|
1352
|
+
q = apply_rotations(rotary_pos_emb, q)
|
|
1353
|
+
k = apply_rotations(rotary_pos_emb, k)
|
|
1354
|
+
|
|
1334
1355
|
# caching
|
|
1335
1356
|
|
|
1336
1357
|
if exists(kv_cache):
|
|
@@ -1338,12 +1359,6 @@ class Attention(Module):
|
|
|
1338
1359
|
k = cat((ck, k), dim = -2)
|
|
1339
1360
|
v = cat((cv, v), dim = -2)
|
|
1340
1361
|
|
|
1341
|
-
# rotary
|
|
1342
|
-
|
|
1343
|
-
if exists(rotary_pos_emb):
|
|
1344
|
-
q = apply_rotations(rotary_pos_emb, q)
|
|
1345
|
-
k = apply_rotations(rotary_pos_emb, k)
|
|
1346
|
-
|
|
1347
1362
|
# attention
|
|
1348
1363
|
|
|
1349
1364
|
attend_fn = default(attend_fn, naive_attend)
|
|
@@ -1366,10 +1381,10 @@ class Attention(Module):
|
|
|
1366
1381
|
|
|
1367
1382
|
out = inverse_packed_batch(out)
|
|
1368
1383
|
|
|
1369
|
-
if not
|
|
1384
|
+
if not return_intermediates:
|
|
1370
1385
|
return out
|
|
1371
1386
|
|
|
1372
|
-
return out, stack((k, v))
|
|
1387
|
+
return out, AttentionIntermediates(stack((k, v)), tokens)
|
|
1373
1388
|
|
|
1374
1389
|
# feedforward
|
|
1375
1390
|
|
|
@@ -1483,7 +1498,7 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1483
1498
|
self,
|
|
1484
1499
|
tokens, # (b t s d)
|
|
1485
1500
|
kv_cache: Tensor | None = None, # (y 2 b h t d)
|
|
1486
|
-
|
|
1501
|
+
return_intermediates = False
|
|
1487
1502
|
|
|
1488
1503
|
): # (b t s d) | (y 2 b h t d)
|
|
1489
1504
|
|
|
@@ -1493,7 +1508,8 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1493
1508
|
|
|
1494
1509
|
# attend functions for space and time
|
|
1495
1510
|
|
|
1496
|
-
|
|
1511
|
+
has_kv_cache = exists(kv_cache)
|
|
1512
|
+
use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix
|
|
1497
1513
|
|
|
1498
1514
|
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
|
|
1499
1515
|
|
|
@@ -1505,14 +1521,12 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1505
1521
|
|
|
1506
1522
|
time_attn_kv_caches = []
|
|
1507
1523
|
|
|
1508
|
-
has_kv_cache = exists(kv_cache)
|
|
1509
|
-
|
|
1510
1524
|
|
|
1511
1525
|
if has_kv_cache:
|
|
1512
1526
|
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
|
1513
1527
|
|
|
1514
1528
|
rotary_seq_len = 1
|
|
1515
|
-
rotary_pos_offset = past_tokens.shape[
|
|
1529
|
+
rotary_pos_offset = past_tokens.shape[1]
|
|
1516
1530
|
else:
|
|
1517
1531
|
rotary_seq_len = time
|
|
1518
1532
|
rotary_pos_offset = 0
|
|
@@ -1525,6 +1539,11 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1525
1539
|
|
|
1526
1540
|
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
|
1527
1541
|
|
|
1542
|
+
# normed attention inputs
|
|
1543
|
+
|
|
1544
|
+
normed_time_attn_inputs = []
|
|
1545
|
+
normed_space_attn_inputs = []
|
|
1546
|
+
|
|
1528
1547
|
# attention
|
|
1529
1548
|
|
|
1530
1549
|
tokens = self.expand_streams(tokens)
|
|
@@ -1545,12 +1564,12 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1545
1564
|
|
|
1546
1565
|
# attention layer
|
|
1547
1566
|
|
|
1548
|
-
tokens,
|
|
1567
|
+
tokens, attn_intermediates = attn(
|
|
1549
1568
|
tokens,
|
|
1550
1569
|
rotary_pos_emb = layer_rotary_pos_emb,
|
|
1551
1570
|
attend_fn = attend_fn,
|
|
1552
1571
|
kv_cache = maybe_kv_cache,
|
|
1553
|
-
|
|
1572
|
+
return_intermediates = True
|
|
1554
1573
|
)
|
|
1555
1574
|
|
|
1556
1575
|
tokens = post_attn_rearrange(tokens)
|
|
@@ -1562,7 +1581,13 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1562
1581
|
# save kv cache if is time layer
|
|
1563
1582
|
|
|
1564
1583
|
if layer_is_time:
|
|
1565
|
-
time_attn_kv_caches.append(next_kv_cache)
|
|
1584
|
+
time_attn_kv_caches.append(attn_intermediates.next_kv_cache)
|
|
1585
|
+
|
|
1586
|
+
# save time attention inputs for decorr
|
|
1587
|
+
|
|
1588
|
+
space_or_time_inputs = normed_time_attn_inputs if layer_is_time else normed_space_attn_inputs
|
|
1589
|
+
|
|
1590
|
+
space_or_time_inputs.append(attn_intermediates.normed_inputs)
|
|
1566
1591
|
|
|
1567
1592
|
tokens = self.reduce_streams(tokens)
|
|
1568
1593
|
|
|
@@ -1572,10 +1597,16 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1572
1597
|
# just concat the past tokens back on for now, todo - clean up the logic
|
|
1573
1598
|
out = cat((past_tokens, out), dim = 1)
|
|
1574
1599
|
|
|
1575
|
-
if not
|
|
1600
|
+
if not return_intermediates:
|
|
1576
1601
|
return out
|
|
1577
1602
|
|
|
1578
|
-
|
|
1603
|
+
intermediates = TransformerIntermediates(
|
|
1604
|
+
stack(time_attn_kv_caches),
|
|
1605
|
+
stack(normed_time_attn_inputs),
|
|
1606
|
+
stack(normed_space_attn_inputs)
|
|
1607
|
+
)
|
|
1608
|
+
|
|
1609
|
+
return out, intermediates
|
|
1579
1610
|
|
|
1580
1611
|
# video tokenizer
|
|
1581
1612
|
|
|
@@ -1601,12 +1632,15 @@ class VideoTokenizer(Module):
|
|
|
1601
1632
|
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
|
1602
1633
|
lpips_loss_network: Module | None = None,
|
|
1603
1634
|
lpips_loss_weight = 0.2,
|
|
1635
|
+
encoder_add_decor_aux_loss = False,
|
|
1636
|
+
decor_auxx_loss_weight = 0.1,
|
|
1637
|
+
decorr_sample_frac = 0.25,
|
|
1604
1638
|
nd_rotary_kwargs: dict = dict(
|
|
1605
1639
|
rope_min_freq = 1.,
|
|
1606
1640
|
rope_max_freq = 10000.,
|
|
1607
1641
|
rope_p_zero_freqs = 0.
|
|
1608
1642
|
),
|
|
1609
|
-
num_residual_streams = 1
|
|
1643
|
+
num_residual_streams = 1,
|
|
1610
1644
|
):
|
|
1611
1645
|
super().__init__()
|
|
1612
1646
|
|
|
@@ -1687,6 +1721,7 @@ class VideoTokenizer(Module):
|
|
|
1687
1721
|
time_block_every = time_block_every,
|
|
1688
1722
|
num_special_spatial_tokens = num_latent_tokens,
|
|
1689
1723
|
num_residual_streams = num_residual_streams,
|
|
1724
|
+
special_attend_only_itself = True,
|
|
1690
1725
|
final_norm = True
|
|
1691
1726
|
)
|
|
1692
1727
|
|
|
@@ -1700,6 +1735,14 @@ class VideoTokenizer(Module):
|
|
|
1700
1735
|
if self.has_lpips_loss:
|
|
1701
1736
|
self.lpips = LPIPSLoss(lpips_loss_network)
|
|
1702
1737
|
|
|
1738
|
+
# decorr aux loss
|
|
1739
|
+
# https://arxiv.org/abs/2510.14657
|
|
1740
|
+
|
|
1741
|
+
self.encoder_add_decor_aux_loss = encoder_add_decor_aux_loss
|
|
1742
|
+
self.decorr_aux_loss_weight = decor_auxx_loss_weight
|
|
1743
|
+
|
|
1744
|
+
self.decorr_loss = DecorrelationLoss(decorr_sample_frac, soft_validate_num_sampled = True) if encoder_add_decor_aux_loss else None
|
|
1745
|
+
|
|
1703
1746
|
@property
|
|
1704
1747
|
def device(self):
|
|
1705
1748
|
return self.zero.device
|
|
@@ -1813,7 +1856,7 @@ class VideoTokenizer(Module):
|
|
|
1813
1856
|
|
|
1814
1857
|
# encoder attention
|
|
1815
1858
|
|
|
1816
|
-
tokens = self.encoder_transformer(tokens)
|
|
1859
|
+
tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
|
|
1817
1860
|
|
|
1818
1861
|
# latent bottleneck
|
|
1819
1862
|
|
|
@@ -1835,19 +1878,27 @@ class VideoTokenizer(Module):
|
|
|
1835
1878
|
if self.has_lpips_loss:
|
|
1836
1879
|
lpips_loss = self.lpips(video, recon_video)
|
|
1837
1880
|
|
|
1881
|
+
time_decorr_loss = space_decorr_loss = self.zero
|
|
1882
|
+
|
|
1883
|
+
if self.encoder_add_decor_aux_loss:
|
|
1884
|
+
time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
|
|
1885
|
+
space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
|
|
1886
|
+
|
|
1838
1887
|
# losses
|
|
1839
1888
|
|
|
1840
1889
|
total_loss = (
|
|
1841
1890
|
recon_loss +
|
|
1842
|
-
lpips_loss * self.lpips_loss_weight
|
|
1891
|
+
lpips_loss * self.lpips_loss_weight +
|
|
1892
|
+
time_decorr_loss * self.decorr_aux_loss_weight +
|
|
1893
|
+
space_decorr_loss * self.decorr_aux_loss_weight
|
|
1843
1894
|
)
|
|
1844
1895
|
|
|
1845
1896
|
if not return_all_losses:
|
|
1846
1897
|
return total_loss
|
|
1847
1898
|
|
|
1848
|
-
losses = (recon_loss, lpips_loss)
|
|
1899
|
+
losses = (recon_loss, lpips_loss, decorr_loss)
|
|
1849
1900
|
|
|
1850
|
-
return total_loss, TokenizerLosses(losses)
|
|
1901
|
+
return total_loss, TokenizerLosses(*losses)
|
|
1851
1902
|
|
|
1852
1903
|
# dynamics model, axial space-time transformer
|
|
1853
1904
|
|
|
@@ -2104,7 +2155,7 @@ class DynamicsWorldModel(Module):
|
|
|
2104
2155
|
|
|
2105
2156
|
self.ppo_eps_clip = ppo_eps_clip
|
|
2106
2157
|
self.value_clip = value_clip
|
|
2107
|
-
self.policy_entropy_weight =
|
|
2158
|
+
self.policy_entropy_weight = policy_entropy_weight
|
|
2108
2159
|
|
|
2109
2160
|
# pmpo related
|
|
2110
2161
|
|
|
@@ -2127,7 +2178,7 @@ class DynamicsWorldModel(Module):
|
|
|
2127
2178
|
self.flow_loss_normalizer = LossNormalizer(1)
|
|
2128
2179
|
self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
|
|
2129
2180
|
self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
|
2130
|
-
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if
|
|
2181
|
+
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None
|
|
2131
2182
|
|
|
2132
2183
|
self.latent_flow_loss_weight = latent_flow_loss_weight
|
|
2133
2184
|
|
|
@@ -2358,6 +2409,9 @@ class DynamicsWorldModel(Module):
|
|
|
2358
2409
|
elif len(env_step_out) == 4:
|
|
2359
2410
|
next_frame, reward, terminated, truncated = env_step_out
|
|
2360
2411
|
|
|
2412
|
+
elif len(env_step_out) == 5:
|
|
2413
|
+
next_frame, reward, terminated, truncated, info = env_step_out
|
|
2414
|
+
|
|
2361
2415
|
# update episode lens
|
|
2362
2416
|
|
|
2363
2417
|
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
|
@@ -2429,6 +2483,9 @@ class DynamicsWorldModel(Module):
|
|
|
2429
2483
|
normalize_advantages = None,
|
|
2430
2484
|
eps = 1e-6
|
|
2431
2485
|
):
|
|
2486
|
+
assert isinstance(experience, Experience)
|
|
2487
|
+
|
|
2488
|
+
experience = experience.to(self.device)
|
|
2432
2489
|
|
|
2433
2490
|
latents = experience.latents
|
|
2434
2491
|
actions = experience.actions
|
|
@@ -2441,7 +2498,7 @@ class DynamicsWorldModel(Module):
|
|
|
2441
2498
|
step_size = experience.step_size
|
|
2442
2499
|
agent_index = experience.agent_index
|
|
2443
2500
|
|
|
2444
|
-
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
|
|
2501
|
+
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization - world_model.generate(..., return_log_probs_and_values = True)'
|
|
2445
2502
|
|
|
2446
2503
|
batch, time = latents.shape[0], latents.shape[1]
|
|
2447
2504
|
|
|
@@ -2455,8 +2512,8 @@ class DynamicsWorldModel(Module):
|
|
|
2455
2512
|
if exists(experience.lens):
|
|
2456
2513
|
mask_for_gae = lens_to_mask(experience.lens, time)
|
|
2457
2514
|
|
|
2458
|
-
rewards = rewards.masked_fill(mask_for_gae, 0.)
|
|
2459
|
-
old_values = old_values.masked_fill(mask_for_gae, 0.)
|
|
2515
|
+
rewards = rewards.masked_fill(~mask_for_gae, 0.)
|
|
2516
|
+
old_values = old_values.masked_fill(~mask_for_gae, 0.)
|
|
2460
2517
|
|
|
2461
2518
|
# calculate returns
|
|
2462
2519
|
|
|
@@ -2491,7 +2548,7 @@ class DynamicsWorldModel(Module):
|
|
|
2491
2548
|
|
|
2492
2549
|
# mean, var - todo - handle distributed
|
|
2493
2550
|
|
|
2494
|
-
returns_mean, returns_var =
|
|
2551
|
+
returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var()
|
|
2495
2552
|
|
|
2496
2553
|
# ema
|
|
2497
2554
|
|
|
@@ -2694,12 +2751,22 @@ class DynamicsWorldModel(Module):
|
|
|
2694
2751
|
return_rewards_per_frame = False,
|
|
2695
2752
|
return_agent_actions = False,
|
|
2696
2753
|
return_log_probs_and_values = False,
|
|
2754
|
+
return_for_policy_optimization = False,
|
|
2697
2755
|
return_time_kv_cache = False,
|
|
2698
2756
|
store_agent_embed = True,
|
|
2699
2757
|
store_old_action_unembeds = True
|
|
2700
2758
|
|
|
2701
2759
|
): # (b t n d) | (b c t h w)
|
|
2702
2760
|
|
|
2761
|
+
# handy flag for returning generations for rl
|
|
2762
|
+
|
|
2763
|
+
if return_for_policy_optimization:
|
|
2764
|
+
return_agent_actions |= True
|
|
2765
|
+
return_log_probs_and_values |= True
|
|
2766
|
+
return_rewards_per_frame |= True
|
|
2767
|
+
|
|
2768
|
+
# more variables
|
|
2769
|
+
|
|
2703
2770
|
has_proprio = self.has_proprio
|
|
2704
2771
|
was_training = self.training
|
|
2705
2772
|
self.eval()
|
|
@@ -2769,6 +2836,19 @@ class DynamicsWorldModel(Module):
|
|
|
2769
2836
|
|
|
2770
2837
|
curr_time_steps = latents.shape[1]
|
|
2771
2838
|
|
|
2839
|
+
# determine whether to take an extra step if
|
|
2840
|
+
# (1) using time kv cache
|
|
2841
|
+
# (2) decoding anything off agent embedding (rewards, actions, etc)
|
|
2842
|
+
|
|
2843
|
+
take_extra_step = (
|
|
2844
|
+
use_time_kv_cache or
|
|
2845
|
+
return_rewards_per_frame or
|
|
2846
|
+
store_agent_embed or
|
|
2847
|
+
return_agent_actions
|
|
2848
|
+
)
|
|
2849
|
+
|
|
2850
|
+
# prepare noised latent / proprio inputs
|
|
2851
|
+
|
|
2772
2852
|
noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
|
|
2773
2853
|
|
|
2774
2854
|
noised_proprio = None
|
|
@@ -2776,7 +2856,10 @@ class DynamicsWorldModel(Module):
|
|
|
2776
2856
|
if has_proprio:
|
|
2777
2857
|
noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device)
|
|
2778
2858
|
|
|
2779
|
-
|
|
2859
|
+
# denoising steps
|
|
2860
|
+
|
|
2861
|
+
for step in range(num_steps + int(take_extra_step)):
|
|
2862
|
+
|
|
2780
2863
|
is_last_step = (step + 1) == num_steps
|
|
2781
2864
|
|
|
2782
2865
|
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
|
@@ -2819,6 +2902,11 @@ class DynamicsWorldModel(Module):
|
|
|
2819
2902
|
if use_time_kv_cache and is_last_step:
|
|
2820
2903
|
time_kv_cache = next_time_kv_cache
|
|
2821
2904
|
|
|
2905
|
+
# early break if taking an extra step for agent embedding off cleaned latents for decoding
|
|
2906
|
+
|
|
2907
|
+
if take_extra_step and is_last_step:
|
|
2908
|
+
break
|
|
2909
|
+
|
|
2822
2910
|
# maybe proprio
|
|
2823
2911
|
|
|
2824
2912
|
if has_proprio:
|
|
@@ -3021,7 +3109,7 @@ class DynamicsWorldModel(Module):
|
|
|
3021
3109
|
latent_is_noised = False,
|
|
3022
3110
|
return_all_losses = False,
|
|
3023
3111
|
return_intermediates = False,
|
|
3024
|
-
add_autoregressive_action_loss =
|
|
3112
|
+
add_autoregressive_action_loss = True,
|
|
3025
3113
|
update_loss_ema = None,
|
|
3026
3114
|
latent_has_view_dim = False
|
|
3027
3115
|
):
|
|
@@ -3053,8 +3141,8 @@ class DynamicsWorldModel(Module):
|
|
|
3053
3141
|
if latents.ndim == 4:
|
|
3054
3142
|
latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
|
|
3055
3143
|
|
|
3056
|
-
assert latents.shape[-2:] == self.latent_shape
|
|
3057
|
-
assert latents.shape[2] == self.num_video_views
|
|
3144
|
+
assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}'
|
|
3145
|
+
assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}'
|
|
3058
3146
|
|
|
3059
3147
|
# variables
|
|
3060
3148
|
|
|
@@ -3289,7 +3377,7 @@ class DynamicsWorldModel(Module):
|
|
|
3289
3377
|
|
|
3290
3378
|
# attention
|
|
3291
3379
|
|
|
3292
|
-
tokens, next_time_kv_cache = self.transformer(tokens, kv_cache = time_kv_cache,
|
|
3380
|
+
tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
|
|
3293
3381
|
|
|
3294
3382
|
# unpack
|
|
3295
3383
|
|
|
@@ -3478,7 +3566,7 @@ class DynamicsWorldModel(Module):
|
|
|
3478
3566
|
|
|
3479
3567
|
reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
|
|
3480
3568
|
|
|
3481
|
-
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
|
3569
|
+
reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.)
|
|
3482
3570
|
|
|
3483
3571
|
if is_var_len:
|
|
3484
3572
|
reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
|
|
@@ -3522,7 +3610,7 @@ class DynamicsWorldModel(Module):
|
|
|
3522
3610
|
discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
|
|
3523
3611
|
|
|
3524
3612
|
if exists(continuous_actions):
|
|
3525
|
-
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(
|
|
3613
|
+
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(continuous_actions, self.multi_token_pred_len)
|
|
3526
3614
|
continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
|
|
3527
3615
|
continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')
|
|
3528
3616
|
|
dreamer4/trainers.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dreamer4
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.1.10
|
|
4
4
|
Summary: Dreamer 4
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/dreamer4/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/dreamer4
|
|
@@ -44,6 +44,7 @@ Requires-Dist: hl-gauss-pytorch
|
|
|
44
44
|
Requires-Dist: hyper-connections>=0.2.1
|
|
45
45
|
Requires-Dist: torch>=2.4
|
|
46
46
|
Requires-Dist: torchvision
|
|
47
|
+
Requires-Dist: vit-pytorch>=1.15.3
|
|
47
48
|
Requires-Dist: x-mlps-pytorch>=0.0.29
|
|
48
49
|
Provides-Extra: examples
|
|
49
50
|
Provides-Extra: test
|
|
@@ -53,11 +54,100 @@ Description-Content-Type: text/markdown
|
|
|
53
54
|
|
|
54
55
|
<img src="./dreamer4-fig2.png" width="400px"></img>
|
|
55
56
|
|
|
56
|
-
## Dreamer 4
|
|
57
|
+
## Dreamer 4
|
|
57
58
|
|
|
58
59
|
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
|
59
60
|
|
|
60
|
-
[
|
|
61
|
+
[Discord channel](https://discord.gg/PmGR7KRwxq) for collaborating with other researchers interested in this work
|
|
62
|
+
|
|
63
|
+
## Appreciation
|
|
64
|
+
|
|
65
|
+
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
|
|
66
|
+
|
|
67
|
+
## Install
|
|
68
|
+
|
|
69
|
+
```bash
|
|
70
|
+
$ pip install dreamer4
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
## Usage
|
|
74
|
+
|
|
75
|
+
```python
|
|
76
|
+
import torch
|
|
77
|
+
from dreamer4 import VideoTokenizer, DynamicsWorldModel
|
|
78
|
+
|
|
79
|
+
# video tokenizer, learned through MAE + lpips
|
|
80
|
+
|
|
81
|
+
tokenizer = VideoTokenizer(
|
|
82
|
+
dim = 512,
|
|
83
|
+
dim_latent = 32,
|
|
84
|
+
patch_size = 32,
|
|
85
|
+
image_height = 256,
|
|
86
|
+
image_width = 256
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
video = torch.randn(2, 3, 10, 256, 256)
|
|
90
|
+
|
|
91
|
+
# learn the tokenizer
|
|
92
|
+
|
|
93
|
+
loss = tokenizer(video)
|
|
94
|
+
loss.backward()
|
|
95
|
+
|
|
96
|
+
# dynamics world model
|
|
97
|
+
|
|
98
|
+
world_model = DynamicsWorldModel(
|
|
99
|
+
dim = 512,
|
|
100
|
+
dim_latent = 32,
|
|
101
|
+
video_tokenizer = tokenizer,
|
|
102
|
+
num_discrete_actions = 4,
|
|
103
|
+
num_residual_streams = 1
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# state, action, rewards
|
|
107
|
+
|
|
108
|
+
video = torch.randn(2, 3, 10, 256, 256)
|
|
109
|
+
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
|
110
|
+
rewards = torch.randn(2, 10)
|
|
111
|
+
|
|
112
|
+
# learn dynamics / behavior cloned model
|
|
113
|
+
|
|
114
|
+
loss = world_model(
|
|
115
|
+
video = video,
|
|
116
|
+
rewards = rewards,
|
|
117
|
+
discrete_actions = discrete_actions
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
loss.backward()
|
|
121
|
+
|
|
122
|
+
# do the above with much data
|
|
123
|
+
|
|
124
|
+
# then generate dreams
|
|
125
|
+
|
|
126
|
+
dreams = world_model.generate(
|
|
127
|
+
10,
|
|
128
|
+
batch_size = 2,
|
|
129
|
+
return_decoded_video = True,
|
|
130
|
+
return_for_policy_optimization = True
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# learn from the dreams
|
|
134
|
+
|
|
135
|
+
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
|
|
136
|
+
|
|
137
|
+
(actor_loss + critic_loss).backward()
|
|
138
|
+
|
|
139
|
+
# learn from environment
|
|
140
|
+
|
|
141
|
+
from dreamer4.mocks import MockEnv
|
|
142
|
+
|
|
143
|
+
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
|
|
144
|
+
|
|
145
|
+
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
|
|
146
|
+
|
|
147
|
+
actor_loss, critic_loss = world_model.learn_from_experience(experience)
|
|
148
|
+
|
|
149
|
+
(actor_loss + critic_loss).backward()
|
|
150
|
+
```
|
|
61
151
|
|
|
62
152
|
## Citation
|
|
63
153
|
|
|
@@ -72,3 +162,5 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
|
|
|
72
162
|
url = {https://arxiv.org/abs/2509.24527},
|
|
73
163
|
}
|
|
74
164
|
```
|
|
165
|
+
|
|
166
|
+
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=_xr_XJGfqhCabRV0vnue4zypHZ4kXeUDZp1N6RF2AoY,122988
|
|
3
|
+
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
+
dreamer4/trainers.py,sha256=h_BMi-P2QMVi-IWQCkejPmyA0UzHgKtE1n7Qn1-IrxE,15093
|
|
5
|
+
dreamer4-0.1.10.dist-info/METADATA,sha256=oTK9b_fWDCQTC89Y30OBY_2BzJJ6ih25BzgO0D-SApg,4973
|
|
6
|
+
dreamer4-0.1.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.1.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.1.10.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=3qeVN3qdvx7iPxA0OBXw_yy5Re6rX6FIKITH9bp6RBs,119202
|
|
3
|
-
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
-
dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
|
|
5
|
-
dreamer4-0.0.102.dist-info/METADATA,sha256=xxVL1sFimb0azSD5sDOEzugY7rBT6oDek4YdiIS8m18,3066
|
|
6
|
-
dreamer4-0.0.102.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.0.102.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.0.102.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|