dreamer4 0.1.10__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dreamer4/dreamer4.py +78 -13
- {dreamer4-0.1.10.dist-info → dreamer4-0.1.15.dist-info}/METADATA +1 -1
- dreamer4-0.1.15.dist-info/RECORD +8 -0
- dreamer4-0.1.10.dist-info/RECORD +0 -8
- {dreamer4-0.1.10.dist-info → dreamer4-0.1.15.dist-info}/WHEEL +0 -0
- {dreamer4-0.1.10.dist-info → dreamer4-0.1.15.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -189,6 +189,13 @@ def sample_prob(prob):
|
|
|
189
189
|
def is_power_two(num):
|
|
190
190
|
return log2(num).is_integer()
|
|
191
191
|
|
|
192
|
+
def maybe(fn):
|
|
193
|
+
def inner(t, *args, **kwargs):
|
|
194
|
+
if not exists(t) or not exists(fn):
|
|
195
|
+
return None
|
|
196
|
+
return fn(t)
|
|
197
|
+
return inner
|
|
198
|
+
|
|
192
199
|
# tensor helpers
|
|
193
200
|
|
|
194
201
|
def is_empty(t):
|
|
@@ -223,6 +230,14 @@ def mean_log_var_to_distr(
|
|
|
223
230
|
std = (0.5 * log_var).exp()
|
|
224
231
|
return Normal(mean, std)
|
|
225
232
|
|
|
233
|
+
def safe_stack(tensors, dim = 0):
|
|
234
|
+
tensors = [*filter(exists, tensors)]
|
|
235
|
+
|
|
236
|
+
if len(tensors) == 0:
|
|
237
|
+
return None
|
|
238
|
+
|
|
239
|
+
return stack(tensors, dim = dim)
|
|
240
|
+
|
|
226
241
|
def safe_cat(tensors, dim):
|
|
227
242
|
tensors = [*filter(exists, tensors)]
|
|
228
243
|
|
|
@@ -1276,7 +1291,8 @@ class Attention(Module):
|
|
|
1276
1291
|
pre_rmsnorm = True,
|
|
1277
1292
|
gate_values = True,
|
|
1278
1293
|
rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2
|
|
1279
|
-
rmsnorm_key = True
|
|
1294
|
+
rmsnorm_key = True,
|
|
1295
|
+
value_residual = True
|
|
1280
1296
|
):
|
|
1281
1297
|
super().__init__()
|
|
1282
1298
|
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
|
@@ -1315,6 +1331,14 @@ class Attention(Module):
|
|
|
1315
1331
|
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity()
|
|
1316
1332
|
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
|
|
1317
1333
|
|
|
1334
|
+
# value residual
|
|
1335
|
+
|
|
1336
|
+
self.to_learned_value_residual_mix = nn.Sequential(
|
|
1337
|
+
nn.Linear(dim, heads),
|
|
1338
|
+
Rearrange('b n h -> b h n 1'),
|
|
1339
|
+
nn.Sigmoid()
|
|
1340
|
+
) if value_residual else None
|
|
1341
|
+
|
|
1318
1342
|
def muon_parameters(self):
|
|
1319
1343
|
# omit the queries and keys for now given what we learned from kimi 2 paper
|
|
1320
1344
|
|
|
@@ -1329,6 +1353,7 @@ class Attention(Module):
|
|
|
1329
1353
|
kv_cache = None,
|
|
1330
1354
|
return_intermediates = False,
|
|
1331
1355
|
rotary_pos_emb = None,
|
|
1356
|
+
residual_values = None, # (b n h d)
|
|
1332
1357
|
attend_fn: Callable | None = None
|
|
1333
1358
|
):
|
|
1334
1359
|
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
|
|
@@ -1341,6 +1366,17 @@ class Attention(Module):
|
|
|
1341
1366
|
|
|
1342
1367
|
q, k, v = map(self.split_heads, (q, k, v))
|
|
1343
1368
|
|
|
1369
|
+
# handle maybe value residual
|
|
1370
|
+
|
|
1371
|
+
if exists(residual_values):
|
|
1372
|
+
residual_values = rearrange(residual_values, '... n h d -> (...) h n d')
|
|
1373
|
+
|
|
1374
|
+
assert exists(self.to_learned_value_residual_mix)
|
|
1375
|
+
|
|
1376
|
+
learned_mix = self.to_learned_value_residual_mix(tokens)
|
|
1377
|
+
|
|
1378
|
+
v = v.lerp(residual_values, learned_mix)
|
|
1379
|
+
|
|
1344
1380
|
# qk rmsnorm
|
|
1345
1381
|
|
|
1346
1382
|
q = self.q_heads_rmsnorm(q)
|
|
@@ -1424,6 +1460,7 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1424
1460
|
self,
|
|
1425
1461
|
dim,
|
|
1426
1462
|
depth,
|
|
1463
|
+
attn_heads = 8,
|
|
1427
1464
|
attn_dim_head = 64,
|
|
1428
1465
|
attn_softclamp_value = 50.,
|
|
1429
1466
|
time_block_every = 4,
|
|
@@ -1432,7 +1469,8 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1432
1469
|
num_residual_streams = 1,
|
|
1433
1470
|
num_special_spatial_tokens = 1,
|
|
1434
1471
|
special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
|
|
1435
|
-
final_norm = True
|
|
1472
|
+
final_norm = True,
|
|
1473
|
+
value_residual = True # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS
|
|
1436
1474
|
):
|
|
1437
1475
|
super().__init__()
|
|
1438
1476
|
assert depth >= time_block_every, f'depth must be at least {time_block_every}'
|
|
@@ -1453,6 +1491,19 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1453
1491
|
|
|
1454
1492
|
self.time_rotary = Rotary1D(attn_dim_head)
|
|
1455
1493
|
|
|
1494
|
+
# project initial for value residuals
|
|
1495
|
+
|
|
1496
|
+
self.value_residual = value_residual
|
|
1497
|
+
|
|
1498
|
+
if value_residual:
|
|
1499
|
+
dim_inner = attn_dim_head * attn_heads
|
|
1500
|
+
|
|
1501
|
+
self.to_value_residual = nn.Sequential(
|
|
1502
|
+
nn.RMSNorm(dim),
|
|
1503
|
+
nn.Linear(dim, dim_inner, bias = False),
|
|
1504
|
+
Rearrange('... (h d) -> ... h d', h = attn_heads)
|
|
1505
|
+
)
|
|
1506
|
+
|
|
1456
1507
|
# transformer
|
|
1457
1508
|
|
|
1458
1509
|
layers = []
|
|
@@ -1464,13 +1515,13 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1464
1515
|
is_time_block = divisible_by(layer_index, time_block_every)
|
|
1465
1516
|
is_time.append(is_time_block)
|
|
1466
1517
|
|
|
1467
|
-
rearrange_to_attend = Rearrange('b t s
|
|
1468
|
-
rearrange_from_attend = Rearrange('b s t
|
|
1518
|
+
rearrange_to_attend = Rearrange('b t s ... -> b s t ...') if is_time_block else Identity()
|
|
1519
|
+
rearrange_from_attend = Rearrange('b s t ... -> b t s ...') if is_time_block else Identity()
|
|
1469
1520
|
|
|
1470
1521
|
layers.append(ModuleList([
|
|
1471
1522
|
rearrange_to_attend,
|
|
1472
1523
|
rearrange_from_attend,
|
|
1473
|
-
hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
|
|
1524
|
+
hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, value_residual = value_residual, **attn_kwargs)),
|
|
1474
1525
|
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
|
1475
1526
|
]))
|
|
1476
1527
|
|
|
@@ -1521,7 +1572,6 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1521
1572
|
|
|
1522
1573
|
time_attn_kv_caches = []
|
|
1523
1574
|
|
|
1524
|
-
|
|
1525
1575
|
if has_kv_cache:
|
|
1526
1576
|
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
|
1527
1577
|
|
|
@@ -1539,6 +1589,13 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1539
1589
|
|
|
1540
1590
|
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
|
1541
1591
|
|
|
1592
|
+
# value residual
|
|
1593
|
+
|
|
1594
|
+
residual_values = None
|
|
1595
|
+
|
|
1596
|
+
if self.value_residual:
|
|
1597
|
+
residual_values = self.to_value_residual(tokens)
|
|
1598
|
+
|
|
1542
1599
|
# normed attention inputs
|
|
1543
1600
|
|
|
1544
1601
|
normed_time_attn_inputs = []
|
|
@@ -1562,6 +1619,10 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1562
1619
|
|
|
1563
1620
|
maybe_kv_cache = next(iter_kv_cache, None) if layer_is_time else None
|
|
1564
1621
|
|
|
1622
|
+
# residual values
|
|
1623
|
+
|
|
1624
|
+
layer_residual_values = maybe(pre_attn_rearrange)(residual_values)
|
|
1625
|
+
|
|
1565
1626
|
# attention layer
|
|
1566
1627
|
|
|
1567
1628
|
tokens, attn_intermediates = attn(
|
|
@@ -1569,6 +1630,7 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1569
1630
|
rotary_pos_emb = layer_rotary_pos_emb,
|
|
1570
1631
|
attend_fn = attend_fn,
|
|
1571
1632
|
kv_cache = maybe_kv_cache,
|
|
1633
|
+
residual_values = layer_residual_values,
|
|
1572
1634
|
return_intermediates = True
|
|
1573
1635
|
)
|
|
1574
1636
|
|
|
@@ -1602,8 +1664,8 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1602
1664
|
|
|
1603
1665
|
intermediates = TransformerIntermediates(
|
|
1604
1666
|
stack(time_attn_kv_caches),
|
|
1605
|
-
|
|
1606
|
-
|
|
1667
|
+
safe_stack(normed_time_attn_inputs),
|
|
1668
|
+
safe_stack(normed_space_attn_inputs)
|
|
1607
1669
|
)
|
|
1608
1670
|
|
|
1609
1671
|
return out, intermediates
|
|
@@ -1881,8 +1943,11 @@ class VideoTokenizer(Module):
|
|
|
1881
1943
|
time_decorr_loss = space_decorr_loss = self.zero
|
|
1882
1944
|
|
|
1883
1945
|
if self.encoder_add_decor_aux_loss:
|
|
1884
|
-
|
|
1885
|
-
|
|
1946
|
+
if exists(time_attn_normed_inputs):
|
|
1947
|
+
time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
|
|
1948
|
+
|
|
1949
|
+
if exists(space_attn_normed_inputs):
|
|
1950
|
+
space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
|
|
1886
1951
|
|
|
1887
1952
|
# losses
|
|
1888
1953
|
|
|
@@ -1920,10 +1985,9 @@ class DynamicsWorldModel(Module):
|
|
|
1920
1985
|
depth = 4,
|
|
1921
1986
|
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
|
1922
1987
|
time_block_every = 4, # every 4th block is time
|
|
1923
|
-
attn_kwargs: dict = dict(
|
|
1924
|
-
heads = 8,
|
|
1925
|
-
),
|
|
1988
|
+
attn_kwargs: dict = dict(),
|
|
1926
1989
|
transformer_kwargs: dict = dict(),
|
|
1990
|
+
attn_heads = 8,
|
|
1927
1991
|
attn_dim_head = 64,
|
|
1928
1992
|
attn_softclamp_value = 50.,
|
|
1929
1993
|
ff_kwargs: dict = dict(),
|
|
@@ -2136,6 +2200,7 @@ class DynamicsWorldModel(Module):
|
|
|
2136
2200
|
self.transformer = AxialSpaceTimeTransformer(
|
|
2137
2201
|
dim = dim,
|
|
2138
2202
|
depth = depth,
|
|
2203
|
+
attn_heads = attn_heads,
|
|
2139
2204
|
attn_dim_head = attn_dim_head,
|
|
2140
2205
|
attn_softclamp_value = attn_softclamp_value,
|
|
2141
2206
|
attn_kwargs = attn_kwargs,
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=BVMAIfhqv7wO0FWo-SBfUnyXEQcMljh6CyaHeZ8GmCI,125018
|
|
3
|
+
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
+
dreamer4/trainers.py,sha256=h_BMi-P2QMVi-IWQCkejPmyA0UzHgKtE1n7Qn1-IrxE,15093
|
|
5
|
+
dreamer4-0.1.15.dist-info/METADATA,sha256=ghChOd76397jZ_XwFwKRv1lxP1ZFqNgQfSKBUB7DXoo,4973
|
|
6
|
+
dreamer4-0.1.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.1.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.1.15.dist-info/RECORD,,
|
dreamer4-0.1.10.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=_xr_XJGfqhCabRV0vnue4zypHZ4kXeUDZp1N6RF2AoY,122988
|
|
3
|
-
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
-
dreamer4/trainers.py,sha256=h_BMi-P2QMVi-IWQCkejPmyA0UzHgKtE1n7Qn1-IrxE,15093
|
|
5
|
-
dreamer4-0.1.10.dist-info/METADATA,sha256=oTK9b_fWDCQTC89Y30OBY_2BzJJ6ih25BzgO0D-SApg,4973
|
|
6
|
-
dreamer4-0.1.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.1.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.1.10.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|