dreamer4 0.0.90__tar.gz → 0.0.92__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.90
3
+ Version: 0.0.92
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -11,7 +11,7 @@ from dataclasses import dataclass, asdict
11
11
  import torch
12
12
  import torch.nn.functional as F
13
13
  from torch.nested import nested_tensor
14
- from torch.distributions import Normal
14
+ from torch.distributions import Normal, kl
15
15
  from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
16
16
  from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
17
17
  from torch.utils._pytree import tree_flatten, tree_unflatten
@@ -186,9 +186,26 @@ def lens_to_mask(t, max_len = None):
186
186
 
187
187
  return einx.less('j, i -> i j', seq, t)
188
188
 
189
+ def masked_mean(t, mask = None):
190
+ if not exists(mask):
191
+ return t.mean()
192
+
193
+ if not mask.any():
194
+ return t[mask].sum()
195
+
196
+ return t[mask].mean()
197
+
189
198
  def log(t, eps = 1e-20):
190
199
  return t.clamp(min = eps).log()
191
200
 
201
+ def mean_log_var_to_distr(
202
+ mean_log_var: Tensor
203
+ ) -> Normal:
204
+
205
+ mean, log_var = mean_log_var.unbind(dim = -1)
206
+ std = (0.5 * log_var).exp()
207
+ return Normal(mean, std)
208
+
192
209
  def safe_cat(tensors, dim):
193
210
  tensors = [*filter(exists, tensors)]
194
211
 
@@ -815,10 +832,7 @@ class ActionEmbedder(Module):
815
832
  continuous_entropies = None
816
833
 
817
834
  if exists(continuous_targets):
818
- mean, log_var = continuous_action_mean_log_var.unbind(dim = -1)
819
- std = (0.5 * log_var).exp()
820
-
821
- distr = Normal(mean, std)
835
+ distr = mean_log_var_to_distr(continuous_action_mean_log_var)
822
836
  continuous_log_probs = distr.log_prob(continuous_targets)
823
837
 
824
838
  if return_entropies:
@@ -833,6 +847,54 @@ class ActionEmbedder(Module):
833
847
 
834
848
  return log_probs, entropies
835
849
 
850
+ def kl_div(
851
+ self,
852
+ src: tuple[Tensor | None, Tensor | None],
853
+ tgt: tuple[Tensor | None, Tensor | None]
854
+ ) -> tuple[Tensor | None, Tensor | None]:
855
+
856
+ src_discrete, src_continuous = src
857
+ tgt_discrete, tgt_continuous = tgt
858
+
859
+ discrete_kl_div = None
860
+
861
+ # split discrete if it is not already (multiple discrete actions)
862
+
863
+ if exists(src_discrete):
864
+
865
+ discrete_split = self.num_discrete_actions.tolist()
866
+
867
+ if is_tensor(src_discrete):
868
+ src_discrete = src_discrete.split(discrete_split, dim = -1)
869
+
870
+ if is_tensor(tgt_discrete):
871
+ tgt_discrete = tgt_discrete.split(discrete_split, dim = -1)
872
+
873
+ discrete_kl_divs = []
874
+
875
+ for src_logit, tgt_logit in zip(src_discrete, tgt_discrete):
876
+
877
+ src_log_probs = src_logit.log_softmax(dim = -1)
878
+ tgt_prob = tgt_logit.softmax(dim = -1)
879
+
880
+ one_discrete_kl_div = F.kl_div(src_log_probs, tgt_prob, reduction = 'none')
881
+
882
+ discrete_kl_divs.append(one_discrete_kl_div.sum(dim = -1))
883
+
884
+ discrete_kl_div = stack(discrete_kl_divs, dim = -1)
885
+
886
+ # calculate kl divergence for continuous
887
+
888
+ continuous_kl_div = None
889
+
890
+ if exists(src_continuous):
891
+ src_normal = mean_log_var_to_distr(src_continuous)
892
+ tgt_normal = mean_log_var_to_distr(tgt_continuous)
893
+
894
+ continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
895
+
896
+ return discrete_kl_div, continuous_kl_div
897
+
836
898
  def forward(
837
899
  self,
838
900
  *,
@@ -1824,6 +1886,7 @@ class DynamicsWorldModel(Module):
1824
1886
  gae_discount_factor = 0.997,
1825
1887
  gae_lambda = 0.95,
1826
1888
  ppo_eps_clip = 0.2,
1889
+ pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
1827
1890
  value_clip = 0.4,
1828
1891
  policy_entropy_weight = .01,
1829
1892
  gae_use_accelerated = False
@@ -2027,6 +2090,10 @@ class DynamicsWorldModel(Module):
2027
2090
  self.value_clip = value_clip
2028
2091
  self.policy_entropy_weight = value_clip
2029
2092
 
2093
+ # pmpo related
2094
+
2095
+ self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
2096
+
2030
2097
  # rewards related
2031
2098
 
2032
2099
  self.keep_reward_ema_stats = keep_reward_ema_stats
@@ -2334,7 +2401,7 @@ class DynamicsWorldModel(Module):
2334
2401
  policy_optim: Optimizer | None = None,
2335
2402
  value_optim: Optimizer | None = None,
2336
2403
  only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
2337
- use_signed_advantage = True,
2404
+ use_pmpo = True,
2338
2405
  eps = 1e-6
2339
2406
  ):
