dreamer4 0.1.10__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dreamer4/dreamer4.py CHANGED
@@ -189,6 +189,13 @@ def sample_prob(prob):
189
189
  def is_power_two(num):
190
190
  return log2(num).is_integer()
191
191
 
192
+ def maybe(fn):
193
+ def inner(t, *args, **kwargs):
194
+ if not exists(t) or not exists(fn):
195
+ return None
196
+ return fn(t)
197
+ return inner
198
+
192
199
  # tensor helpers
193
200
 
194
201
  def is_empty(t):
@@ -223,6 +230,14 @@ def mean_log_var_to_distr(
223
230
  std = (0.5 * log_var).exp()
224
231
  return Normal(mean, std)
225
232
 
233
+ def safe_stack(tensors, dim = 0):
234
+ tensors = [*filter(exists, tensors)]
235
+
236
+ if len(tensors) == 0:
237
+ return None
238
+
239
+ return stack(tensors, dim = dim)
240
+
226
241
  def safe_cat(tensors, dim):
227
242
  tensors = [*filter(exists, tensors)]
228
243
 
@@ -1276,7 +1291,8 @@ class Attention(Module):
1276
1291
  pre_rmsnorm = True,
1277
1292
  gate_values = True,
1278
1293
  rmsnorm_query = False, # a paper claims that it is better to just norm only the keys https://openreview.net/forum?id=HkztQWZfl2
1279
- rmsnorm_key = True
1294
+ rmsnorm_key = True,
1295
+ value_residual = True
1280
1296
  ):
1281
1297
  super().__init__()
1282
1298
  self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
@@ -1315,6 +1331,14 @@ class Attention(Module):
1315
1331
  self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) if rmsnorm_query else nn.Identity()
1316
1332
  self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
1317
1333
 
1334
+ # value residual
1335
+
1336
+ self.to_learned_value_residual_mix = nn.Sequential(
1337
+ nn.Linear(dim, heads),
1338
+ Rearrange('b n h -> b h n 1'),
1339
+ nn.Sigmoid()
1340
+ ) if value_residual else None
1341
+
1318
1342
  def muon_parameters(self):
1319
1343
  # omit the queries and keys for now given what we learned from kimi 2 paper
1320
1344
 
@@ -1329,6 +1353,7 @@ class Attention(Module):
1329
1353
  kv_cache = None,
1330
1354
  return_intermediates = False,
1331
1355
  rotary_pos_emb = None,
1356
+ residual_values = None, # (b n h d)
1332
1357
  attend_fn: Callable | None = None
1333
1358
  ):
1334
1359
  tokens, inverse_packed_batch = pack_one(tokens, '* n d')
@@ -1341,6 +1366,17 @@ class Attention(Module):
1341
1366
 
1342
1367
  q, k, v = map(self.split_heads, (q, k, v))
1343
1368
 
1369
+ # handle maybe value residual
1370
+
1371
+ if exists(residual_values):
1372
+ residual_values = rearrange(residual_values, '... n h d -> (...) h n d')
1373
+
1374
+ assert exists(self.to_learned_value_residual_mix)
1375
+
1376
+ learned_mix = self.to_learned_value_residual_mix(tokens)
1377
+
1378
+ v = v.lerp(residual_values, learned_mix)
1379
+
1344
1380
  # qk rmsnorm
1345
1381
 
1346
1382
  q = self.q_heads_rmsnorm(q)
@@ -1424,6 +1460,7 @@ class AxialSpaceTimeTransformer(Module):
1424
1460
  self,
1425
1461
  dim,
1426
1462
  depth,
1463
+ attn_heads = 8,
1427
1464
  attn_dim_head = 64,
1428
1465
  attn_softclamp_value = 50.,
1429
1466
  time_block_every = 4,
@@ -1432,7 +1469,8 @@ class AxialSpaceTimeTransformer(Module):
1432
1469
  num_residual_streams = 1,
1433
1470
  num_special_spatial_tokens = 1,
1434
1471
  special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
1435
- final_norm = True
1472
+ final_norm = True,
1473
+ value_residual = True # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS
1436
1474
  ):
1437
1475
  super().__init__()
1438
1476
  assert depth >= time_block_every, f'depth must be at least {time_block_every}'
@@ -1453,6 +1491,19 @@ class AxialSpaceTimeTransformer(Module):
1453
1491
 
1454
1492
  self.time_rotary = Rotary1D(attn_dim_head)
1455
1493
 
1494
+ # project initial for value residuals
1495
+
1496
+ self.value_residual = value_residual
1497
+
1498
+ if value_residual:
1499
+ dim_inner = attn_dim_head * attn_heads
1500
+
1501
+ self.to_value_residual = nn.Sequential(
1502
+ nn.RMSNorm(dim),
1503
+ nn.Linear(dim, dim_inner, bias = False),
1504
+ Rearrange('... (h d) -> ... h d', h = attn_heads)
1505
+ )
1506
+
1456
1507
  # transformer
1457
1508
 
1458
1509
  layers = []
@@ -1464,13 +1515,13 @@ class AxialSpaceTimeTransformer(Module):
1464
1515
  is_time_block = divisible_by(layer_index, time_block_every)
1465
1516
  is_time.append(is_time_block)
1466
1517
 
1467
- rearrange_to_attend = Rearrange('b t s d -> b s t d') if is_time_block else Identity()
1468
- rearrange_from_attend = Rearrange('b s t d -> b t s d') if is_time_block else Identity()
1518
+ rearrange_to_attend = Rearrange('b t s ... -> b s t ...') if is_time_block else Identity()
1519
+ rearrange_from_attend = Rearrange('b s t ... -> b t s ...') if is_time_block else Identity()
1469
1520
 
