dreamer4 0.0.89__py3-none-any.whl → 0.0.91__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- dreamer4/dreamer4.py +49 -17
- {dreamer4-0.0.89.dist-info → dreamer4-0.0.91.dist-info}/METADATA +1 -1
- dreamer4-0.0.91.dist-info/RECORD +8 -0
- dreamer4-0.0.89.dist-info/RECORD +0 -8
- {dreamer4-0.0.89.dist-info → dreamer4-0.0.91.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.89.dist-info → dreamer4-0.0.91.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -186,6 +186,15 @@ def lens_to_mask(t, max_len = None):
|
|
|
186
186
|
|
|
187
187
|
return einx.less('j, i -> i j', seq, t)
|
|
188
188
|
|
|
189
|
+
def masked_mean(t, mask = None):
|
|
190
|
+
if not exists(mask):
|
|
191
|
+
return t.mean()
|
|
192
|
+
|
|
193
|
+
if not mask.any():
|
|
194
|
+
return t[mask].sum()
|
|
195
|
+
|
|
196
|
+
return t[mask].mean()
|
|
197
|
+
|
|
189
198
|
def log(t, eps = 1e-20):
|
|
190
199
|
return t.clamp(min = eps).log()
|
|
191
200
|
|
|
@@ -1824,6 +1833,7 @@ class DynamicsWorldModel(Module):
|
|
|
1824
1833
|
gae_discount_factor = 0.997,
|
|
1825
1834
|
gae_lambda = 0.95,
|
|
1826
1835
|
ppo_eps_clip = 0.2,
|
|
1836
|
+
pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
|
|
1827
1837
|
value_clip = 0.4,
|
|
1828
1838
|
policy_entropy_weight = .01,
|
|
1829
1839
|
gae_use_accelerated = False
|
|
@@ -2027,6 +2037,10 @@ class DynamicsWorldModel(Module):
|
|
|
2027
2037
|
self.value_clip = value_clip
|
|
2028
2038
|
self.policy_entropy_weight = value_clip
|
|
2029
2039
|
|
|
2040
|
+
# pmpo related
|
|
2041
|
+
|
|
2042
|
+
self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
|
|
2043
|
+
|
|
2030
2044
|
# rewards related
|
|
2031
2045
|
|
|
2032
2046
|
self.keep_reward_ema_stats = keep_reward_ema_stats
|
|
@@ -2334,7 +2348,7 @@ class DynamicsWorldModel(Module):
|
|
|
2334
2348
|
policy_optim: Optimizer | None = None,
|
|
2335
2349
|
value_optim: Optimizer | None = None,
|
|
2336
2350
|
only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
|
2337
|
-
|
|
2351
|
+
use_pmpo = True,
|
|
2338
2352
|
eps = 1e-6
|
|
2339
2353
|
):
|
|
2340
2354
|
|
|
@@ -2374,6 +2388,8 @@ class DynamicsWorldModel(Module):
|
|
|
2374
2388
|
max_time = latents.shape[1]
|
|
2375
2389
|
is_var_len = exists(experience.lens)
|
|
2376
2390
|
|
|
2391
|
+
mask = None
|
|
2392
|
+
|
|
2377
2393
|
if is_var_len:
|
|
2378
2394
|
learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value
|
|
2379
2395
|
mask = lens_to_mask(learnable_lens, max_time)
|
|
@@ -2417,8 +2433,9 @@ class DynamicsWorldModel(Module):
|
|
|
2417
2433
|
# apparently they just use the sign of the advantage
|
|
2418
2434
|
# https://arxiv.org/abs/2410.04166v1
|
|
2419
2435
|
|
|
2420
|
-
if
|
|
2421
|
-
|
|
2436
|
+
if use_pmpo:
|
|
2437
|
+
pos_advantage_mask = advantage >= 0.
|
|
2438
|
+
neg_advantage_mask = ~pos_advantage_mask
|
|
2422
2439
|
else:
|
|
2423
2440
|
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
|
2424
2441
|
|
|
@@ -2464,35 +2481,48 @@ class DynamicsWorldModel(Module):
|
|
|
2464
2481
|
log_probs = safe_cat(log_probs, dim = -1)
|
|
2465
2482
|
entropies = safe_cat(entropies, dim = -1)
|
|
2466
2483
|
|
|
2467
|
-
|
|
2484
|
+
advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions
|
|
2468
2485
|
|
|
2469
|
-
|
|
2486
|
+
if use_pmpo:
|
|
2487
|
+
# pmpo - weighting the positive and negative advantages equally - ignoring magnitude of advantage and taking the sign
|
|
2488
|
+
# seems to be weighted across batch and time, iiuc
|
|
2489
|
+
# eq (10) in https://arxiv.org/html/2410.04166v1
|
|
2470
2490
|
|
|
2471
|
-
|
|
2491
|
+
if exists(mask):
|
|
2492
|
+
pos_advantage_mask &= mask
|
|
2493
|
+
neg_advantage_mask &= mask
|
|
2494
|
+
|
|
2495
|
+
α = self.pmpo_pos_to_neg_weight
|
|
2472
2496
|
|
|
2473
|
-
|
|
2497
|
+
pos = masked_mean(log_probs, pos_advantage_mask)
|
|
2498
|
+
neg = -masked_mean(log_probs, neg_advantage_mask)
|
|
2474
2499
|
|
|
2475
|
-
|
|
2476
|
-
policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
|
|
2500
|
+
policy_loss = -(α * pos + (1. - α) * neg)
|
|
2477
2501
|
|
|
2478
|
-
|
|
2502
|
+
else:
|
|
2503
|
+
# ppo clipped surrogate loss
|
|
2504
|
+
|
|
2505
|
+
ratio = (log_probs - old_log_probs).exp()
|
|
2506
|
+
clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
|
|
2507
|
+
|
|
2508
|
+
policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
|
|
2509
|
+
policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
|
|
2510
|
+
|
|
2511
|
+
policy_loss = masked_mean(policy_loss, mask)
|
|
2479
2512
|
|
|
2480
2513
|
# handle entropy loss for naive exploration bonus
|
|
2481
2514
|
|
|
2482
2515
|
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
|
|
2483
2516
|
|
|
2517
|
+
entropy_loss = masked_mean(entropy_loss, mask)
|
|
2518
|
+
|
|
2519
|
+
# total policy loss
|
|
2520
|
+
|
|
2484
2521
|
total_policy_loss = (
|
|
2485
2522
|
policy_loss +
|
|
2486
2523
|
entropy_loss * self.policy_entropy_weight
|
|
2487
2524
|
)
|
|
2488
2525
|
|
|
2489
|
-
# maybe handle variable lengths
|
|
2490
|
-
|
|
2491
|
-
if is_var_len:
|
|
2492
|
-
total_policy_loss = total_policy_loss[mask].mean()
|
|
2493
|
-
else:
|
|
2494
|
-
total_policy_loss = total_policy_loss.mean()
|
|
2495
|
-
|
|
2496
2526
|
# maybe take policy optimizer step
|
|
2497
2527
|
|
|
2498
2528
|
if exists(policy_optim):
|
|
@@ -2543,6 +2573,7 @@ class DynamicsWorldModel(Module):
|
|
|
2543
2573
|
batch_size = 1,
|
|
2544
2574
|
agent_index = 0,
|
|
2545
2575
|
tasks: int | Tensor | None = None,
|
|
2576
|
+
latent_gene_ids = None,
|
|
2546
2577
|
image_height = None,
|
|
2547
2578
|
image_width = None,
|
|
2548
2579
|
return_decoded_video = None,
|
|
@@ -2658,6 +2689,7 @@ class DynamicsWorldModel(Module):
|
|
|
2658
2689
|
step_sizes = step_size,
|
|
2659
2690
|
rewards = decoded_rewards,
|
|
2660
2691
|
tasks = tasks,
|
|
2692
|
+
latent_gene_ids = latent_gene_ids,
|
|
2661
2693
|
discrete_actions = decoded_discrete_actions,
|
|
2662
2694
|
continuous_actions = decoded_continuous_actions,
|
|
2663
2695
|
proprio = noised_proprio_with_context,
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=LnCEoVek0ekyttB3z4axdgs_NszqpOg1LnNwhrfSisM,114742
|
|
3
|
+
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
+
dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
|
|
5
|
+
dreamer4-0.0.91.dist-info/METADATA,sha256=D2RUBJjn7zLpaRH0_duBFaAA0EMw5JuywscwecZUhEc,3065
|
|
6
|
+
dreamer4-0.0.91.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.0.91.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.0.91.dist-info/RECORD,,
|
dreamer4-0.0.89.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=F8UMxI8uAJmMHmXf_Xhibcs_PkMjNm4AW357U941REo,113725
|
|
3
|
-
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
-
dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
|
|
5
|
-
dreamer4-0.0.89.dist-info/METADATA,sha256=ak7JROeb_PRnOvAg8eoARtnbWh64-2JQJ5zVgSnwKpc,3065
|
|
6
|
-
dreamer4-0.0.89.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.0.89.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.0.89.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|