2340
2407
 
@@ -2374,6 +2441,8 @@ class DynamicsWorldModel(Module):
2374
2441
  max_time = latents.shape[1]
2375
2442
  is_var_len = exists(experience.lens)
2376
2443
 
2444
+ mask = None
2445
+
2377
2446
  if is_var_len:
2378
2447
  learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value
2379
2448
  mask = lens_to_mask(learnable_lens, max_time)
@@ -2417,8 +2486,9 @@ class DynamicsWorldModel(Module):
2417
2486
  # apparently they just use the sign of the advantage
2418
2487
  # https://arxiv.org/abs/2410.04166v1
2419
2488
 
2420
- if use_signed_advantage:
2421
- advantage = advantage.sign()
2489
+ if use_pmpo:
2490
+ pos_advantage_mask = advantage >= 0.
2491
+ neg_advantage_mask = ~pos_advantage_mask
2422
2492
  else:
2423
2493
  advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
2424
2494
 
@@ -2464,35 +2534,48 @@ class DynamicsWorldModel(Module):
2464
2534
  log_probs = safe_cat(log_probs, dim = -1)
2465
2535
  entropies = safe_cat(entropies, dim = -1)
2466
2536
 
2467
- ratio = (log_probs - old_log_probs).exp()
2537
+ advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions
2468
2538
 
2469
- clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
2539
+ if use_pmpo:
2540
+ # pmpo - weighting the positive and negative advantages equally - ignoring magnitude of advantage and taking the sign
2541
+ # seems to be weighted across batch and time, iiuc
2542
+ # eq (10) in https://arxiv.org/html/2410.04166v1
2470
2543
 
2471
- advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions
2544
+ if exists(mask):
2545
+ pos_advantage_mask &= mask
2546
+ neg_advantage_mask &= mask
2472
2547
 
2473
- # clipped surrogate loss
2548
+ α = self.pmpo_pos_to_neg_weight
2474
2549
 
2475
- policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
2476
- policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
2550
+ pos = masked_mean(log_probs, pos_advantage_mask)
2551
+ neg = -masked_mean(log_probs, neg_advantage_mask)
2477
2552
 
2478
- policy_loss = policy_loss.mean()
2553
+ policy_loss = -(α * pos + (1. - α) * neg)
2554
+
2555
+ else:
2556
+ # ppo clipped surrogate loss
2557
+
2558
+ ratio = (log_probs - old_log_probs).exp()
2559
+ clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
2560
+
2561
+ policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
2562
+ policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
2563
+
2564
+ policy_loss = masked_mean(policy_loss, mask)
2479
2565
 
2480
2566
  # handle entropy loss for naive exploration bonus
2481
2567
 
2482
2568
  entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
2483
2569
 
2570
+ entropy_loss = masked_mean(entropy_loss, mask)
2571
+
2572
+ # total policy loss
2573
+
2484
2574
  total_policy_loss = (
2485
2575
  policy_loss +
2486
2576
  entropy_loss * self.policy_entropy_weight
2487
2577
  )
2488
2578
 
2489
- # maybe handle variable lengths
2490
-
2491
- if is_var_len:
2492
- total_policy_loss = total_policy_loss[mask].mean()
2493
- else:
2494
- total_policy_loss = total_policy_loss.mean()
2495
-
2496
2579
  # maybe take policy optimizer step
2497
2580
 
2498
2581
  if exists(policy_optim):
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.90"
3
+ version = "0.0.92"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -346,6 +346,15 @@ def test_action_embedder():
346
346
  assert discrete_logits.shape == (2, 3, 8)
347
347
  assert continuous_mean_log_var.shape == (2, 3, 2, 2)
348
348
 
349
+ # test kl div
350
+
351
+ discrete_logits_tgt, continuous_mean_log_var_tgt = embedder.unembed(action_embed)
352
+
353
+ discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
354
+
355
+ assert discrete_kl_div.shape == (2, 3, 2)
356
+ assert continuous_kl_div.shape == (2, 3, 2)
357
+
349
358
  # return discrete split by number of actions
350
359
 
351
360
  discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
@@ -611,13 +620,13 @@ def test_cache_generate():
611
620
  generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
612
621
 
613
622
  @param('vectorized', (False, True))
614
- @param('use_signed_advantage', (False, True))
623
+ @param('use_pmpo', (False, True))
615
624
  @param('env_can_terminate', (False, True))
616
625
  @param('env_can_truncate', (False, True))
617
626
  @param('store_agent_embed', (False, True))
618
627
  def test_online_rl(
619
628
  vectorized,
620
- use_signed_advantage,
629
+ use_pmpo,
621
630
  env_can_terminate,
622
631
  env_can_truncate,
623
632
  store_agent_embed
@@ -674,7 +683,7 @@ def test_online_rl(
674
683
  if store_agent_embed:
675
684
  assert exists(combined_experience.agent_embed)
676
685
 
677
- actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage)
686
+ actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_pmpo = use_pmpo)
678
687
 
679
688
  actor_loss.backward()
680
689
  critic_loss.backward()
File without changes
File without changes
File without changes
File without changes
File without changes