dreamer4 0.0.90__tar.gz → 0.0.91__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.90
3
+ Version: 0.0.91
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -186,6 +186,15 @@ def lens_to_mask(t, max_len = None):
186
186
 
187
187
  return einx.less('j, i -> i j', seq, t)
188
188
 
189
+ def masked_mean(t, mask = None):
190
+ if not exists(mask):
191
+ return t.mean()
192
+
193
+ if not mask.any():
194
+ return t[mask].sum()
195
+
196
+ return t[mask].mean()
197
+
189
198
  def log(t, eps = 1e-20):
190
199
  return t.clamp(min = eps).log()
191
200
 
@@ -1824,6 +1833,7 @@ class DynamicsWorldModel(Module):
1824
1833
  gae_discount_factor = 0.997,
1825
1834
  gae_lambda = 0.95,
1826
1835
  ppo_eps_clip = 0.2,
1836
+ pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
1827
1837
  value_clip = 0.4,
1828
1838
  policy_entropy_weight = .01,
1829
1839
  gae_use_accelerated = False
@@ -2027,6 +2037,10 @@ class DynamicsWorldModel(Module):
2027
2037
  self.value_clip = value_clip
2028
2038
  self.policy_entropy_weight = value_clip
2029
2039
 
2040
+ # pmpo related
2041
+
2042
+ self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
2043
+
2030
2044
  # rewards related
2031
2045
 
2032
2046
  self.keep_reward_ema_stats = keep_reward_ema_stats
@@ -2334,7 +2348,7 @@ class DynamicsWorldModel(Module):
2334
2348
  policy_optim: Optimizer | None = None,
2335
2349
  value_optim: Optimizer | None = None,
2336
2350
  only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
2337
- use_signed_advantage = True,
2351
+ use_pmpo = True,
2338
2352
  eps = 1e-6
2339
2353
  ):
2340
2354
 
@@ -2374,6 +2388,8 @@ class DynamicsWorldModel(Module):
2374
2388
  max_time = latents.shape[1]
2375
2389
  is_var_len = exists(experience.lens)
2376
2390
 
2391
+ mask = None
2392
+
2377
2393
  if is_var_len:
2378
2394
  learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value
2379
2395
  mask = lens_to_mask(learnable_lens, max_time)
@@ -2417,8 +2433,9 @@ class DynamicsWorldModel(Module):
2417
2433
  # apparently they just use the sign of the advantage
2418
2434
  # https://arxiv.org/abs/2410.04166v1
2419
2435
 
2420
- if use_signed_advantage:
2421
- advantage = advantage.sign()
2436
+ if use_pmpo:
2437
+ pos_advantage_mask = advantage >= 0.
2438
+ neg_advantage_mask = ~pos_advantage_mask
2422
2439
  else:
2423
2440
  advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
2424
2441
 
@@ -2464,35 +2481,48 @@ class DynamicsWorldModel(Module):
2464
2481
  log_probs = safe_cat(log_probs, dim = -1)
2465
2482
  entropies = safe_cat(entropies, dim = -1)
2466
2483
 
2467
- ratio = (log_probs - old_log_probs).exp()
2484
+ advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions
2468
2485
 
2469
- clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
2486
+ if use_pmpo:
2487
+ # pmpo - weighting the positive and negative advantages equally - ignoring magnitude of advantage and taking the sign
2488
+ # seems to be weighted across batch and time, iiuc
2489
+ # eq (10) in https://arxiv.org/html/2410.04166v1
2470
2490
 
2471
- advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions
2491
+ if exists(mask):
2492
+ pos_advantage_mask &= mask
2493
+ neg_advantage_mask &= mask
2494
+
2495
+ α = self.pmpo_pos_to_neg_weight
2472
2496
 
2473
- # clipped surrogate loss
2497
+ pos = masked_mean(log_probs, pos_advantage_mask)
2498
+ neg = -masked_mean(log_probs, neg_advantage_mask)
2474
2499
 
2475
- policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
2476
- policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
2500
+ policy_loss = -(α * pos + (1. - α) * neg)
2477
2501
 
2478
- policy_loss = policy_loss.mean()
2502
+ else:
2503
+ # ppo clipped surrogate loss
2504
+
2505
+ ratio = (log_probs - old_log_probs).exp()
2506
+ clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
2507
+
2508
+ policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
2509
+ policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
2510
+
2511
+ policy_loss = masked_mean(policy_loss, mask)
2479
2512
 
2480
2513
  # handle entropy loss for naive exploration bonus
2481
2514
 
2482
2515
  entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
2483
2516
 
2517
+ entropy_loss = masked_mean(entropy_loss, mask)
2518
+
2519
+ # total policy loss
2520
+
2484
2521
  total_policy_loss = (
2485
2522
  policy_loss +
2486
2523
  entropy_loss * self.policy_entropy_weight
2487
2524
  )
2488
2525
 
2489
- # maybe handle variable lengths
2490
-
2491
- if is_var_len:
2492
- total_policy_loss = total_policy_loss[mask].mean()
2493
- else:
2494
- total_policy_loss = total_policy_loss.mean()
2495
-
2496
2526
  # maybe take policy optimizer step
2497
2527
 
2498
2528
  if exists(policy_optim):
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.90"
3
+ version = "0.0.91"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -611,13 +611,13 @@ def test_cache_generate():
611
611
  generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
612
612
 
613
613
  @param('vectorized', (False, True))
614
- @param('use_signed_advantage', (False, True))
614
+ @param('use_pmpo', (False, True))
615
615
  @param('env_can_terminate', (False, True))
616
616
  @param('env_can_truncate', (False, True))
617
617
  @param('store_agent_embed', (False, True))
618
618
  def test_online_rl(
619
619
  vectorized,
620
- use_signed_advantage,
620
+ use_pmpo,
621
621
  env_can_terminate,
622
622
  env_can_truncate,
623
623
  store_agent_embed
@@ -674,7 +674,7 @@ def test_online_rl(
674
674
  if store_agent_embed:
675
675
  assert exists(combined_experience.agent_embed)
676
676
 
677
- actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage)
677
+ actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_pmpo = use_pmpo)
678
678
 
679
679
  actor_loss.backward()
680
680
  critic_loss.backward()
File without changes
File without changes
File without changes
File without changes
File without changes