dreamer4 0.0.85__tar.gz → 0.0.88__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- {dreamer4-0.0.85 → dreamer4-0.0.88}/PKG-INFO +1 -1
- {dreamer4-0.0.85 → dreamer4-0.0.88}/dreamer4/dreamer4.py +28 -19
- {dreamer4-0.0.85 → dreamer4-0.0.88}/pyproject.toml +1 -1
- {dreamer4-0.0.85 → dreamer4-0.0.88}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.0.85 → dreamer4-0.0.88}/.github/workflows/test.yml +0 -0
- {dreamer4-0.0.85 → dreamer4-0.0.88}/.gitignore +0 -0
- {dreamer4-0.0.85 → dreamer4-0.0.88}/LICENSE +0 -0
- {dreamer4-0.0.85 → dreamer4-0.0.88}/README.md +0 -0
- {dreamer4-0.0.85 → dreamer4-0.0.88}/dreamer4/__init__.py +0 -0
- {dreamer4-0.0.85 → dreamer4-0.0.88}/dreamer4/mocks.py +0 -0
- {dreamer4-0.0.85 → dreamer4-0.0.88}/dreamer4/trainers.py +0 -0
- {dreamer4-0.0.85 → dreamer4-0.0.88}/dreamer4-fig2.png +0 -0
- {dreamer4-0.0.85 → dreamer4-0.0.88}/tests/test_dreamer.py +0 -0
|
@@ -77,7 +77,7 @@ class Experience:
|
|
|
77
77
|
latents: Tensor
|
|
78
78
|
video: Tensor | None = None
|
|
79
79
|
proprio: Tensor | None = None
|
|
80
|
-
agent_embed: Tensor | None = None
|
|
80
|
+
agent_embed: Tensor | None = None
|
|
81
81
|
rewards: Tensor | None = None
|
|
82
82
|
actions: tuple[Tensor, Tensor] | None = None
|
|
83
83
|
log_probs: tuple[Tensor, Tensor] | None = None
|
|
@@ -2255,7 +2255,7 @@ class DynamicsWorldModel(Module):
|
|
|
2255
2255
|
video = cat((video, next_frame), dim = 2)
|
|
2256
2256
|
rewards = safe_cat((rewards, reward), dim = 1)
|
|
2257
2257
|
|
|
2258
|
-
acc_agent_embed = safe_cat((acc_agent_embed,
|
|
2258
|
+
acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
|
|
2259
2259
|
|
|
2260
2260
|
# package up one experience for learning
|
|
2261
2261
|
|
|
@@ -2295,6 +2295,7 @@ class DynamicsWorldModel(Module):
|
|
|
2295
2295
|
old_log_probs = experience.log_probs
|
|
2296
2296
|
old_values = experience.values
|
|
2297
2297
|
rewards = experience.rewards
|
|
2298
|
+
agent_embeds = experience.agent_embed
|
|
2298
2299
|
|
|
2299
2300
|
step_size = experience.step_size
|
|
2300
2301
|
agent_index = experience.agent_index
|
|
@@ -2374,32 +2375,38 @@ class DynamicsWorldModel(Module):
|
|
|
2374
2375
|
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
|
2375
2376
|
|
|
2376
2377
|
# replay for the action logits and values
|
|
2378
|
+
# but only do so if fine tuning the entire world model for RL
|
|
2377
2379
|
|
|
2378
2380
|
discrete_actions, continuous_actions = actions
|
|
2379
2381
|
|
|
2380
|
-
|
|
2381
|
-
|
|
2382
|
-
|
|
2383
|
-
|
|
2384
|
-
step_sizes = step_size,
|
|
2385
|
-
rewards = rewards,
|
|
2386
|
-
discrete_actions = discrete_actions,
|
|
2387
|
-
continuous_actions = continuous_actions,
|
|
2388
|
-
latent_is_noised = True,
|
|
2389
|
-
return_pred_only = True,
|
|
2390
|
-
return_intermediates = True
|
|
2391
|
-
)
|
|
2382
|
+
if (
|
|
2383
|
+
not only_learn_policy_value_heads or
|
|
2384
|
+
not exists(agent_embeds)
|
|
2385
|
+
):
|
|
2392
2386
|
|
|
2393
|
-
|
|
2387
|
+
with world_model_forward_context():
|
|
2388
|
+
_, (agent_embeds, _) = self.forward(
|
|
2389
|
+
latents = latents,
|
|
2390
|
+
signal_levels = self.max_steps - 1,
|
|
2391
|
+
step_sizes = step_size,
|
|
2392
|
+
rewards = rewards,
|
|
2393
|
+
discrete_actions = discrete_actions,
|
|
2394
|
+
continuous_actions = continuous_actions,
|
|
2395
|
+
latent_is_noised = True,
|
|
2396
|
+
return_pred_only = True,
|
|
2397
|
+
return_intermediates = True
|
|
2398
|
+
)
|
|
2399
|
+
|
|
2400
|
+
agent_embeds = agent_embeds[..., agent_index, :]
|
|
2394
2401
|
|
|
2395
2402
|
# maybe detach agent embed
|
|
2396
2403
|
|
|
2397
2404
|
if only_learn_policy_value_heads:
|
|
2398
|
-
|
|
2405
|
+
agent_embeds = agent_embeds.detach()
|
|
2399
2406
|
|
|
2400
2407
|
# ppo
|
|
2401
2408
|
|
|
2402
|
-
policy_embed = self.policy_head(
|
|
2409
|
+
policy_embed = self.policy_head(agent_embeds)
|
|
2403
2410
|
|
|
2404
2411
|
log_probs, entropies = self.action_embedder.log_probs(policy_embed, pred_head_index = 0, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True)
|
|
2405
2412
|
|
|
@@ -2448,7 +2455,7 @@ class DynamicsWorldModel(Module):
|
|
|
2448
2455
|
|
|
2449
2456
|
# value loss
|
|
2450
2457
|
|
|
2451
|
-
value_bins = self.value_head(
|
|
2458
|
+
value_bins = self.value_head(agent_embeds)
|
|
2452
2459
|
values = self.reward_encoder.bins_to_scalar_value(value_bins)
|
|
2453
2460
|
|
|
2454
2461
|
clipped_values = old_values + (values - old_values).clamp(-self.value_clip, self.value_clip)
|
|
@@ -2665,7 +2672,9 @@ class DynamicsWorldModel(Module):
|
|
|
2665
2672
|
|
|
2666
2673
|
# maybe store agent embed
|
|
2667
2674
|
|
|
2668
|
-
|
|
2675
|
+
if store_agent_embed:
|
|
2676
|
+
one_agent_embed = agent_embed[:, -1:, agent_index]
|
|
2677
|
+
acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
|
|
2669
2678
|
|
|
2670
2679
|
# decode the agent actions if needed
|
|
2671
2680
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|