dreamer4 0.1.4__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dreamer4/dreamer4.py +142 -25
- dreamer4/trainers.py +1 -1
- {dreamer4-0.1.4.dist-info → dreamer4-0.1.15.dist-info}/METADATA +4 -3
- dreamer4-0.1.15.dist-info/RECORD +8 -0
- dreamer4-0.1.4.dist-info/RECORD +0 -8
- {dreamer4-0.1.4.dist-info → dreamer4-0.1.15.dist-info}/WHEEL +0 -0
- {dreamer4-0.1.4.dist-info → dreamer4-0.1.15.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -14,7 +14,7 @@ from torch.nested import nested_tensor
|
|
|
14
14
|
from torch.distributions import Normal, kl
|
|
15
15
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
|
16
16
|
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
17
|
-
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
17
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
|
18
18
|
|
|
19
19
|
import torchvision
|
|
20
20
|
from torchvision.models import VGG16_Weights
|
|
@@ -27,6 +27,8 @@ from x_mlps_pytorch.normed_mlp import create_mlp
|
|
|
27
27
|
|
|
28
28
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
29
29
|
|
|
30
|
+
from vit_pytorch.vit_with_decorr import DecorrelationLoss
|
|
31
|
+
|
|
30
32
|
from assoc_scan import AssocScan
|
|
31
33
|
|
|
32
34
|
# ein related
|
|
@@ -68,10 +70,14 @@ except ImportError:
|
|
|
68
70
|
|
|
69
71
|
LinearNoBias = partial(Linear, bias = False)
|
|
70
72
|
|
|
71
|
-
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
|
|
73
|
+
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
|
|
72
74
|
|
|
73
75
|
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
|
74
76
|
|
|
77
|
+
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
|
|
78
|
+
|
|
79
|
+
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs'))
|
|
80
|
+
|
|
75
81
|
MaybeTensor = Tensor | None
|
|
76
82
|
|
|
77
83
|
@dataclass
|
|
@@ -91,6 +97,14 @@ class Experience:
|
|
|
91
97
|
agent_index: int = 0
|
|
92
98
|
is_from_world_model: bool = True
|
|
93
99
|
|
|
100
|
+
def cpu(self):
|
|
101
|
+
return self.to(torch.device('cpu'))
|
|
102
|
+
|
|
103
|
+
def to(self, device):
|
|
104
|
+
experience_dict = asdict(self)
|
|
105
|
+
experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
|
|
106
|
+
return Experience(**experience_dict)
|
|
107
|
+
|
|
94
108
|
def combine_experiences(
|
|
95
109
|
exps: list[Experiences]
|
|
96
110
|
) -> Experience:
|
|
@@ -175,6 +189,13 @@ def sample_prob(prob):
|
|
|
175
189
|
def is_power_two(num):
|
|
176
190
|
return log2(num).is_integer()
|
|
177
191
|
|
|
192
|
+
def maybe(fn):
|
|
193
|
+
def inner(t, *args, **kwargs):
|
|
194
|
+
if not exists(t) or not exists(fn):
|
|
195
|
+
return None
|
|
196
|
+
return fn(t)
|
|
197
|
+
return inner
|
|
198
|
+
|
|
178
199
|
# tensor helpers
|
|
179
200
|
|
|
180
201
|
def is_empty(t):
|
|
@@ -209,6 +230,14 @@ def mean_log_var_to_distr(
|
|
|
209
230
|
std = (0.5 * log_var).exp()
|
|
210
231
|
return Normal(mean, std)
|
|
211
232
|
|
|
233
|
+
def safe_stack(tensors, dim = 0):
|
|
234
|
+
tensors = [*filter(exists, tensors)]
|
|
235
|
+
|
|
236
|
+
if len(tensors) == 0:
|
|
237
|
+
return None
|
|
238
|
+
|
|
239
|
+
return stack(tensors, dim = dim)
|
|
240
|
+
|
|
212
241
|
def safe_cat(tensors, dim):
|
|
213
242
|
tensors = [*filter(exists, tensors)]
|
|
214
243
|
|
|
@@ -1262,7 +1291,8 @@ class Attention(Module):
|
|
|
1262
1291
|
pre_rmsnorm = True,
|
|
1263
1292
|
gate_values = True,
|
|
1264
1293
|
rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2
|
|
1265
|
-
rmsnorm_key = True
|
|
1294
|
+
rmsnorm_key = True,
|
|
1295
|
+
value_residual = True
|
|
1266
1296
|
):
|
|
1267
1297
|
super().__init__()
|
|
1268
1298
|
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
|
@@ -1301,6 +1331,14 @@ class Attention(Module):
|
|
|
1301
1331
|
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity()
|
|
1302
1332
|
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
|
|
1303
1333
|
|
|
1334
|
+
# value residual
|
|
1335
|
+
|
|
1336
|
+
self.to_learned_value_residual_mix = nn.Sequential(
|
|
1337
|
+
nn.Linear(dim, heads),
|
|
1338
|
+
Rearrange('b n h -> b h n 1'),
|
|
1339
|
+
nn.Sigmoid()
|
|
1340
|
+
) if value_residual else None
|
|
1341
|
+
|
|
1304
1342
|
def muon_parameters(self):
|
|
1305
1343
|
# omit the queries and keys for now given what we learned from kimi 2 paper
|
|
1306
1344
|
|
|
@@ -1313,8 +1351,9 @@ class Attention(Module):
|
|
|
1313
1351
|
self,
|
|
1314
1352
|
tokens, # (b n d)
|
|
1315
1353
|
kv_cache = None,
|
|
1316
|
-
|
|
1354
|
+
return_intermediates = False,
|
|
1317
1355
|
rotary_pos_emb = None,
|
|
1356
|
+
residual_values = None, # (b n h d)
|
|
1318
1357
|
attend_fn: Callable | None = None
|
|
1319
1358
|
):
|
|
1320
1359
|
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
|
|
@@ -1327,6 +1366,17 @@ class Attention(Module):
|
|
|
1327
1366
|
|
|
1328
1367
|
q, k, v = map(self.split_heads, (q, k, v))
|
|
1329
1368
|
|
|
1369
|
+
# handle maybe value residual
|
|
1370
|
+
|
|
1371
|
+
if exists(residual_values):
|
|
1372
|
+
residual_values = rearrange(residual_values, '... n h d -> (...) h n d')
|
|
1373
|
+
|
|
1374
|
+
assert exists(self.to_learned_value_residual_mix)
|
|
1375
|
+
|
|
1376
|
+
learned_mix = self.to_learned_value_residual_mix(tokens)
|
|
1377
|
+
|
|
1378
|
+
v = v.lerp(residual_values, learned_mix)
|
|
1379
|
+
|
|
1330
1380
|
# qk rmsnorm
|
|
1331
1381
|
|
|
1332
1382
|
q = self.q_heads_rmsnorm(q)
|
|
@@ -1367,10 +1417,10 @@ class Attention(Module):
|
|
|
1367
1417
|
|
|
1368
1418
|
out = inverse_packed_batch(out)
|
|
1369
1419
|
|
|
1370
|
-
if not
|
|
1420
|
+
if not return_intermediates:
|
|
1371
1421
|
return out
|
|
1372
1422
|
|
|
1373
|
-
return out, stack((k, v))
|
|
1423
|
+
return out, AttentionIntermediates(stack((k, v)), tokens)
|
|
1374
1424
|
|
|
1375
1425
|
# feedforward
|
|
1376
1426
|
|
|
@@ -1410,6 +1460,7 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1410
1460
|
self,
|
|
1411
1461
|
dim,
|
|
1412
1462
|
depth,
|
|
1463
|
+
attn_heads = 8,
|
|
1413
1464
|
attn_dim_head = 64,
|
|
1414
1465
|
attn_softclamp_value = 50.,
|
|
1415
1466
|
time_block_every = 4,
|
|
@@ -1418,7 +1469,8 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1418
1469
|
num_residual_streams = 1,
|
|
1419
1470
|
num_special_spatial_tokens = 1,
|
|
1420
1471
|
special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
|
|
1421
|
-
final_norm = True
|
|
1472
|
+
final_norm = True,
|
|
1473
|
+
value_residual = True # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS
|
|
1422
1474
|
):
|
|
1423
1475
|
super().__init__()
|
|
1424
1476
|
assert depth >= time_block_every, f'depth must be at least {time_block_every}'
|
|
@@ -1439,6 +1491,19 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1439
1491
|
|
|
1440
1492
|
self.time_rotary = Rotary1D(attn_dim_head)
|
|
1441
1493
|
|
|
1494
|
+
# project initial for value residuals
|
|
1495
|
+
|
|
1496
|
+
self.value_residual = value_residual
|
|
1497
|
+
|
|
1498
|
+
if value_residual:
|
|
1499
|
+
dim_inner = attn_dim_head * attn_heads
|
|
1500
|
+
|
|
1501
|
+
self.to_value_residual = nn.Sequential(
|
|
1502
|
+
nn.RMSNorm(dim),
|
|
1503
|
+
nn.Linear(dim, dim_inner, bias = False),
|
|
1504
|
+
Rearrange('... (h d) -> ... h d', h = attn_heads)
|
|
1505
|
+
)
|
|
1506
|
+
|
|
1442
1507
|
# transformer
|
|
1443
1508
|
|
|
1444
1509
|
layers = []
|
|
@@ -1450,13 +1515,13 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1450
1515
|
is_time_block = divisible_by(layer_index, time_block_every)
|
|
1451
1516
|
is_time.append(is_time_block)
|
|
1452
1517
|
|
|
1453
|
-
rearrange_to_attend = Rearrange('b t s
|
|
1454
|
-
rearrange_from_attend = Rearrange('b s t
|
|
1518
|
+
rearrange_to_attend = Rearrange('b t s ... -> b s t ...') if is_time_block else Identity()
|
|
1519
|
+
rearrange_from_attend = Rearrange('b s t ... -> b t s ...') if is_time_block else Identity()
|
|
1455
1520
|
|
|
1456
1521
|
layers.append(ModuleList([
|
|
1457
1522
|
rearrange_to_attend,
|
|
1458
1523
|
rearrange_from_attend,
|
|
1459
|
-
hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
|
|
1524
|
+
hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, value_residual = value_residual, **attn_kwargs)),
|
|
1460
1525
|
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
|
1461
1526
|
]))
|
|
1462
1527
|
|
|
@@ -1484,7 +1549,7 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1484
1549
|
self,
|
|
1485
1550
|
tokens, # (b t s d)
|
|
1486
1551
|
kv_cache: Tensor | None = None, # (y 2 b h t d)
|
|
1487
|
-
|
|
1552
|
+
return_intermediates = False
|
|
1488
1553
|
|
|
1489
1554
|
): # (b t s d) | (y 2 b h t d)
|
|
1490
1555
|
|
|
@@ -1507,7 +1572,6 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1507
1572
|
|
|
1508
1573
|
time_attn_kv_caches = []
|
|
1509
1574
|
|
|
1510
|
-
|
|
1511
1575
|
if has_kv_cache:
|
|
1512
1576
|
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
|
1513
1577
|
|
|
@@ -1525,6 +1589,18 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1525
1589
|
|
|
1526
1590
|
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
|
1527
1591
|
|
|
1592
|
+
# value residual
|
|
1593
|
+
|
|
1594
|
+
residual_values = None
|
|
1595
|
+
|
|
1596
|
+
if self.value_residual:
|
|
1597
|
+
residual_values = self.to_value_residual(tokens)
|
|
1598
|
+
|
|
1599
|
+
# normed attention inputs
|
|
1600
|
+
|
|
1601
|
+
normed_time_attn_inputs = []
|
|
1602
|
+
normed_space_attn_inputs = []
|
|
1603
|
+
|
|
1528
1604
|
# attention
|
|
1529
1605
|
|
|
1530
1606
|
tokens = self.expand_streams(tokens)
|
|
@@ -1543,14 +1619,19 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1543
1619
|
|
|
1544
1620
|
maybe_kv_cache = next(iter_kv_cache, None) if layer_is_time else None
|
|
1545
1621
|
|
|
1622
|
+
# residual values
|
|
1623
|
+
|
|
1624
|
+
layer_residual_values = maybe(pre_attn_rearrange)(residual_values)
|
|
1625
|
+
|
|
1546
1626
|
# attention layer
|
|
1547
1627
|
|
|
1548
|
-
tokens,
|
|
1628
|
+
tokens, attn_intermediates = attn(
|
|
1549
1629
|
tokens,
|
|
1550
1630
|
rotary_pos_emb = layer_rotary_pos_emb,
|
|
1551
1631
|
attend_fn = attend_fn,
|
|
1552
1632
|
kv_cache = maybe_kv_cache,
|
|
1553
|
-
|
|
1633
|
+
residual_values = layer_residual_values,
|
|
1634
|
+
return_intermediates = True
|
|
1554
1635
|
)
|
|
1555
1636
|
|
|
1556
1637
|
tokens = post_attn_rearrange(tokens)
|
|
@@ -1562,7 +1643,13 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1562
1643
|
# save kv cache if is time layer
|
|
1563
1644
|
|
|
1564
1645
|
if layer_is_time:
|
|
1565
|
-
time_attn_kv_caches.append(next_kv_cache)
|
|
1646
|
+
time_attn_kv_caches.append(attn_intermediates.next_kv_cache)
|
|
1647
|
+
|
|
1648
|
+
# save time attention inputs for decorr
|
|
1649
|
+
|
|
1650
|
+
space_or_time_inputs = normed_time_attn_inputs if layer_is_time else normed_space_attn_inputs
|
|
1651
|
+
|
|
1652
|
+
space_or_time_inputs.append(attn_intermediates.normed_inputs)
|
|
1566
1653
|
|
|
1567
1654
|
tokens = self.reduce_streams(tokens)
|
|
1568
1655
|
|
|
@@ -1572,10 +1659,16 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1572
1659
|
# just concat the past tokens back on for now, todo - clean up the logic
|
|
1573
1660
|
out = cat((past_tokens, out), dim = 1)
|
|
1574
1661
|
|
|
1575
|
-
if not
|
|
1662
|
+
if not return_intermediates:
|
|
1576
1663
|
return out
|
|
1577
1664
|
|
|
1578
|
-
|
|
1665
|
+
intermediates = TransformerIntermediates(
|
|
1666
|
+
stack(time_attn_kv_caches),
|
|
1667
|
+
safe_stack(normed_time_attn_inputs),
|
|
1668
|
+
safe_stack(normed_space_attn_inputs)
|
|
1669
|
+
)
|
|
1670
|
+
|
|
1671
|
+
return out, intermediates
|
|
1579
1672
|
|
|
1580
1673
|
# video tokenizer
|
|
1581
1674
|
|
|
@@ -1601,12 +1694,15 @@ class VideoTokenizer(Module):
|
|
|
1601
1694
|
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
|
1602
1695
|
lpips_loss_network: Module | None = None,
|
|
1603
1696
|
lpips_loss_weight = 0.2,
|
|
1697
|
+
encoder_add_decor_aux_loss = False,
|
|
1698
|
+
decor_auxx_loss_weight = 0.1,
|
|
1699
|
+
decorr_sample_frac = 0.25,
|
|
1604
1700
|
nd_rotary_kwargs: dict = dict(
|
|
1605
1701
|
rope_min_freq = 1.,
|
|
1606
1702
|
rope_max_freq = 10000.,
|
|
1607
1703
|
rope_p_zero_freqs = 0.
|
|
1608
1704
|
),
|
|
1609
|
-
num_residual_streams = 1
|
|
1705
|
+
num_residual_streams = 1,
|
|
1610
1706
|
):
|
|
1611
1707
|
super().__init__()
|
|
1612
1708
|
|
|
@@ -1701,6 +1797,14 @@ class VideoTokenizer(Module):
|
|
|
1701
1797
|
if self.has_lpips_loss:
|
|
1702
1798
|
self.lpips = LPIPSLoss(lpips_loss_network)
|
|
1703
1799
|
|
|
1800
|
+
# decorr aux loss
|
|
1801
|
+
# https://arxiv.org/abs/2510.14657
|
|
1802
|
+
|
|
1803
|
+
self.encoder_add_decor_aux_loss = encoder_add_decor_aux_loss
|
|
1804
|
+
self.decorr_aux_loss_weight = decor_auxx_loss_weight
|
|
1805
|
+
|
|
1806
|
+
self.decorr_loss = DecorrelationLoss(decorr_sample_frac, soft_validate_num_sampled = True) if encoder_add_decor_aux_loss else None
|
|
1807
|
+
|
|
1704
1808
|
@property
|
|
1705
1809
|
def device(self):
|
|
1706
1810
|
return self.zero.device
|
|
@@ -1814,7 +1918,7 @@ class VideoTokenizer(Module):
|
|
|
1814
1918
|
|
|
1815
1919
|
# encoder attention
|
|
1816
1920
|
|
|
1817
|
-
tokens = self.encoder_transformer(tokens)
|
|
1921
|
+
tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
|
|
1818
1922
|
|
|
1819
1923
|
# latent bottleneck
|
|
1820
1924
|
|
|
@@ -1836,17 +1940,28 @@ class VideoTokenizer(Module):
|
|
|
1836
1940
|
if self.has_lpips_loss:
|
|
1837
1941
|
lpips_loss = self.lpips(video, recon_video)
|
|
1838
1942
|
|
|
1943
|
+
time_decorr_loss = space_decorr_loss = self.zero
|
|
1944
|
+
|
|
1945
|
+
if self.encoder_add_decor_aux_loss:
|
|
1946
|
+
if exists(time_attn_normed_inputs):
|
|
1947
|
+
time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
|
|
1948
|
+
|
|
1949
|
+
if exists(space_attn_normed_inputs):
|
|
1950
|
+
space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
|
|
1951
|
+
|
|
1839
1952
|
# losses
|
|
1840
1953
|
|
|
1841
1954
|
total_loss = (
|
|
1842
1955
|
recon_loss +
|
|
1843
|
-
lpips_loss * self.lpips_loss_weight
|
|
1956
|
+
lpips_loss * self.lpips_loss_weight +
|
|
1957
|
+
time_decorr_loss * self.decorr_aux_loss_weight +
|
|
1958
|
+
space_decorr_loss * self.decorr_aux_loss_weight
|
|
1844
1959
|
)
|
|
1845
1960
|
|
|
1846
1961
|
if not return_all_losses:
|
|
1847
1962
|
return total_loss
|
|
1848
1963
|
|
|
1849
|
-
losses = (recon_loss, lpips_loss)
|
|
1964
|
+
losses = (recon_loss, lpips_loss, decorr_loss)
|
|
1850
1965
|
|
|
1851
1966
|
return total_loss, TokenizerLosses(*losses)
|
|
1852
1967
|
|
|
@@ -1870,10 +1985,9 @@ class DynamicsWorldModel(Module):
|
|
|
1870
1985
|
depth = 4,
|
|
1871
1986
|
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
|
1872
1987
|
time_block_every = 4, # every 4th block is time
|
|
1873
|
-
attn_kwargs: dict = dict(
|
|
1874
|
-
heads = 8,
|
|
1875
|
-
),
|
|
1988
|
+
attn_kwargs: dict = dict(),
|
|
1876
1989
|
transformer_kwargs: dict = dict(),
|
|
1990
|
+
attn_heads = 8,
|
|
1877
1991
|
attn_dim_head = 64,
|
|
1878
1992
|
attn_softclamp_value = 50.,
|
|
1879
1993
|
ff_kwargs: dict = dict(),
|
|
@@ -2086,6 +2200,7 @@ class DynamicsWorldModel(Module):
|
|
|
2086
2200
|
self.transformer = AxialSpaceTimeTransformer(
|
|
2087
2201
|
dim = dim,
|
|
2088
2202
|
depth = depth,
|
|
2203
|
+
attn_heads = attn_heads,
|
|
2089
2204
|
attn_dim_head = attn_dim_head,
|
|
2090
2205
|
attn_softclamp_value = attn_softclamp_value,
|
|
2091
2206
|
attn_kwargs = attn_kwargs,
|
|
@@ -2435,6 +2550,8 @@ class DynamicsWorldModel(Module):
|
|
|
2435
2550
|
):
|
|
2436
2551
|
assert isinstance(experience, Experience)
|
|
2437
2552
|
|
|
2553
|
+
experience = experience.to(self.device)
|
|
2554
|
+
|
|
2438
2555
|
latents = experience.latents
|
|
2439
2556
|
actions = experience.actions
|
|
2440
2557
|
old_log_probs = experience.log_probs
|
|
@@ -3325,7 +3442,7 @@ class DynamicsWorldModel(Module):
|
|
|
3325
3442
|
|
|
3326
3443
|
# attention
|
|
3327
3444
|
|
|
3328
|
-
tokens, next_time_kv_cache = self.transformer(tokens, kv_cache = time_kv_cache,
|
|
3445
|
+
tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
|
|
3329
3446
|
|
|
3330
3447
|
# unpack
|
|
3331
3448
|
|
dreamer4/trainers.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dreamer4
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.15
|
|
4
4
|
Summary: Dreamer 4
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/dreamer4/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/dreamer4
|
|
@@ -44,6 +44,7 @@ Requires-Dist: hl-gauss-pytorch
|
|
|
44
44
|
Requires-Dist: hyper-connections>=0.2.1
|
|
45
45
|
Requires-Dist: torch>=2.4
|
|
46
46
|
Requires-Dist: torchvision
|
|
47
|
+
Requires-Dist: vit-pytorch>=1.15.3
|
|
47
48
|
Requires-Dist: x-mlps-pytorch>=0.0.29
|
|
48
49
|
Provides-Extra: examples
|
|
49
50
|
Provides-Extra: test
|
|
@@ -57,7 +58,7 @@ Description-Content-Type: text/markdown
|
|
|
57
58
|
|
|
58
59
|
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
|
59
60
|
|
|
60
|
-
[Discord channel](https://discord.gg/
|
|
61
|
+
[Discord channel](https://discord.gg/PmGR7KRwxq) for collaborating with other researchers interested in this work
|
|
61
62
|
|
|
62
63
|
## Appreciation
|
|
63
64
|
|
|
@@ -90,7 +91,7 @@ video = torch.randn(2, 3, 10, 256, 256)
|
|
|
90
91
|
# learn the tokenizer
|
|
91
92
|
|
|
92
93
|
loss = tokenizer(video)
|
|
93
|
-
loss.backward()
|
|
94
|
+
loss.backward()
|
|
94
95
|
|
|
95
96
|
# dynamics world model
|
|
96
97
|
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=BVMAIfhqv7wO0FWo-SBfUnyXEQcMljh6CyaHeZ8GmCI,125018
|
|
3
|
+
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
+
dreamer4/trainers.py,sha256=h_BMi-P2QMVi-IWQCkejPmyA0UzHgKtE1n7Qn1-IrxE,15093
|
|
5
|
+
dreamer4-0.1.15.dist-info/METADATA,sha256=ghChOd76397jZ_XwFwKRv1lxP1ZFqNgQfSKBUB7DXoo,4973
|
|
6
|
+
dreamer4-0.1.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.1.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.1.15.dist-info/RECORD,,
|
dreamer4-0.1.4.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=ghestMgz7B1oEqBRR0XkkdWe0kkh7bshhzmi6-n-XIs,120790
|
|
3
|
-
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
-
dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
|
|
5
|
-
dreamer4-0.1.4.dist-info/METADATA,sha256=GkzuqKtNJJCSh5FycWJOr49253_w926biJkSz9ic4TQ,4941
|
|
6
|
-
dreamer4-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.1.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.1.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|