1470
1521
  layers.append(ModuleList([
1471
1522
  rearrange_to_attend,
1472
1523
  rearrange_from_attend,
1473
- hyper_conn(branch = Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs)),
1524
+ hyper_conn(branch = Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, value_residual = value_residual, **attn_kwargs)),
1474
1525
  hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
1475
1526
  ]))
1476
1527
 
@@ -1521,7 +1572,6 @@ class AxialSpaceTimeTransformer(Module):
1521
1572
 
1522
1573
  time_attn_kv_caches = []
1523
1574
 
1524
-
1525
1575
  if has_kv_cache:
1526
1576
  past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
1527
1577
 
@@ -1539,6 +1589,13 @@ class AxialSpaceTimeTransformer(Module):
1539
1589
 
1540
1590
  rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
1541
1591
 
1592
+ # value residual
1593
+
1594
+ residual_values = None
1595
+
1596
+ if self.value_residual:
1597
+ residual_values = self.to_value_residual(tokens)
1598
+
1542
1599
  # normed attention inputs
1543
1600
 
1544
1601
  normed_time_attn_inputs = []
@@ -1562,6 +1619,10 @@ class AxialSpaceTimeTransformer(Module):
1562
1619
 
1563
1620
  maybe_kv_cache = next(iter_kv_cache, None) if layer_is_time else None
1564
1621
 
1622
+ # residual values
1623
+
1624
+ layer_residual_values = maybe(pre_attn_rearrange)(residual_values)
1625
+
1565
1626
  # attention layer
1566
1627
 
1567
1628
  tokens, attn_intermediates = attn(
@@ -1569,6 +1630,7 @@ class AxialSpaceTimeTransformer(Module):
1569
1630
  rotary_pos_emb = layer_rotary_pos_emb,
1570
1631
  attend_fn = attend_fn,
1571
1632
  kv_cache = maybe_kv_cache,
1633
+ residual_values = layer_residual_values,
1572
1634
  return_intermediates = True
1573
1635
  )
1574
1636
 
@@ -1602,8 +1664,8 @@ class AxialSpaceTimeTransformer(Module):
1602
1664
 
1603
1665
  intermediates = TransformerIntermediates(
1604
1666
  stack(time_attn_kv_caches),
1605
- stack(normed_time_attn_inputs),
1606
- stack(normed_space_attn_inputs)
1667
+ safe_stack(normed_time_attn_inputs),
1668
+ safe_stack(normed_space_attn_inputs)
1607
1669
  )
1608
1670
 
1609
1671
  return out, intermediates
@@ -1881,8 +1943,11 @@ class VideoTokenizer(Module):
1881
1943
  time_decorr_loss = space_decorr_loss = self.zero
1882
1944
 
1883
1945
  if self.encoder_add_decor_aux_loss:
1884
- time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
1885
- space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
1946
+ if exists(time_attn_normed_inputs):
1947
+ time_decorr_loss = self.decorr_loss(time_attn_normed_inputs)
1948
+
1949
+ if exists(space_attn_normed_inputs):
1950
+ space_decorr_loss = self.decorr_loss(space_attn_normed_inputs)
1886
1951
 
1887
1952
  # losses
1888
1953
 
@@ -1920,10 +1985,9 @@ class DynamicsWorldModel(Module):
1920
1985
  depth = 4,
1921
1986
  pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
1922
1987
  time_block_every = 4, # every 4th block is time
1923
- attn_kwargs: dict = dict(
1924
- heads = 8,
1925
- ),
1988
+ attn_kwargs: dict = dict(),
1926
1989
  transformer_kwargs: dict = dict(),
1990
+ attn_heads = 8,
1927
1991
  attn_dim_head = 64,
1928
1992
  attn_softclamp_value = 50.,
1929
1993
  ff_kwargs: dict = dict(),
@@ -2136,6 +2200,7 @@ class DynamicsWorldModel(Module):
2136
2200
  self.transformer = AxialSpaceTimeTransformer(
2137
2201
  dim = dim,
2138
2202
  depth = depth,
2203
+ attn_heads = attn_heads,
2139
2204
  attn_dim_head = attn_dim_head,
2140
2205
  attn_softclamp_value = attn_softclamp_value,
2141
2206
  attn_kwargs = attn_kwargs,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.1.10
3
+ Version: 0.1.15
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=BVMAIfhqv7wO0FWo-SBfUnyXEQcMljh6CyaHeZ8GmCI,125018
3
+ dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
+ dreamer4/trainers.py,sha256=h_BMi-P2QMVi-IWQCkejPmyA0UzHgKtE1n7Qn1-IrxE,15093
5
+ dreamer4-0.1.15.dist-info/METADATA,sha256=ghChOd76397jZ_XwFwKRv1lxP1ZFqNgQfSKBUB7DXoo,4973
6
+ dreamer4-0.1.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.1.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.1.15.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=_xr_XJGfqhCabRV0vnue4zypHZ4kXeUDZp1N6RF2AoY,122988
3
- dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
- dreamer4/trainers.py,sha256=h_BMi-P2QMVi-IWQCkejPmyA0UzHgKtE1n7Qn1-IrxE,15093
5
- dreamer4-0.1.10.dist-info/METADATA,sha256=oTK9b_fWDCQTC89Y30OBY_2BzJJ6ih25BzgO0D-SApg,4973
6
- dreamer4-0.1.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.1.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.1.10.dist-info/RECORD,,