dreamer4 0.0.92__py3-none-any.whl → 0.0.93__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dreamer4/dreamer4.py +43 -1
- {dreamer4-0.0.92.dist-info → dreamer4-0.0.93.dist-info}/METADATA +1 -1
- dreamer4-0.0.93.dist-info/RECORD +8 -0
- dreamer4-0.0.92.dist-info/RECORD +0 -8
- {dreamer4-0.0.92.dist-info → dreamer4-0.0.93.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.92.dist-info → dreamer4-0.0.93.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -81,6 +81,7 @@ class Experience:
|
|
|
81
81
|
rewards: Tensor | None = None
|
|
82
82
|
actions: tuple[Tensor, Tensor] | None = None
|
|
83
83
|
log_probs: tuple[Tensor, Tensor] | None = None
|
|
84
|
+
old_action_unembeds: tuple[Tensor, Tensor] | None = None
|
|
84
85
|
values: Tensor | None = None
|
|
85
86
|
step_size: int | None = None
|
|
86
87
|
lens: Tensor | None = None
|
|
@@ -1887,6 +1888,7 @@ class DynamicsWorldModel(Module):
|
|
|
1887
1888
|
gae_lambda = 0.95,
|
|
1888
1889
|
ppo_eps_clip = 0.2,
|
|
1889
1890
|
pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
|
|
1891
|
+
pmpo_kl_div_loss_weight = 1.,
|
|
1890
1892
|
value_clip = 0.4,
|
|
1891
1893
|
policy_entropy_weight = .01,
|
|
1892
1894
|
gae_use_accelerated = False
|
|
@@ -2093,6 +2095,7 @@ class DynamicsWorldModel(Module):
|
|
|
2093
2095
|
# pmpo related
|
|
2094
2096
|
|
|
2095
2097
|
self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
|
|
2098
|
+
self.pmpo_kl_div_loss_weight = pmpo_kl_div_loss_weight
|
|
2096
2099
|
|
|
2097
2100
|
# rewards related
|
|
2098
2101
|
|
|
@@ -2222,7 +2225,8 @@ class DynamicsWorldModel(Module):
|
|
|
2222
2225
|
max_timesteps = 16,
|
|
2223
2226
|
env_is_vectorized = False,
|
|
2224
2227
|
use_time_kv_cache = True,
|
|
2225
|
-
store_agent_embed = False
|
|
2228
|
+
store_agent_embed = False,
|
|
2229
|
+
store_old_action_unembeds = False,
|
|
2226
2230
|
):
|
|
2227
2231
|
assert exists(self.video_tokenizer)
|
|
2228
2232
|
|
|
@@ -2248,6 +2252,7 @@ class DynamicsWorldModel(Module):
|
|
|
2248
2252
|
latents = None
|
|
2249
2253
|
|
|
2250
2254
|
acc_agent_embed = None
|
|
2255
|
+
acc_policy_embed = None
|
|
2251
2256
|
|
|
2252
2257
|
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
|
2253
2258
|
|
|
@@ -2300,6 +2305,9 @@ class DynamicsWorldModel(Module):
|
|
|
2300
2305
|
|
|
2301
2306
|
policy_embed = self.policy_head(one_agent_embed)
|
|
2302
2307
|
|
|
2308
|
+
if store_old_action_unembeds:
|
|
2309
|
+
acc_policy_embed = safe_cat((acc_policy_embed, policy_embed), dim = 1)
|
|
2310
|
+
|
|
2303
2311
|
# sample actions
|
|
2304
2312
|
|
|
2305
2313
|
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
|
@@ -2383,6 +2391,7 @@ class DynamicsWorldModel(Module):
|
|
|
2383
2391
|
actions = (discrete_actions, continuous_actions),
|
|
2384
2392
|
log_probs = (discrete_log_probs, continuous_log_probs),
|
|
2385
2393
|
values = values,
|
|
2394
|
+
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
|
|
2386
2395
|
agent_embed = acc_agent_embed if store_agent_embed else None,
|
|
2387
2396
|
step_size = step_size,
|
|
2388
2397
|
agent_index = agent_index,
|
|
@@ -2411,6 +2420,7 @@ class DynamicsWorldModel(Module):
|
|
|
2411
2420
|
old_values = experience.values
|
|
2412
2421
|
rewards = experience.rewards
|
|
2413
2422
|
agent_embeds = experience.agent_embed
|
|
2423
|
+
old_action_unembeds = experience.old_action_unembeds
|
|
2414
2424
|
|
|
2415
2425
|
step_size = experience.step_size
|
|
2416
2426
|
agent_index = experience.agent_index
|
|
@@ -2489,6 +2499,7 @@ class DynamicsWorldModel(Module):
|
|
|
2489
2499
|
if use_pmpo:
|
|
2490
2500
|
pos_advantage_mask = advantage >= 0.
|
|
2491
2501
|
neg_advantage_mask = ~pos_advantage_mask
|
|
2502
|
+
|
|
2492
2503
|
else:
|
|
2493
2504
|
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
|
2494
2505
|
|
|
@@ -2552,6 +2563,25 @@ class DynamicsWorldModel(Module):
|
|
|
2552
2563
|
|
|
2553
2564
|
policy_loss = -(α * pos + (1. - α) * neg)
|
|
2554
2565
|
|
|
2566
|
+
# take care of kl
|
|
2567
|
+
|
|
2568
|
+
if self.pmpo_kl_div_loss_weight > 0.:
|
|
2569
|
+
new_unembedded_actions = self.action_embedder.unembed(policy_embed, pred_head_index = 0)
|
|
2570
|
+
|
|
2571
|
+
discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(new_unembedded_actions, old_action_unembeds)
|
|
2572
|
+
|
|
2573
|
+
# accumulate discrete and continuous kl div
|
|
2574
|
+
|
|
2575
|
+
kl_div_loss = 0.
|
|
2576
|
+
|
|
2577
|
+
if exists(discrete_kl_div):
|
|
2578
|
+
kl_div_loss = kl_div_loss + discrete_kl_div[mask].mean()
|
|
2579
|
+
|
|
2580
|
+
if exists(continuous_kl_div):
|
|
2581
|
+
kl_div_loss = kl_div_loss + continuous_kl_div[mask].mean()
|
|
2582
|
+
|
|
2583
|
+
policy_loss = policy_loss + kl_div_loss * self.pmpo_kl_div_loss_weight
|
|
2584
|
+
|
|
2555
2585
|
else:
|
|
2556
2586
|
# ppo clipped surrogate loss
|
|
2557
2587
|
|
|
@@ -2694,6 +2724,10 @@ class DynamicsWorldModel(Module):
|
|
|
2694
2724
|
|
|
2695
2725
|
acc_agent_embed = None
|
|
2696
2726
|
|
|
2727
|
+
# maybe store old actions for kl
|
|
2728
|
+
|
|
2729
|
+
acc_policy_embed = None
|
|
2730
|
+
|
|
2697
2731
|
# maybe return rewards
|
|
2698
2732
|
|
|
2699
2733
|
decoded_rewards = None
|
|
@@ -2818,6 +2852,13 @@ class DynamicsWorldModel(Module):
|
|
|
2818
2852
|
|
|
2819
2853
|
policy_embed = self.policy_head(one_agent_embed)
|
|
2820
2854
|
|
|
2855
|
+
# maybe store old actions
|
|
2856
|
+
|
|
2857
|
+
if store_old_action_unembeds:
|
|
2858
|
+
acc_policy_embed = safe_cat((acc_policy_embed, policy_embed))
|
|
2859
|
+
|
|
2860
|
+
# sample actions
|
|
2861
|
+
|
|
2821
2862
|
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
|
2822
2863
|
|
|
2823
2864
|
decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
|
|
@@ -2906,6 +2947,7 @@ class DynamicsWorldModel(Module):
|
|
|
2906
2947
|
video = video,
|
|
2907
2948
|
proprio = proprio if has_proprio else None,
|
|
2908
2949
|
agent_embed = acc_agent_embed if store_agent_embed else None,
|
|
2950
|
+
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
|
|
2909
2951
|
step_size = step_size,
|
|
2910
2952
|
agent_index = agent_index,
|
|
2911
2953
|
lens = experience_lens,
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=B2Bk6JJO9MVTWwss9hOP1k6SBiEr56ijNOa3PiidPnY,118120
|
|
3
|
+
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
+
dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
|
|
5
|
+
dreamer4-0.0.93.dist-info/METADATA,sha256=FhVnlhfeloUPMiFqqJ5qR6fqdd7YmN1-gXykkOTPF_A,3065
|
|
6
|
+
dreamer4-0.0.93.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.0.93.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.0.93.dist-info/RECORD,,
|
dreamer4-0.0.92.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=8RIgxdlgphe19w4bdcy2w7qWqVL4CZtoPfL6LXwj7jE,116380
|
|
3
|
-
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
-
dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
|
|
5
|
-
dreamer4-0.0.92.dist-info/METADATA,sha256=we70JHu-GvXjlcb5HR2bgtn8dc6L-uOARiNdqAKqDTU,3065
|
|
6
|
-
dreamer4-0.0.92.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.0.92.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.0.92.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|